Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][linalg] Produce canonical linalg.generic for im2col #134675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from

Conversation

fabrizio-indirli
Copy link
Contributor

Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op.

Before this patch, the Img2Col transform produced a non-canonical
linalg.generic whose input tensor was not reported in the inputs
of the operation: instead, it was accessed manually from inside the
op body, after an internal calculation of the access offsets.
This patch modifies the Im2Col rewrite to produce a canonical
linalg.generic whose input is correctly reported in its 'ins()',
whose access offsets are computed through an indexing map,
and whose body contains only a 'linalg.yield' op.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (fabrizio-indirli)

Changes

Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op.


Full diff: https://github.com/llvm/llvm-project/pull/134675.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+43-33)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-22)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 81d44ba04fa1d..4999e8bc4ecae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -64,6 +65,19 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
   return *multiIndex;
 }
 
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+                                   bool useSymbols = true) {
+  AffineExpr oExpr, fExpr;
+  if (useSymbols)
+    bindSymbols(b.getContext(), oExpr, fExpr);
+  else
+    bindDims(b.getContext(), oExpr, fExpr);
+  return AffineExpr(stride * oExpr + fExpr);
+}
+
 // Given indices corresponding to iterators in the output (oIndex) and filter
 // (fIndex) for a convolution, compute the convolved index for the
 // input as `oIndex * stride + fIndex`.
@@ -71,7 +85,7 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
                                Value fIndex, int64_t stride) {
   AffineExpr oExpr, fExpr;
   bindSymbols(b.getContext(), oExpr, fExpr);
-  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+  AffineMap convMap = AffineMap::get(0, 2, getConvolvedExpr(b, stride));
   return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
 }
 
@@ -556,44 +570,40 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
   auto reduction = utils::IteratorType::reduction;
   SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
 
+  // Recover the original iteration indices from the problem/input sizes.
+  auto mIndicesExprs =
+      delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+                                   ArrayRef<int64_t>{fw * ic, ic, 1});
+  auto hIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{mIndicesExprs[0], kIndicesExprs[0]}}, rewriter.getContext())[0];
+  auto wIndicesMap = AffineMap::inferFromExprList(
+      {ArrayRef{mIndicesExprs[1], kIndicesExprs[1]}}, rewriter.getContext())[0];
+  // Compute the input indexing map, to map the output indices to the input
+  // offsets
+  auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+  auto hIndexExpr =
+      getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[0],
+                       /*useSymbols*/ false)
+          .compose(hIndicesMap);
+  auto wIndexExpr =
+      getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[1],
+                       /*useSymbols*/ false)
+          .compose(wIndicesMap);
+  auto cIndexExpr = kIndicesExprs[2];
+  auto inMap = AffineMap::inferFromExprList(
+      {ArrayRef{bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}},
+      rewriter.getContext())[0];
+
   SmallVector<AffineMap> img2colIndexingMaps = {
-      AffineMap::getMultiDimIdentityMap(nloops, context)};
+      inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
 
   auto img2ColTensor = rewriter.create<linalg::GenericOp>(
       loc, colTensor.getType(),
-      /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+      /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
       img2colIterators,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
-        // Get the iterators named based on the matmul (batch, m, k).
-        Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
-        Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
-        Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
-        // Recover the original iteration indices from the problem/input sizes.
-        SmallVector<Value> mIndices = unrollIndex(
-            nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
-        auto ohIndex = mIndices[0];
-        auto owIndex = mIndices[1];
-
-        SmallVector<Value> kIndices = unrollIndex(
-            nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
-        auto fhIndex = kIndices[0];
-        auto fwIndex = kIndices[1];
-        auto icIndex = kIndices[2];
-
-        // Extract the input element corresponding to the expanded indices.
-        Value hIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
-                              convOp.getStrides().getValues<int64_t>()[0]);
-        Value wIndex =
-            getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
-                              convOp.getStrides().getValues<int64_t>()[1]);
-
-        // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
-        SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
-        Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
-            loc, input, extractionIndices);
-        nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
       });
 
   // Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b2d03ab..80ff27da430bf 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -291,31 +291,19 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK: IR printer: tensor_producer
 // CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
 // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+//     CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+//     CHECK: linalg.yield %[[IN_DATA]] : f32
 
 // CHECK: IR printer: transformed
 // CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //      CHECK: @conv_2d_nhwc_fhwc
 //      CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
 //      CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +312,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
 //      CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
 //      CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP0]]
+//           CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
 //                CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
 //                CHECK: linalg.yield %{{.+}} : f32
 //      CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-//           CHECK-SAME: #[[MAP1]]
 //           CHECK-SAME: #[[MAP2]]
 //           CHECK-SAME: #[[MAP3]]
+//           CHECK-SAME: #[[MAP4]]
 //           CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
 //           CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
 //                CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)

