diff --git a/include/ndarray/pybind11.h b/include/ndarray/pybind11.h index 22e59720..8d25ae14 100644 --- a/include/ndarray/pybind11.h +++ b/include/ndarray/pybind11.h @@ -59,6 +59,8 @@ template class type_caster< ndarray::Array > { public: bool load(handle src, bool) { + _none = src.is_none(); + if (_none) return true; _src.reset(src.ptr(), true); // keep alive for stage 2 if (!ndarray::PyConverter< ndarray::Array >::fromPythonStage1(_src)) { PyErr_Clear(); @@ -80,12 +82,20 @@ class type_caster< ndarray::Array > { protected: ndarray::PyPtr _src; ndarray::Array _value; + bool _none = false; public: static PYBIND11_DESCR name() { return type_descr(_>()); } static handle cast(const ndarray::Array *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } - operator ndarray::Array * () { set_value(); return &_value; } + operator ndarray::Array * () { + if (_none) { + return nullptr; + } else { + set_value(); + return &_value; + } + } operator ndarray::Array & () { set_value(); return _value; } template using cast_op_type = pybind11::detail::cast_op_type<_T>; }; @@ -96,6 +106,8 @@ template class type_caster< ndarray::EigenView > { public: bool load(handle src, bool) { + _none = src.is_none(); + if (_none) return true; _src.reset(src.ptr(), true); // keep alive for stage 2 if (!ndarray::PyConverter< ndarray::EigenView >::fromPythonStage1(_src)) { PyErr_Clear(); @@ -117,12 +129,20 @@ class type_caster< ndarray::EigenView > { protected: ndarray::PyPtr _src; ndarray::EigenView _value; + bool _none = false; public: static PYBIND11_DESCR name() { return type_descr(_>()); } static handle cast(const ndarray::EigenView *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } - operator ndarray::EigenView * () { set_value(); return &_value; } + operator ndarray::EigenView * () { + if (_none) { + return nullptr; + } else { + set_value(); + return &_value; + } + } operator ndarray::EigenView & () { set_value(); return _value; } template using cast_op_type = pybind11::detail::cast_op_type<_T>; }; @@ -133,6 +153,8 @@ template class type_caster< Eigen::Array > { public: bool load(handle src, bool) { + _none = src.is_none(); + if (_none) return true; _src.reset(src.ptr(), true); // keep alive for stage 2 if (!ndarray::PyConverter< Eigen::Array >::fromPythonStage1(_src)) { PyErr_Clear(); @@ -154,12 +176,20 @@ class type_caster< Eigen::Array > { protected: ndarray::PyPtr _src; Eigen::Array _value; + bool _none = false; public: static PYBIND11_DESCR name() { return type_descr(_>()); } static handle cast(const Eigen::Array *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } - operator Eigen::Array * () { set_value(); return &_value; } + operator Eigen::Array * () { + if (_none) { + return nullptr; + } else { + set_value(); + return &_value; + } + } operator Eigen::Array & () { set_value(); return _value; } template using cast_op_type = pybind11::detail::cast_op_type<_T>; }; @@ -170,6 +200,8 @@ template class type_caster< Eigen::Matrix > { public: bool load(handle src, bool) { + _none = src.is_none(); + if (_none) return true; _src.reset(src.ptr(), true); // keep alive for stage 2 if (!ndarray::PyConverter< Eigen::Matrix >::fromPythonStage1(_src)) { PyErr_Clear(); @@ -191,12 +223,20 @@ class type_caster< Eigen::Matrix > { protected: ndarray::PyPtr _src; Eigen::Matrix _value; + bool _none = false; public: static PYBIND11_DESCR name() { return type_descr(_>()); } static handle cast(const Eigen::Matrix *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } - operator Eigen::Matrix * () { set_value(); return &_value; } + operator Eigen::Matrix * () { + if (_none) { + return nullptr; + } else { + set_value(); + return &_value; + } + } operator Eigen::Matrix & () { set_value(); return _value; } template using cast_op_type = pybind11::detail::cast_op_type<_T>; }; diff --git a/tests/pybind11_test.py b/tests/pybind11_test.py index 8855acb3..82e65b5f 100755 --- a/tests/pybind11_test.py +++ b/tests/pybind11_test.py @@ -73,6 +73,20 @@ def testStrideHandling(self): table = numpy.zeros(3, dtype=dtype) self.assertRaises(TypeError, pybind11_test_mod.acceptArray10, table['f1']) + def testNone(self): + array = numpy.zeros(10, dtype=float) + self.assertEqual(pybind11_test_mod.acceptNoneArray(array), 0) + self.assertEqual(pybind11_test_mod.acceptNoneArray(None), 1) + self.assertEqual(pybind11_test_mod.acceptNoneArray(), 1) + m1 = pybind11_test_mod.returnMatrixXd() + self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(m1), 2) + self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(None), 3) + self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(), 3) + + m2 = pybind11_test_mod.returnMatrix2d() + self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(m2), 4) + self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(None), 5) + self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(), 5) if __name__ == "__main__": unittest.main() diff --git a/tests/pybind11_test_mod.cc b/tests/pybind11_test_mod.cc index 37703a22..22118c2e 100644 --- a/tests/pybind11_test_mod.cc +++ b/tests/pybind11_test_mod.cc @@ -5,6 +5,7 @@ #include "ndarray/converter.h" namespace py = pybind11; +using namespace py::literals; Eigen::MatrixXd returnMatrixXd() { Eigen::MatrixXd r(5, 3); @@ -94,6 +95,30 @@ int acceptOverload(Eigen::Matrix2d const & m) { return 2; } +int acceptNoneArray(ndarray::Array const * array = nullptr) { + if (array) { + return 0; + } else { + return 1; + } +} + +int acceptNoneMatrixXd(Eigen::MatrixXd const * matrix = nullptr) { + if (matrix) { + return 2; + } else { + return 3; + } +} + +int acceptNoneMatrix2d(Eigen::Matrix2d const * matrix = nullptr) { + if (matrix) { + return 4; + } else { + return 5; + } +} + struct MatrixOwner { typedef Eigen::Matrix MemberMatrix; MemberMatrix member; @@ -130,6 +155,9 @@ PYBIND11_PLUGIN(pybind11_test_mod) { mod.def("acceptOverload", (int (*)(int)) acceptOverload); mod.def("acceptOverload", (int (*)(Eigen::Matrix2d const &)) acceptOverload); mod.def("acceptOverload", (int (*)(Eigen::Matrix3d const &)) acceptOverload); + mod.def("acceptNoneArray", acceptNoneArray, "array"_a = nullptr); + mod.def("acceptNoneMatrixXd", acceptNoneMatrixXd, "matrix"_a = nullptr); + mod.def("acceptNoneMatrix2d", acceptNoneMatrix2d, "matrix"_a = nullptr); return mod.ptr(); } \ No newline at end of file