diff --git a/include/ndarray/arange.h b/include/ndarray/arange.h index 593beaed..6cc241d6 100644 --- a/include/ndarray/arange.h +++ b/include/ndarray/arange.h @@ -30,11 +30,11 @@ namespace ndarray { */ template <> struct ExpressionTraits { - typedef int Element; + typedef std::size_t Element; typedef boost::mpl::int_<1> ND; - typedef boost::counting_iterator Iterator; - typedef int Value; - typedef int Reference; + typedef boost::counting_iterator Iterator; + typedef std::size_t Value; + typedef std::size_t Reference; }; namespace detail { @@ -52,11 +52,11 @@ class CountingExpression : public ExpressionBase { typedef ExpressionTraits::Iterator Iterator; typedef ExpressionTraits::Value Value; typedef ExpressionTraits::Reference Reference; - typedef Vector Index; + typedef Vector Index; - CountingExpression(int stop=0) : _stop(stop) { NDARRAY_ASSERT(stop >= 0); } + CountingExpression(std::size_t stop=0) : _stop(stop) { NDARRAY_ASSERT(stop >= 0); } - Reference operator[](int n) const { + Reference operator[](std::size_t n) const { return n; } @@ -68,7 +68,7 @@ class CountingExpression : public ExpressionBase { return Iterator(_stop); } - template int getSize() const { + template std::size_t getSize() const { BOOST_STATIC_ASSERT(P==0); return _stop; } @@ -78,7 +78,7 @@ class CountingExpression : public ExpressionBase { } private: - int _stop; + std::size_t _stop; }; template @@ -86,31 +86,31 @@ class RangeTransformer { T _offset; T _scale; public: - typedef int argument_type; + typedef std::size_t argument_type; typedef T result_type; explicit RangeTransformer(T const & offset, T const & scale) : _offset(offset), _scale(scale) {} - T operator()(int n) const { return static_cast(n) * _scale + _offset; } + T operator()(std::size_t n) const { return static_cast(n) * _scale + _offset; } }; } // namespace detail /// @brief Create 1D Expression that contains integer values in the range [0,stop). -inline detail::CountingExpression arange(int stop) { +inline detail::CountingExpression arange(std::size_t stop) { return detail::CountingExpression(stop); } /// @brief Create 1D Expression that contains integer values in the range [start,stop) with increment step. -inline detail::UnaryOpExpression< detail::CountingExpression, detail::RangeTransformer > -arange(int start, int stop, int step = 1) { +template +detail::UnaryOpExpression< detail::CountingExpression, detail::RangeTransformer > +arange(T start, T stop, T step = 1) { NDARRAY_ASSERT(step != 0); - int size = stop - start; - if (step < -1) ++size; - if (step > 1) --size; - size /= step; + T const diff = stop - start; + NDARRAY_ASSERT((diff > 0 && step > 0) || (diff < 0 && step < 0)); + std::size_t const size = diff/step; return vectorize( - detail::RangeTransformer(start,step), + detail::RangeTransformer(start,step), detail::CountingExpression(size) ); } diff --git a/tests/ndarray.cc b/tests/ndarray.cc index 7729e4ad..e5f3e851 100644 --- a/tests/ndarray.cc +++ b/tests/ndarray.cc @@ -628,3 +628,57 @@ BOOST_AUTO_TEST_CASE(issue3) { BOOST_CHECK_EQUAL(s5.getSize<3>(), 2); BOOST_CHECK_EQUAL(s5.getSize<4>(), 1); } + +BOOST_AUTO_TEST_CASE(arangeInt) { + std::size_t const size = 10; + // Basic use + { + ndarray::Array values = ndarray::copy(ndarray::arange(size)); + BOOST_CHECK_EQUAL(values.getNumElements(), size); + for (std::size_t ii = 0; ii < size; ++ii) { + BOOST_CHECK_EQUAL(values[ii], ii); + } + } + { + ndarray::Array values = ndarray::copy(ndarray::arange(0UL, size)); + BOOST_CHECK_EQUAL(values.getNumElements(), size); + for (std::size_t ii = 0; ii < size; ++ii) { + BOOST_CHECK_EQUAL(values[ii], ii); + } + + } + + // Expanded use + std::vector startList = {0, 3, 7, 123}; + std::vector stepList = {1, 2, 3, 10, -1, -2, -3, -10}; + { + for (int step : stepList) { + for (int start : startList) { + int stop = start + size*step; + ndarray::Array values = ndarray::copy(ndarray::arange(start, stop, step)); + BOOST_CHECK_EQUAL(values.getNumElements(), size); + for (std::size_t ii = 0; ii < size; ++ii) { + int expect = start + ii*step; + BOOST_CHECK_EQUAL(values[ii], expect); + } + } + } + } +} + +BOOST_AUTO_TEST_CASE(arangeFloat) { + std::size_t const size = 10; + std::vector startList = {0.0, 123.45}; + std::vector stepList = {1.0, 0.1, 1.23}; + for (float step : stepList) { + for (float start : startList) { + float stop = start + size*step; + ndarray::Array values = ndarray::copy(ndarray::arange(start, stop, step)); + BOOST_CHECK_EQUAL(values.getNumElements(), size); + for (std::size_t ii = 0; ii < size; ++ii) { + float expect = start + ii*step; + BOOST_CHECK_EQUAL(values[ii], expect); + } + } + } +}