From 0afdaa634e7d51cc40616a904bac5a9b4c7a1834 Mon Sep 17 00:00:00 2001
From: Pim Schellart
Date: Fri, 9 Jun 2017 11:28:27 -0400
Subject: [PATCH] Fix pybind11 type_casters to accept None as nullptr
---
include/ndarray/pybind11.h | 48 ++++++++++++++++++++++++++++++++++----
tests/pybind11_test.py | 14 +++++++++++
tests/pybind11_test_mod.cc | 28 ++++++++++++++++++++++
3 files changed, 86 insertions(+), 4 deletions(-)
diff --git a/include/ndarray/pybind11.h b/include/ndarray/pybind11.h
index 22e59720..8d25ae14 100644
--- a/include/ndarray/pybind11.h
+++ b/include/ndarray/pybind11.h
@@ -59,6 +59,8 @@ template
class type_caster< ndarray::Array > {
public:
bool load(handle src, bool) {
+ _none = src.is_none();
+ if (_none) return true;
_src.reset(src.ptr(), true); // keep alive for stage 2
if (!ndarray::PyConverter< ndarray::Array >::fromPythonStage1(_src)) {
PyErr_Clear();
@@ -80,12 +82,20 @@ class type_caster< ndarray::Array > {
protected:
ndarray::PyPtr _src;
ndarray::Array _value;
+ bool _none = false;
public:
static PYBIND11_DESCR name() { return type_descr(_>()); }
static handle cast(const ndarray::Array *src, return_value_policy policy, handle parent) {
return cast(*src, policy, parent);
}
- operator ndarray::Array * () { set_value(); return &_value; }
+ operator ndarray::Array * () {
+ if (_none) {
+ return nullptr;
+ } else {
+ set_value();
+ return &_value;
+ }
+ }
operator ndarray::Array & () { set_value(); return _value; }
template using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
@@ -96,6 +106,8 @@ template
class type_caster< ndarray::EigenView > {
public:
bool load(handle src, bool) {
+ _none = src.is_none();
+ if (_none) return true;
_src.reset(src.ptr(), true); // keep alive for stage 2
if (!ndarray::PyConverter< ndarray::EigenView >::fromPythonStage1(_src)) {
PyErr_Clear();
@@ -117,12 +129,20 @@ class type_caster< ndarray::EigenView > {
protected:
ndarray::PyPtr _src;
ndarray::EigenView _value;
+ bool _none = false;
public:
static PYBIND11_DESCR name() { return type_descr(_>()); }
static handle cast(const ndarray::EigenView *src, return_value_policy policy, handle parent) {
return cast(*src, policy, parent);
}
- operator ndarray::EigenView * () { set_value(); return &_value; }
+ operator ndarray::EigenView * () {
+ if (_none) {
+ return nullptr;
+ } else {
+ set_value();
+ return &_value;
+ }
+ }
operator ndarray::EigenView & () { set_value(); return _value; }
template using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
@@ -133,6 +153,8 @@ template
class type_caster< Eigen::Array > {
public:
bool load(handle src, bool) {
+ _none = src.is_none();
+ if (_none) return true;
_src.reset(src.ptr(), true); // keep alive for stage 2
if (!ndarray::PyConverter< Eigen::Array >::fromPythonStage1(_src)) {
PyErr_Clear();
@@ -154,12 +176,20 @@ class type_caster< Eigen::Array > {
protected:
ndarray::PyPtr _src;
Eigen::Array _value;
+ bool _none = false;
public:
static PYBIND11_DESCR name() { return type_descr(_>()); }
static handle cast(const Eigen::Array *src, return_value_policy policy, handle parent) {
return cast(*src, policy, parent);
}
- operator Eigen::Array * () { set_value(); return &_value; }
+ operator Eigen::Array * () {
+ if (_none) {
+ return nullptr;
+ } else {
+ set_value();
+ return &_value;
+ }
+ }
operator Eigen::Array & () { set_value(); return _value; }
template using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
@@ -170,6 +200,8 @@ template
class type_caster< Eigen::Matrix > {
public:
bool load(handle src, bool) {
+ _none = src.is_none();
+ if (_none) return true;
_src.reset(src.ptr(), true); // keep alive for stage 2
if (!ndarray::PyConverter< Eigen::Matrix >::fromPythonStage1(_src)) {
PyErr_Clear();
@@ -191,12 +223,20 @@ class type_caster< Eigen::Matrix > {
protected:
ndarray::PyPtr _src;
Eigen::Matrix _value;
+ bool _none = false;
public:
static PYBIND11_DESCR name() { return type_descr(_>()); }
static handle cast(const Eigen::Matrix *src, return_value_policy policy, handle parent) {
return cast(*src, policy, parent);
}
- operator Eigen::Matrix * () { set_value(); return &_value; }
+ operator Eigen::Matrix * () {
+ if (_none) {
+ return nullptr;
+ } else {
+ set_value();
+ return &_value;
+ }
+ }
operator Eigen::Matrix & () { set_value(); return _value; }
template using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
diff --git a/tests/pybind11_test.py b/tests/pybind11_test.py
index 8855acb3..82e65b5f 100755
--- a/tests/pybind11_test.py
+++ b/tests/pybind11_test.py
@@ -73,6 +73,20 @@ def testStrideHandling(self):
table = numpy.zeros(3, dtype=dtype)
self.assertRaises(TypeError, pybind11_test_mod.acceptArray10, table['f1'])
+ def testNone(self):
+ array = numpy.zeros(10, dtype=float)
+ self.assertEqual(pybind11_test_mod.acceptNoneArray(array), 0)
+ self.assertEqual(pybind11_test_mod.acceptNoneArray(None), 1)
+ self.assertEqual(pybind11_test_mod.acceptNoneArray(), 1)
+ m1 = pybind11_test_mod.returnMatrixXd()
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(m1), 2)
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(None), 3)
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrixXd(), 3)
+
+ m2 = pybind11_test_mod.returnMatrix2d()
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(m2), 4)
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(None), 5)
+ self.assertEqual(pybind11_test_mod.acceptNoneMatrix2d(), 5)
if __name__ == "__main__":
unittest.main()
diff --git a/tests/pybind11_test_mod.cc b/tests/pybind11_test_mod.cc
index 37703a22..22118c2e 100644
--- a/tests/pybind11_test_mod.cc
+++ b/tests/pybind11_test_mod.cc
@@ -5,6 +5,7 @@
#include "ndarray/converter.h"
namespace py = pybind11;
+using namespace py::literals;
Eigen::MatrixXd returnMatrixXd() {
Eigen::MatrixXd r(5, 3);
@@ -94,6 +95,30 @@ int acceptOverload(Eigen::Matrix2d const & m) {
return 2;
}
+int acceptNoneArray(ndarray::Array const * array = nullptr) {
+ if (array) {
+ return 0;
+ } else {
+ return 1;
+ }
+}
+
+int acceptNoneMatrixXd(Eigen::MatrixXd const * matrix = nullptr) {
+ if (matrix) {
+ return 2;
+ } else {
+ return 3;
+ }
+}
+
+int acceptNoneMatrix2d(Eigen::Matrix2d const * matrix = nullptr) {
+ if (matrix) {
+ return 4;
+ } else {
+ return 5;
+ }
+}
+
struct MatrixOwner {
typedef Eigen::Matrix MemberMatrix;
MemberMatrix member;
@@ -130,6 +155,9 @@ PYBIND11_PLUGIN(pybind11_test_mod) {
mod.def("acceptOverload", (int (*)(int)) acceptOverload);
mod.def("acceptOverload", (int (*)(Eigen::Matrix2d const &)) acceptOverload);
mod.def("acceptOverload", (int (*)(Eigen::Matrix3d const &)) acceptOverload);
+ mod.def("acceptNoneArray", acceptNoneArray, "array"_a = nullptr);
+ mod.def("acceptNoneMatrixXd", acceptNoneMatrixXd, "matrix"_a = nullptr);
+ mod.def("acceptNoneMatrix2d", acceptNoneMatrix2d, "matrix"_a = nullptr);
return mod.ptr();
}
\ No newline at end of file