@@ -344,14 +344,17 @@ af_err af_convolve2_nn(af_array *out, const af_array signal,
344
344
345
345
const af_dtype signalType = sInfo .getType ();
346
346
347
- ARG_ASSERT (3 , stride_dims > 0 && stride_dims <= 2 );
348
- ARG_ASSERT (5 , padding_dims > 0 && padding_dims <= 2 );
349
- ARG_ASSERT (7 , dilation_dims > 0 && dilation_dims <= 2 );
350
-
351
347
dim4 stride (stride_dims, strides);
352
348
dim4 padding (padding_dims, paddings);
353
349
dim4 dilation (dilation_dims, dilations);
354
350
351
+ size_t stride_ndims = stride.ndims ();
352
+ size_t padding_ndims = padding.ndims ();
353
+ size_t dilation_ndims = dilation.ndims ();
354
+ ARG_ASSERT (3 , stride_ndims > 0 && stride_ndims <= 2 );
355
+ ARG_ASSERT (5 , padding_ndims >= 0 && padding_ndims <= 2 );
356
+ ARG_ASSERT (7 , dilation_ndims > 0 && dilation_ndims <= 2 );
357
+
355
358
// assert number of features matches between signal and filter
356
359
DIM_ASSERT (1 , sDims [2 ] == fDims [2 ]);
357
360
@@ -424,14 +427,17 @@ af_err af_convolve2_gradient_nn(
424
427
425
428
af_array output;
426
429
427
- ARG_ASSERT (3 , stride_dims > 0 && stride_dims <= 2 );
428
- ARG_ASSERT (5 , padding_dims > 0 && padding_dims <= 2 );
429
- ARG_ASSERT (7 , dilation_dims > 0 && dilation_dims <= 2 );
430
-
431
430
af::dim4 stride (stride_dims, strides);
432
431
af::dim4 padding (padding_dims, paddings);
433
432
af::dim4 dilation (dilation_dims, dilations);
434
433
434
+ size_t stride_ndims = stride.ndims ();
435
+ size_t padding_ndims = padding.ndims ();
436
+ size_t dilation_ndims = dilation.ndims ();
437
+ ARG_ASSERT (3 , stride_ndims > 0 && stride_ndims <= 2 );
438
+ ARG_ASSERT (5 , padding_ndims > 0 && padding_ndims <= 2 );
439
+ ARG_ASSERT (7 , dilation_ndims > 0 && dilation_ndims <= 2 );
440
+
435
441
af_dtype type = oinfo.getType ();
436
442
switch (type) {
437
443
case f32 :
0 commit comments