Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion 41 src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use num_traits::Float;
use num_traits::One;
use num_traits::{FromPrimitive, Zero};
use std::ops::{Add, Div, Mul, Sub};
use std::ops::{Add, Div, Mul, MulAssign, Sub};

use crate::imp_prelude::*;
use crate::numeric_util;
Expand Down Expand Up @@ -97,6 +97,45 @@ where D: Dimension
sum
}

/// Return the cumulative product of elements along a given axis.
///
/// ```
/// use ndarray::{arr2, Axis};
///
/// let a = arr2(&[[1., 2., 3.],
/// [4., 5., 6.]]);
///
/// // Cumulative product along rows (axis 0)
/// assert_eq!(
/// a.cumprod(Axis(0)),
/// arr2(&[[1., 2., 3.],
/// [4., 10., 18.]])
/// );
///
/// // Cumulative product along columns (axis 1)
/// assert_eq!(
/// a.cumprod(Axis(1)),
/// arr2(&[[1., 2., 6.],
/// [4., 20., 120.]])
/// );
/// ```
///
/// **Panics** if `axis` is out of bounds.
#[track_caller]
pub fn cumprod(&self, axis: Axis) -> Array<A, D>
where
A: Clone + Mul<Output = A> + MulAssign,
D: Dimension + RemoveAxis,
{
if axis.0 >= self.ndim() {
panic!("axis is out of bounds for array of dimension");
}

let mut result = self.to_owned();
result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone());
result
}

/// Return variance of elements in the array.
///
/// The variance is computed using the [Welford one-pass
Expand Down
70 changes: 70 additions & 0 deletions 70 tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,76 @@ fn sum_mean_prod_empty()
assert_eq!(a, None);
}

#[test]
fn test_cumprod_1d()
{
let a = array![1, 2, 3, 4];
let result = a.cumprod(Axis(0));
assert_eq!(result, array![1, 2, 6, 24]);
}

#[test]
fn test_cumprod_2d()
{
let a = array![[1, 2], [3, 4]];

let result_axis0 = a.cumprod(Axis(0));
assert_eq!(result_axis0, array![[1, 2], [3, 8]]);

let result_axis1 = a.cumprod(Axis(1));
assert_eq!(result_axis1, array![[1, 2], [3, 12]]);
}

#[test]
fn test_cumprod_3d()
{
let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];

let result_axis0 = a.cumprod(Axis(0));
assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]);

let result_axis1 = a.cumprod(Axis(1));
assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]);

let result_axis2 = a.cumprod(Axis(2));
assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]);
}

#[test]
fn test_cumprod_empty()
{
// For 2D empty array
let b: Array2<i32> = Array2::zeros((0, 0));
let result_axis0 = b.cumprod(Axis(0));
assert_eq!(result_axis0, Array2::zeros((0, 0)));
let result_axis1 = b.cumprod(Axis(1));
assert_eq!(result_axis1, Array2::zeros((0, 0)));
}

#[test]
fn test_cumprod_1_element()
{
// For 1D array with one element
let a = array![5];
let result = a.cumprod(Axis(0));
assert_eq!(result, array![5]);

// For 2D array with one element
let b = array![[5]];
let result_axis0 = b.cumprod(Axis(0));
let result_axis1 = b.cumprod(Axis(1));
assert_eq!(result_axis0, array![[5]]);
assert_eq!(result_axis1, array![[5]]);
}

#[test]
#[should_panic(expected = "axis is out of bounds for array of dimension")]
fn test_cumprod_axis_out_of_bounds()
{
let a = array![[1, 2], [3, 4]];
let _result = a.cumprod(Axis(2));
}

#[test]
#[cfg(feature = "std")]
fn var()
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.