@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Apr 7, 2025

Before this patch, the following input IR:

func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %0 = linalg.conv_2d_nhwc_fhwc
      {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
       ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
      outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
    return %0 : tensor<1x14x14x16xf32>
}

would be converted to:

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map6 = affine_map<()[s0] -> (s0 mod 4)>
#map7 = affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>
#map8 = affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>
#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map10 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
    %collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
    %0 = tensor.empty() : tensor<1x196x36xf32>
    %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} 
    outs(%0 : tensor<1x196x36xf32>) {
    ^bb0(%out: f32):
      %3 = linalg.index 0 : index
      %4 = linalg.index 1 : index
      %5 = linalg.index 2 : index
      %11 = affine.apply #map6()[%5]
      %12 = affine.apply #map7()[%4, %5]
      %13 = affine.apply #map8()[%4, %5]
      %extracted = tensor.extract %arg0[%3, %12, %13, %11] : tensor<1x16x16x4xf32>
      linalg.yield %extracted : f32
    } -> tensor<1x196x36xf32>
    %2 = linalg.generic {indexing_maps = [#map9, #map10, #map11], 
     iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 
     ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) 
     outs(%collapsed_0 : tensor<1x196x16xf32>) {
     ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %3, %out : f32
      linalg.yield %4 : f32
    } -> tensor<1x196x16xf32>
    %expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
    return %expanded : tensor<1x14x14x16xf32>
}

while with this patch it is transformed to:

#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
    %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32>
    %collapsed_0 = tensor.collapse_shape %arg2 [[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
    %0 = tensor.empty() : tensor<1x196x36xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], 
     iterator_types = ["parallel", "parallel", "parallel"]} 
     ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x196x36xf32>) {
     ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x196x36xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], 
     iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 
     ins(%1, %collapsed : tensor<1x196x36xf32>, tensor<16x36xf32>) outs(%collapsed_0 : tensor<1x196x16xf32>) {
     ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      %4 = arith.addf %3, %out : f32
      linalg.yield %4 : f32
    } -> tensor<1x196x16xf32>
    %expanded = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
    return %expanded : tensor<1x14x14x16xf32>
}

Thus, the input tensor %arg0 is now correctly reported in the linalg.generic's ins(), and the input access offsets are computed through a normal indexing maps. This simplifies the code quite a bit and produces more canonical linalg.generic ops.

However I see that the current approach is the result of a rewrite of a previous code, which was already producing canonical linalg.generic. I also found a couple of old discussions on the current code (1 and 2) where it was mentioned that

I know this pattern exists in IREE, but it is a bit of a hack. The representation of the im2col as a linalg.generic doesnt work as well. In reality it is similar to a gather. If this is upstreamed, it might be worth doing this right and not use an unnecessarily higher-dimensional operation for representing the im2col.

Thus, I was wondering if the current code produces the non-canonical linalg on purpose for a specific reason? Since I couldn't be sure, as I first step I modified only one of the 4 rewrites in the transform. If anyone can confirm that my approach is correct, I'll be happy to update the other ones as well.
@qcolombet @nicolasvasilache @MaheshRavishankar @ThomasRaoux

@fabrizio-indirli
Copy link
Contributor Author

I was advised that a possible reason for the current code (that hides the input access pattern inside the op body) is that the indexing maps of a linalg.generic must be invertible, as linalg::GenericOp implements the TilingInterface.
However an indexing map such as affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> is not invertible and would break the requirement. In the current upstream implementation, this is "hidden" by sinking the input access pattern inside the op body, instead of exposing it in the linalg.generic's interface.

@manupak
Copy link
Contributor

manupak commented Apr 9, 2025

Me and @fabrizio-indirli had a chat offline and came to the above conclusion.
@qedawkins @MaheshRavishankar, when you have some time, would you be able to take a look here?

Strictly speaking it only breaks getIterationDomainTileFromOperandTile (a.k.a. consumer fusion interface) which is right now obscured by not exposing the input as an actual input in the implementation.

This has some relevance with the proposal here by @qedawkins : https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 where we could potentially allow conditional interface attachment based on the properties of the indexing maps of the generic.

@qedawkins
Copy link
Contributor

There's no need for conditional interface attachment; getIterationDomainTileFromOperandTile, getIterationDomainTileFromResultTile, and any of the other methods that actually perform tiling are allowed to fail. Thus the choice to manually compute the input access offsets within the body of the generic is a conscious decision to enable at least some of the tiling interface methods to work on the im2col op. affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> is not invertible (as noted) and thus will cause any attempt to tile the op to fail.

@qedawkins
Copy link
Contributor

I am actually curious about the opposite here. Why do you think

#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x196x36xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x196x36xf32>

is the more canonical representation? Which patterns does it enable that aren't possible with the existing representation? Asking for clarity, not to argue one way or another.

@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Apr 10, 2025

Thanks for taking a look at this.
@qedawkins "Canonical" might be a misused term on my side, since in both cases the input indexing maps are not invertible.
I just meant that the new format clearly shows all the accessed tensors (including the input one) and their access patterns (including the input indexing map) in the interface of the linalg.generic op, as one would expect normally. This allows other passes to be able to analyze (e.g. retrieve uses) and manipulate the linalg op without having to inspect its body.
For example, a typical linalg fusion pattern such as FuseElementwiseOps commonly checks only the operands in the linalg.generic interface to traverse the def-use chains of a value to be fused. With the current implementation of img2col, the input tensor is not reported in the linalg's inputs, thus the img2col op is never fused with a producer. On the contrary, the fusion is applied with the "more canonical" linalg op I am proposing:

// INPUT IR
#mapI2c = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 52 + d2 floordiv 192, d1 mod 52 + (d2 mod 192) floordiv 64, d2 mod 64)>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], 
 iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 
 ins(%expanded : tensor<1x54x54x64xf32>) outs(%5 : tensor<1x54x54x64xf32>) {
  ^bb0(%in: f32, %out: f32):  // PRODUCER: ELEMENTWISE OP
    %14 = arith.minimumf %in, %cst_1 : f32
    %15 = arith.maximumf %14, %cst_2 : f32
    linalg.yield %15 : f32
  } -> tensor<1x54x54x64xf32>
  %9 = tensor.empty() : tensor<1x2704x576xf32>
  %10 = linalg.generic {indexing_maps = [#mapI2c, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 
   iterator_types = ["parallel", "parallel", "parallel"]} 
   ins(%6 : tensor<1x54x54x64xf32>) outs(%9 : tensor<1x2704x576xf32>) {
  ^bb0(%in: f32, %out: f32):  // NEW SIMPLIFIED IM2COL LINALG
    linalg.yield %in : f32
  } -> tensor<1x2704x576xf32>

// AFTER FUSION
%7 = tensor.empty() : tensor<1x2704x576xf32>
%8 = linalg.generic {indexing_maps = [#mapI2c, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
 iterator_types = ["parallel", "parallel", "parallel"]} 
 ins(%expanded : tensor<1x54x54x64xf32>) outs(%7 : tensor<1x2704x576xf32>) {
  ^bb0(%in: f32, %out: f32):
    %12 = arith.minimumf %in, %cst_1 : f32
    %13 = arith.maximumf %12, %cst_2 : f32
    linalg.yield %13 : f32
  } -> tensor<1x2704x576xf32>

Moreover, the "new" format would be less verbose, as the body of the op contains only a linalg.yield rather than several lianlg.index, affine.apply and tensor.extract ops, though this is not very important IMHO.

@manupak
Copy link
Contributor

manupak commented Apr 10, 2025

There's no need for conditional interface attachment; getIterationDomainTileFromOperandTile, getIterationDomainTileFromResultTile, and any of the other methods that actually perform tiling are allowed to fail. Thus the choice to manually compute the input access offsets within the body of the generic is a conscious decision to enable at least some of the tiling interface methods to work on the im2col op. affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)> is not invertible (as noted) and thus will cause any attempt to tile the op to fail.

Thanks @qedawkins!
'getIterationDomainTileFromOperandTile' allowed to fail != allowed to producer fuse. correct?

a solution that would not need to do "manually compute the input access offsets within the body of the generic is a conscious decision to enable at least some of the tiling interface methods to work" would allow trivial fusions possible as @fabrizio-indirli points out.

@fabrizio-indirli
Copy link
Contributor Author

Any thoughts on this @qedawkins @Groverkss ? Is there any specific pass or test that I can run to check that this wouldn't break the current pipelines relying on tile-and-fuse?

@qedawkins
Copy link
Contributor

Any thoughts on this @qedawkins @Groverkss ? Is there any specific pass or test that I can run to check that this wouldn't break the current pipelines relying on tile-and-fuse?

I'm not aware of anyone relying on these patterns (we aren't downstream) so I can't point you at a pipeline. Since it's always easier to go from indexing_map -> explicit indexing, this PR is fine to land. I would just caution against complex indexing maps as most transformations don't/can't support them.

@qedawkins
Copy link
Contributor

One other consideration, while the im2col pattern currently doesn't support dynamic shapes, using an indexing map won't work with dynamic shapes.

@fabrizio-indirli
Copy link
Contributor Author

Thanks for reviewing this. For consistency I will adjust the other patterns as well in the next days. In the meantime, I'll mark this as WIP

@fabrizio-indirli fabrizio-indirli marked this pull request as draft May 13, 2025 12:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.