From 86dcabb24c81eaeb1b83dc3fc0ef10f35f814487 Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Sat, 23 Jul 2022 16:56:06 -0400 Subject: [PATCH] Fix issue where ndims was incorrectly used to calculate shape of input --- src/api/c/convolve.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/api/c/convolve.cpp b/src/api/c/convolve.cpp index ddcd916ae6..9a496633b0 100644 --- a/src/api/c/convolve.cpp +++ b/src/api/c/convolve.cpp @@ -344,14 +344,17 @@ af_err af_convolve2_nn(af_array *out, const af_array signal, const af_dtype signalType = sInfo.getType(); - ARG_ASSERT(3, stride_dims > 0 && stride_dims <= 2); - ARG_ASSERT(5, padding_dims > 0 && padding_dims <= 2); - ARG_ASSERT(7, dilation_dims > 0 && dilation_dims <= 2); - dim4 stride(stride_dims, strides); dim4 padding(padding_dims, paddings); dim4 dilation(dilation_dims, dilations); + size_t stride_ndims = stride.ndims(); + size_t padding_ndims = padding.ndims(); + size_t dilation_ndims = dilation.ndims(); + ARG_ASSERT(3, stride_ndims > 0 && stride_ndims <= 2); + ARG_ASSERT(5, padding_ndims >= 0 && padding_ndims <= 2); + ARG_ASSERT(7, dilation_ndims > 0 && dilation_ndims <= 2); + // assert number of features matches between signal and filter DIM_ASSERT(1, sDims[2] == fDims[2]); @@ -424,14 +427,17 @@ af_err af_convolve2_gradient_nn( af_array output; - ARG_ASSERT(3, stride_dims > 0 && stride_dims <= 2); - ARG_ASSERT(5, padding_dims > 0 && padding_dims <= 2); - ARG_ASSERT(7, dilation_dims > 0 && dilation_dims <= 2); - af::dim4 stride(stride_dims, strides); af::dim4 padding(padding_dims, paddings); af::dim4 dilation(dilation_dims, dilations); + size_t stride_ndims = stride.ndims(); + size_t padding_ndims = padding.ndims(); + size_t dilation_ndims = dilation.ndims(); + ARG_ASSERT(3, stride_ndims > 0 && stride_ndims <= 2); + ARG_ASSERT(5, padding_ndims > 0 && padding_ndims <= 2); + ARG_ASSERT(7, dilation_ndims > 0 && dilation_ndims <= 2); + af_dtype type = oinfo.getType(); switch (type) { case f32: