-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][linalg] Produce canonical linalg.generic for im2col #134675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op. Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli@arm.com>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (fabrizio-indirli) ChangesBefore this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op. Full diff: https://github.com/llvm/llvm-project/pull/134675.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 81d44ba04fa1d..4999e8bc4ecae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -64,6 +65,19 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
return *multiIndex;
}
+// Generate the affine expression to compute the convolved index
+// for the input as `oIndex * stride + fIndex`,
+// where oIndex: output iterator; fIndex: filter iterator.
+static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
+ bool useSymbols = true) {
+ AffineExpr oExpr, fExpr;
+ if (useSymbols)
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ else
+ bindDims(b.getContext(), oExpr, fExpr);
+ return AffineExpr(stride * oExpr + fExpr);
+}
+
// Given indices corresponding to iterators in the output (oIndex) and filter
// (fIndex) for a convolution, compute the convolved index for the
// input as `oIndex * stride + fIndex`.
@@ -71,7 +85,7 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
Value fIndex, int64_t stride) {
AffineExpr oExpr, fExpr;
bindSymbols(b.getContext(), oExpr, fExpr);
- AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+ AffineMap convMap = AffineMap::get(0, 2, getConvolvedExpr(b, stride));
return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
}
@@ -556,44 +570,40 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
auto reduction = utils::IteratorType::reduction;
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
+ // Recover the original iteration indices from the problem/input sizes.
+ auto mIndicesExprs =
+ delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
+ auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
+ ArrayRef<int64_t>{fw * ic, ic, 1});
+ auto hIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{mIndicesExprs[0], kIndicesExprs[0]}}, rewriter.getContext())[0];
+ auto wIndicesMap = AffineMap::inferFromExprList(
+ {ArrayRef{mIndicesExprs[1], kIndicesExprs[1]}}, rewriter.getContext())[0];
+ // Compute the input indexing map, to map the output indices to the input
+ // offsets
+ auto bIndexExpr = rewriter.getAffineDimExpr(0U);
+ auto hIndexExpr =
+ getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[0],
+ /*useSymbols*/ false)
+ .compose(hIndicesMap);
+ auto wIndexExpr =
+ getConvolvedExpr(rewriter, convOp.getStrides().getValues<int64_t>()[1],
+ /*useSymbols*/ false)
+ .compose(wIndicesMap);
+ auto cIndexExpr = kIndicesExprs[2];
+ auto inMap = AffineMap::inferFromExprList(
+ {ArrayRef{bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}},
+ rewriter.getContext())[0];
+
SmallVector<AffineMap> img2colIndexingMaps = {
- AffineMap::getMultiDimIdentityMap(nloops, context)};
+ inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
loc, colTensor.getType(),
- /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
+ /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
img2colIterators,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- // Get the iterators named based on the matmul (batch, m, k).
- Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
- Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
- Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
-
- // Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value> mIndices = unrollIndex(
- nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
- auto ohIndex = mIndices[0];
- auto owIndex = mIndices[1];
-
- SmallVector<Value> kIndices = unrollIndex(
- nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
- auto fhIndex = kIndices[0];
- auto fwIndex = kIndices[1];
- auto icIndex = kIndices[2];
-
- // Extract the input element corresponding to the expanded indices.
- Value hIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
- convOp.getStrides().getValues<int64_t>()[0]);
- Value wIndex =
- getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
- convOp.getStrides().getValues<int64_t>()[1]);
-
- // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
- SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
- Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
- loc, input, extractionIndices);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
});
// Because we didn't transpose the filters we don't actually have a batched
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index c17f20b2d03ab..80ff27da430bf 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -291,31 +291,19 @@ module attributes {transform.with_named_sequence} {
// CHECK: IR printer: tensor_producer
// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
-
-// Collapsed indices.
-// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index
-// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
-// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-
-// Compute input channel/convolved indices.
-// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 4)>()[%[[KINDEX]]]
-// CHECK: %[[CONVH:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 14 + s1 floordiv 12)>()[%[[MINDEX]], %[[KINDEX]]]
-// CHECK: %[[CONVW:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod 14 + (s1 mod 12) floordiv 4)>()[%[[MINDEX]], %[[KINDEX]]]
-
-// Extract from the input tensor.
-// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
-// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
-// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
+// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
+// CHECK: linalg.yield %[[IN_DATA]] : f32
// CHECK: IR printer: transformed
// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 14 + d2 floordiv 12, d1 mod 14 + (d2 mod 12) floordiv 4, d2 mod 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK: @conv_2d_nhwc_fhwc
// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32>
// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32>
@@ -324,13 +312,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32>
// CHECK: %[[COL_TENSOR:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %{{.+}} : f32
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic
-// CHECK-SAME: #[[MAP1]]
// CHECK-SAME: #[[MAP2]]
// CHECK-SAME: #[[MAP3]]
+// CHECK-SAME: #[[MAP4]]
// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>)
// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
|
Before this patch, the following input IR:
would be converted to:
while with this patch it is transformed to:
Thus, the input tensor However I see that the current approach is the result of a rewrite of a previous code, which was already producing canonical linalg.generic. I also found a couple of old discussions on the current code (1 and 2) where it was mentioned that
Thus, I was wondering if the current code produces the non-canonical linalg on purpose for a specific reason? Since I couldn't be sure, as I first step I modified only one of the 4 rewrites in the transform. If anyone can confirm that my approach is correct, I'll be happy to update the other ones as well. |
I was advised that a possible reason for the current code (that hides the input access pattern inside the op body) is that the indexing maps of a linalg.generic must be invertible, as linalg::GenericOp implements the TilingInterface. |
Me and @fabrizio-indirli had a chat offline and came to the above conclusion. Strictly speaking it only breaks This has some relevance with the proposal here by @qedawkins : https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 where we could potentially allow conditional interface attachment based on the properties of the indexing maps of the generic. |
There's no need for conditional interface attachment; |
I am actually curious about the opposite here. Why do you think
is the more canonical representation? Which patterns does it enable that aren't possible with the existing representation? Asking for clarity, not to argue one way or another. |
Thanks for taking a look at this.
Moreover, the "new" format would be less verbose, as the body of the op contains only a |
Thanks @qedawkins! a solution that would not need to do "manually compute the input access offsets within the body of the generic is a conscious decision to enable at least some of the tiling interface methods to work" would allow trivial fusions possible as @fabrizio-indirli points out. |
Any thoughts on this @qedawkins @Groverkss ? Is there any specific pass or test that I can run to check that this wouldn't break the current pipelines relying on tile-and-fuse? |
I'm not aware of anyone relying on these patterns (we aren't downstream) so I can't point you at a pipeline. Since it's always easier to go from indexing_map -> explicit indexing, this PR is fine to land. I would just caution against complex indexing maps as most transformations don't/can't support them. |
One other consideration, while the im2col pattern currently doesn't support dynamic shapes, using an indexing map won't work with dynamic shapes. |
Thanks for reviewing this. For consistency I will adjust the other patterns as well in the next days. In the meantime, I'll mark this as WIP |
Before this patch, the Img2Col transform produced a non-canonical linalg.generic whose input tensor was not reported in the inputs of the operation: instead, it was accessed manually from inside the op body, after an internal calculation of the access offsets. This patch modifies the Im2Col rewrite to produce a canonical linalg.generic whose input is correctly reported in its 'ins()', whose access offsets are computed through an indexing map, and whose body contains only a 'linalg.yield' op.