Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 97d097b

Browse filesBrowse files
committed
Fix return type trait bound on reduce all functions
1 parent 6670d5c commit 97d097b
Copy full SHA for 97d097b

File tree

Expand file treeCollapse file tree

1 file changed

+115
-25
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+115
-25
lines changed

‎src/algorithm/mod.rs

Copy file name to clipboardExpand all lines: src/algorithm/mod.rs
+115-25Lines changed: 115 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,17 @@ where
518518
}
519519

520520
macro_rules! all_reduce_func_def {
521-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
521+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
522522
#[doc=$doc_str]
523-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
523+
pub fn $fn_name<T>(input: &Array<T>)
524+
-> (
525+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
526+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
527+
)
524528
where
525529
T: HasAfEnum,
526-
$out_type: HasAfEnum + Fromf64
530+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
531+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
527532
{
528533
let mut real: f64 = 0.0;
529534
let mut imag: f64 = 0.0;
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533538
);
534539
HANDLE_ERROR(AfError::from(err_val));
535540
}
536-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
541+
(
542+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
543+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
544+
)
537545
}
538546
};
539547
}
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564572
",
565573
sum_all,
566574
af_sum_all,
567-
T::AggregateOutType
575+
AggregateOutType
568576
);
569577

570578
all_reduce_func_def!(
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594602
",
595603
product_all,
596604
af_product_all,
597-
T::ProductOutType
605+
ProductOutType
598606
);
599607

600608
all_reduce_func_def!(
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623631
",
624632
min_all,
625633
af_min_all,
626-
T::InType
634+
InType
627635
);
628636

629637
all_reduce_func_def!(
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652660
",
653661
max_all,
654662
af_max_all,
655-
T::InType
663+
InType
656664
);
657665

658-
all_reduce_func_def!(
666+
macro_rules! all_reduce_func_def2 {
667+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
668+
#[doc=$doc_str]
669+
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
670+
where
671+
T: HasAfEnum,
672+
$out_type: HasAfEnum + Fromf64
673+
{
674+
let mut real: f64 = 0.0;
675+
let mut imag: f64 = 0.0;
676+
unsafe {
677+
let err_val = $ffi_name(
678+
&mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
679+
);
680+
HANDLE_ERROR(AfError::from(err_val));
681+
}
682+
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
683+
}
684+
};
685+
}
686+
687+
all_reduce_func_def2!(
659688
"
660689
Find if all values of Array are non-zero
661690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682711
bool
683712
);
684713

685-
all_reduce_func_def!(
714+
all_reduce_func_def2!(
686715
"
687716
Find if any value of Array is non-zero
688717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709738
bool
710739
);
711740

712-
all_reduce_func_def!(
741+
all_reduce_func_def2!(
713742
"
714743
Count number of non-zero values in the Array
715744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751780
/// A tuple of summation result.
752781
///
753782
/// Note: For non-complex data type Arrays, second value of tuple is zero.
754-
pub fn sum_nan_all<T>(input: &Array<T>, val: f64) -> (T::AggregateOutType, T::AggregateOutType)
783+
pub fn sum_nan_all<T>(
784+
input: &Array<T>,
785+
val: f64,
786+
) -> (
787+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
788+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
789+
)
755790
where
756791
T: HasAfEnum,
757-
T::AggregateOutType: HasAfEnum + Fromf64,
792+
<T as HasAfEnum>::AggregateOutType: HasAfEnum,
793+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
758794
{
759795
let mut real: f64 = 0.0;
760796
let mut imag: f64 = 0.0;
@@ -768,8 +804,8 @@ where
768804
HANDLE_ERROR(AfError::from(err_val));
769805
}
770806
(
771-
<T::AggregateOutType>::fromf64(real),
772-
<T::AggregateOutType>::fromf64(imag),
807+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(real),
808+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag),
773809
)
774810
}
775811

@@ -788,10 +824,17 @@ where
788824
/// A tuple of product result.
789825
///
790826
/// Note: For non-complex data type Arrays, second value of tuple is zero.
791-
pub fn product_nan_all<T>(input: &Array<T>, val: f64) -> (T::ProductOutType, T::ProductOutType)
827+
pub fn product_nan_all<T>(
828+
input: &Array<T>,
829+
val: f64,
830+
) -> (
831+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
832+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
833+
)
792834
where
793835
T: HasAfEnum,
794-
T::ProductOutType: HasAfEnum + Fromf64,
836+
<T as HasAfEnum>::ProductOutType: HasAfEnum,
837+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
795838
{
796839
let mut real: f64 = 0.0;
797840
let mut imag: f64 = 0.0;
@@ -805,8 +848,8 @@ where
805848
HANDLE_ERROR(AfError::from(err_val));
806849
}
807850
(
808-
<T::ProductOutType>::fromf64(real),
809-
<T::ProductOutType>::fromf64(imag),
851+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(real),
852+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(imag),
810853
)
811854
}
812855

@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858901
", imax, af_imax, InType);
859902

860903
macro_rules! all_ireduce_func_def {
861-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
904+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
862905
#[doc=$doc_str]
863-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type, u32)
906+
pub fn $fn_name<T>(input: &Array<T>)
907+
-> (
908+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
909+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
910+
u32
911+
)
864912
where
865913
T: HasAfEnum,
866-
$out_type: HasAfEnum + Fromf64
914+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
915+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
867916
{
868917
let mut real: f64 = 0.0;
869918
let mut imag: f64 = 0.0;
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875924
);
876925
HANDLE_ERROR(AfError::from(err_val));
877926
}
878-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp)
927+
(
928+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
929+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
930+
temp,
931+
)
879932
}
880933
};
881934
}
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898951
",
899952
imin_all,
900953
af_imin_all,
901-
T::InType
954+
InType
902955
);
903956
all_ireduce_func_def!(
904957
"
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918971
",
919972
imax_all,
920973
af_imax_all,
921-
T::InType
974+
InType
922975
);
923976

924977
/// Locate the indices of non-zero elements.
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
13861439
af_product_by_key_nan,
13871440
ValueType::ProductOutType
13881441
);
1442+
1443+
#[cfg(test)]
1444+
mod tests {
1445+
use super::super::core::c32;
1446+
use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all};
1447+
use crate::randu;
1448+
1449+
#[test]
1450+
fn all_reduce_api() {
1451+
let a = randu!(c32; 10, 10);
1452+
println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
1453+
1454+
let b = randu!(bool; 10, 10);
1455+
println!("reduction of bool matrix: {:?}", sum_all(&b));
1456+
1457+
println!(
1458+
"reduction of complex f32 matrix after replacing nan with {}: {:?}",
1459+
1.0,
1460+
product_nan_all(&a, 1.0)
1461+
);
1462+
1463+
println!(
1464+
"reduction of bool matrix after replacing nan with {}: {:?}",
1465+
0.0,
1466+
sum_nan_all(&b, 0.0)
1467+
);
1468+
}
1469+
1470+
#[test]
1471+
fn all_ireduce_api() {
1472+
let a = randu!(c32; 10);
1473+
println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
1474+
1475+
let b = randu!(u32; 10);
1476+
println!("reduction of bool matrix: {:?}", imax_all(&b));
1477+
}
1478+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.