From c3e7cec29695ea8c4694419e9caebeef817f3d14 Mon Sep 17 00:00:00 2001 From: Wolf Vollprecht Date: Mon, 24 Apr 2017 12:56:16 -0700 Subject: [PATCH 1/3] add test for complex overload --- test_python/main.cpp | 12 ++++++++++++ test_python/test_pyarray.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/test_python/main.cpp b/test_python/main.cpp index 135eff7..cd2edd8 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -42,6 +42,15 @@ double readme_example2(double i, double j) return std::sin(i) - std::cos(j); } +auto complex_overload(const xt::pyarray>& a) +{ + return a; +} +auto no_complex_overload(const xt::pyarray& a) +{ + return a; +} + // Vectorize Examples int add(int i, int j) @@ -58,6 +67,9 @@ PYBIND11_PLUGIN(xtensor_python_test) m.def("example1", example1); m.def("example2", example2); + m.def("complex_overload", no_complex_overload); + m.def("complex_overload", complex_overload); + m.def("readme_example1", readme_example1); m.def("readme_example2", xt::pyvectorize(readme_example2)); diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index b1e64b1..2558015 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -36,6 +36,18 @@ def test_readme_example1(self): y = xt.readme_example1(v) np.testing.assert_allclose(y, 1.2853996391883833, 1e-12) + def test_complex_overload(self): + a = np.random.rand(3, 3) + b = np.random.rand(3, 3) + c = a + b * 1j + y = xt.complex_overload(c) + print(y) + np.testing.assert_allclose(np.imag(y), np.imag(c)) + np.testing.assert_allclose(np.real(y), np.real(c)) + x = xt.complex_overload(b) + self.assertEqual(x.dtype, b.dtype) + np.testing.assert_allclose(x, b) + def test_readme_example2(self): x = np.arange(15).reshape(3, 5) y = [1, 2, 3, 4, 5] From 406b034bdeb4a9ff1082950196b67b116c63a5ea Mon Sep 17 00:00:00 2001 From: Wolf Vollprecht Date: Mon, 24 Apr 2017 13:03:03 -0700 Subject: [PATCH 2/3] add non failing complex overload test --- test_python/main.cpp | 12 ++++++++++++ test_python/test_pyarray.py | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/test_python/main.cpp b/test_python/main.cpp index cd2edd8..85c352f 100644 --- a/test_python/main.cpp +++ b/test_python/main.cpp @@ -51,6 +51,16 @@ auto no_complex_overload(const xt::pyarray& a) return a; } +auto complex_overload_reg(const std::complex& a) +{ + return a; +} + +auto no_complex_overload_reg(const double& a) +{ + return a; +} + // Vectorize Examples int add(int i, int j) @@ -69,6 +79,8 @@ PYBIND11_PLUGIN(xtensor_python_test) m.def("complex_overload", no_complex_overload); m.def("complex_overload", complex_overload); + m.def("complex_overload_reg", no_complex_overload_reg); + m.def("complex_overload_reg", complex_overload_reg); m.def("readme_example1", readme_example1); m.def("readme_example2", xt::pyvectorize(readme_example2)); diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index 2558015..c2ded1e 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -36,6 +36,12 @@ def test_readme_example1(self): y = xt.readme_example1(v) np.testing.assert_allclose(y, 1.2853996391883833, 1e-12) + def test_complex_overload_reg(self): + a = 23.23 + c = 2.0 + 3.1j + self.assertEqual(xt.complex_overload_reg(a), a) + self.assertEqual(xt.complex_overload_reg(c), c) + def test_complex_overload(self): a = np.random.rand(3, 3) b = np.random.rand(3, 3) From 0b0bea757a629776144f2e14f07c096725fe2f5c Mon Sep 17 00:00:00 2001 From: Wolf Vollprecht Date: Mon, 24 Apr 2017 13:12:59 -0700 Subject: [PATCH 3/3] fix complex overload loading --- include/xtensor-python/pyarray.hpp | 14 +++++++++++++- include/xtensor-python/pytensor.hpp | 15 ++++++++++++++- test_python/test_pyarray.py | 1 - 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/include/xtensor-python/pyarray.hpp b/include/xtensor-python/pyarray.hpp index 47e285f..59ddc65 100644 --- a/include/xtensor-python/pyarray.hpp +++ b/include/xtensor-python/pyarray.hpp @@ -42,8 +42,20 @@ namespace pybind11 { using type = xt::pyarray; - bool load(handle src, bool) + bool load(handle src, bool convert) { + if (!convert) + { + if (!PyArray_Check(src.ptr())) + { + return false; + } + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + { + return false; + } + } value = type::ensure(src); return static_cast(value); } diff --git a/include/xtensor-python/pytensor.hpp b/include/xtensor-python/pytensor.hpp index 8e9638d..f76f841 100644 --- a/include/xtensor-python/pytensor.hpp +++ b/include/xtensor-python/pytensor.hpp @@ -43,8 +43,21 @@ namespace pybind11 { using type = xt::pytensor; - bool load(handle src, bool) + bool load(handle src, bool convert) { + if (!convert) + { + if (!PyArray_Check(src.ptr())) + { + return false; + } + int type_num = xt::detail::numpy_traits::type_num; + if (PyArray_TYPE(reinterpret_cast(src.ptr())) != type_num) + { + return false; + } + } + value = type::ensure(src); return static_cast(value); } diff --git a/test_python/test_pyarray.py b/test_python/test_pyarray.py index c2ded1e..d4625ee 100644 --- a/test_python/test_pyarray.py +++ b/test_python/test_pyarray.py @@ -47,7 +47,6 @@ def test_complex_overload(self): b = np.random.rand(3, 3) c = a + b * 1j y = xt.complex_overload(c) - print(y) np.testing.assert_allclose(np.imag(y), np.imag(c)) np.testing.assert_allclose(np.real(y), np.real(c)) x = xt.complex_overload(b)