Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 86dcabb

Browse filesBrowse files
committed
Fix issue where ndims was incorrectly used to calculate shape of input
1 parent b05da69 commit 86dcabb
Copy full SHA for 86dcabb

File tree

Expand file treeCollapse file tree

1 file changed

+14
-8
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+14
-8
lines changed

‎src/api/c/convolve.cpp

Copy file name to clipboardExpand all lines: src/api/c/convolve.cpp
+14-8Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,17 @@ af_err af_convolve2_nn(af_array *out, const af_array signal,
344344

345345
const af_dtype signalType = sInfo.getType();
346346

347-
ARG_ASSERT(3, stride_dims > 0 && stride_dims <= 2);
348-
ARG_ASSERT(5, padding_dims > 0 && padding_dims <= 2);
349-
ARG_ASSERT(7, dilation_dims > 0 && dilation_dims <= 2);
350-
351347
dim4 stride(stride_dims, strides);
352348
dim4 padding(padding_dims, paddings);
353349
dim4 dilation(dilation_dims, dilations);
354350

351+
size_t stride_ndims = stride.ndims();
352+
size_t padding_ndims = padding.ndims();
353+
size_t dilation_ndims = dilation.ndims();
354+
ARG_ASSERT(3, stride_ndims > 0 && stride_ndims <= 2);
355+
ARG_ASSERT(5, padding_ndims >= 0 && padding_ndims <= 2);
356+
ARG_ASSERT(7, dilation_ndims > 0 && dilation_ndims <= 2);
357+
355358
// assert number of features matches between signal and filter
356359
DIM_ASSERT(1, sDims[2] == fDims[2]);
357360

@@ -424,14 +427,17 @@ af_err af_convolve2_gradient_nn(
424427

425428
af_array output;
426429

427-
ARG_ASSERT(3, stride_dims > 0 && stride_dims <= 2);
428-
ARG_ASSERT(5, padding_dims > 0 && padding_dims <= 2);
429-
ARG_ASSERT(7, dilation_dims > 0 && dilation_dims <= 2);
430-
431430
af::dim4 stride(stride_dims, strides);
432431
af::dim4 padding(padding_dims, paddings);
433432
af::dim4 dilation(dilation_dims, dilations);
434433

434+
size_t stride_ndims = stride.ndims();
435+
size_t padding_ndims = padding.ndims();
436+
size_t dilation_ndims = dilation.ndims();
437+
ARG_ASSERT(3, stride_ndims > 0 && stride_ndims <= 2);
438+
ARG_ASSERT(5, padding_ndims > 0 && padding_ndims <= 2);
439+
ARG_ASSERT(7, dilation_ndims > 0 && dilation_ndims <= 2);
440+
435441
af_dtype type = oinfo.getType();
436442
switch (type) {
437443
case f32:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.