/** * @file * @copyright Copyright 2020. Tom de Geus. All rights reserved. * @license This project is released under the GNU Public License (MIT). */ #include #include #define FORCE_IMPORT_ARRAY #include #include #include namespace py = pybind11; /** * Overrides the `__name__` of a module. * Classes defined by pybind11 use the `__name__` of the module as of the time they are defined, * which affects the `__repr__` of the class type objects. */ class ScopedModuleNameOverride { public: explicit ScopedModuleNameOverride(py::module m, std::string name) : module_(std::move(m)) { original_name_ = module_.attr("__name__"); module_.attr("__name__") = name; } ~ScopedModuleNameOverride() { module_.attr("__name__") = original_name_; } private: py::module module_; py::object original_name_; }; PYBIND11_MODULE(_xt, m) { // Ensure members to display as `xt.X` (not `xt._xt.X`) ScopedModuleNameOverride name_override(m, "xt"); xt::import_numpy(); m.doc() = "Python bindings of xtensor"; m.def("mean", [](const xt::pyarray& a) -> xt::pyarray { return xt::mean(a); }); m.def("average", [](const xt::pyarray& a, const xt::pyarray& w) -> xt::pyarray { return xt::average(a, w); }); m.def("average", [](const xt::pyarray& a, const xt::pyarray& w, const std::vector& axes) -> xt::pyarray { return xt::average(a, w, axes); }); m.def("flip", [](const xt::pyarray& a, ptrdiff_t axis) -> xt::pyarray { return xt::flip(a, axis); }); m.def("cos", [](const xt::pyarray& a) -> xt::pyarray { return xt::cos(a); }); m.def("isin", [](const xt::pyarray& a, const xt::pyarray& b) -> xt::pyarray { return xt::isin(a, b); }); m.def("in1d", [](const xt::pyarray& a, const xt::pyarray& b) -> xt::pyarray { return xt::in1d(a, b); }); }