diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 47e285f..59ddc65 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -42,8 +42,20 @@ namespace pybind11 { using type = xt::pyarray; - bool load(handle src, bool) + bool load(handle src, bool convert) { + if (!convert) + { + if (!PyArray_Check(src.ptr())) + { + return false; + } + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + { + return false; + } + } value = type::ensure(src); return static_cast(value); } diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 8e9638d..f76f841 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -43,8 +43,21 @@ namespace pybind11 { using type = xt::pytensor; - bool load(handle src, bool) + bool load(handle src, bool convert) { + if (!convert) + { + if (!PyArray_Check(src.ptr())) + { + return false; + } + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + { + return false; + } + } + value = type::ensure(src); return static_cast(value); } diff --git a/test_python/main.cpp b/test_python/main.cpp index 135eff7..85c352f 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -42,6 +42,25 @@ double readme_example2(double i, double j) return std::sin(i) - std::cos(j); } +auto complex_overload(const xt::pyarray>& a) +{ + return a; +} +auto no_complex_overload(const xt::pyarray& a) +{ + return a; +} + +auto complex_overload_reg(const std::complex& a) +{ + return a; +} + +auto no_complex_overload_reg(const double& a) +{ + return a; +} + // Vectorize Examples int add(int i, int j) @@ -58,6 +77,11 @@ PYBIND11_PLUGIN(xtensor_python_test) m.def("example1", example1); m.def("example2", example2); + m.def("complex_overload", no_complex_overload); + m.def("complex_overload", complex_overload); + m.def("complex_overload_reg", no_complex_overload_reg); + m.def("complex_overload_reg", complex_overload_reg); + m.def("readme_example1", readme_example1); m.def("readme_example2", xt::pyvectorize(readme_example2)); diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index b1e64b1..d4625ee 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -36,6 +36,23 @@ def test_readme_example1(self): y = xt.readme_example1(v) np.testing.assert_allclose(y, 1.2853996391883833, 1e-12) + def test_complex_overload_reg(self): + a = 23.23 + c = 2.0 + 3.1j + self.assertEqual(xt.complex_overload_reg(a), a) + self.assertEqual(xt.complex_overload_reg(c), c) + + def test_complex_overload(self): + a = np.random.rand(3, 3) + b = np.random.rand(3, 3) + c = a + b * 1j + y = xt.complex_overload(c) + np.testing.assert_allclose(np.imag(y), np.imag(c)) + np.testing.assert_allclose(np.real(y), np.real(c)) + x = xt.complex_overload(b) + self.assertEqual(x.dtype, b.dtype) + np.testing.assert_allclose(x, b) + def test_readme_example2(self): x = np.arange(15).reshape(3, 5) y = [1, 2, 3, 4, 5]