-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][bufferization] implement BufferizableOpInterface for concat op #140171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Jeremy Kun (j2kun) ChangesLowers Example: func.func @<!-- -->tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
return %t : tensor<16xf32>
} Produces module {
func.func @<!-- -->tensor.concat(%arg0: tensor<8xf32>) -> tensor<16xf32> {
// initialization
%0 = bufferization.to_memref %arg0 : tensor<8xf32> to memref<8xf32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8xf32>
memref.copy %0, %alloc : memref<8xf32> to memref<8xf32>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8xf32>
memref.copy %0, %alloc_0 : memref<8xf32> to memref<8xf32>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<16xf32>
// one copy for each operand
%subview = memref.subview %alloc_1[0] [8] [1] : memref<16xf32> to memref<8xf32, strided<[1]>>
memref.copy %alloc, %subview : memref<8xf32> to memref<8xf32, strided<[1]>>
%subview_2 = memref.subview %alloc_1[8] [8] [1] : memref<16xf32> to memref<8xf32, strided<[1], offset: 8>>
memref.copy %alloc_0, %subview_2 : memref<8xf32> to memref<8xf32, strided<[1], offset: 8>>
%1 = bufferization.to_tensor %alloc_1 : memref<16xf32> to tensor<16xf32>
return %1 : tensor<16xf32>
}
} This is my first time implementing BufferizableOpInterface, so I'm looking for some advice on how I can:
Full diff: https://github.com/llvm/llvm-project/pull/140171.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index 8af087cbf0f61..e7d8f52d309c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
>();
addInterfaces<TensorInlinerInterface>();
declarePromisedInterfaces<
- bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
- EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
+ bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
+ DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
ReshapeOp, SplatOp>();
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 31014172a9555..e19d6a50e706a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1048,6 +1048,103 @@ struct SplatOpInterface
}
};
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+ tensor::ConcatOp> {
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {{op->getResult(0), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto concatOp = cast<tensor::ConcatOp>(op);
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+ rewriter, loc, concatOp.getResult(), options,
+ /*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ return op->emitError("memory space not implemented yet");
+
+ MemRefLayoutAttrInterface layout;
+ MemRefType memrefType =
+ MemRefType::get(concatOp.getResultType().getShape(),
+ concatOp.getResultType().getElementType(), layout);
+ Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, *tensorAlloc);
+
+ // Extract the dimension for the concat op
+ uint64_t concatDim = concatOp.getDim();
+
+ SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(tensorType.getRank(),
+ rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes;
+ for (auto dimSize : tensorType.getShape()) {
+ sizes.push_back(rewriter.getIndexAttr(dimSize));
+ }
+
+ int concatDimOffset = 0;
+ for (auto operand : concatOp.getInputs()) {
+ // Get the buffer for the operand.
+ FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+ if (failed(srcBuffer))
+ return failure();
+
+ // Each operand may have a different size along the concat dimension,
+ // so the offset on that axis must accumulate through the loop, and the
+ // size must change to the size of the current operand.
+ auto operandTensorType = cast<RankedTensorType>(operand.getType());
+ int operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+ sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+ offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+
+ // Create a subview of the destination buffer.
+ auto dstMemrefType = cast<MemRefType>(memrefType);
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ operandTensorType.getShape(), dstMemrefType, offsets, sizes,
+ strides);
+ Value subview = rewriter.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
+
+ // Copy the source buffer into the destination subview.
+ if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
+ return failure();
+
+ concatDimOffset += operandConcatDimSize;
+ }
+
+ replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1057,6 +1154,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c1beed95f2006..a9ee707c670b9 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,48 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.concat(
+// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
+// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+ %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+ return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_different_shapes(
+// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK: memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
+ return %t : tensor<8x9xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
|
|
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
I completely forgot about dynamic tensors, but in trying to support this I may need some advice. The current snapshot has my attempt. It transforms func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
return %t : tensor<8x?xf32>
} Into #map = affine_map<()[s0, s1] -> (s0 + s1)>
module {
func.func @tensor.concat_dynamic(%arg0: tensor<8x?xf32>, %arg1: tensor<8x?xf32>) -> tensor<8x?xf32> {
%0 = bufferization.to_memref %arg1 : tensor<8x?xf32> to memref<8x?xf32>
%1 = bufferization.to_memref %arg0 : tensor<8x?xf32> to memref<8x?xf32>
%c1 = arith.constant 1 : index
%dim = memref.dim %1, %c1 : memref<8x?xf32>
%dim_0 = memref.dim %0, %c1 : memref<8x?xf32>
%2 = affine.apply #map()[%dim, %dim_0]
%alloc = memref.alloc(%2) {alignment = 64 : i64} : memref<8x?xf32>
%c-9223372036854775808 = arith.constant -9223372036854775808 : index
%c0 = arith.constant 0 : index
%subview = memref.subview %alloc[0, %c0] [8, %dim] [1, 1] : memref<8x?xf32> to memref<8x?xf32, strided<[?, 1], offset: ?>>
memref.copy %1, %subview : memref<8x?xf32> to memref<8x?xf32, strided<[?, 1], offset: ?>>
%3 = arith.addi %c0, %dim : index
%subview_1 = memref.subview %alloc[0, %3] [8, %dim_0] [1, 1] : memref<8x?xf32> to memref<8x?xf32, strided<[?, 1], offset: ?>>
memref.copy %0, %subview_1 : memref<8x?xf32> to memref<8x?xf32, strided<[?, 1], offset: ?>>
%4 = bufferization.to_tensor %alloc : memref<8x?xf32> to tensor<8x?xf32>
return %4 : tensor<8x?xf32>
}
} I am mostly ignorant of how dynamic dimensions are supposed to work in MLIR in general, and bufferization acutely. The alloc seems fine to me? (using the affine map with memref.dim) But the existence of this |
Yes, I'd be curious where |
I will take a close look tomorrow. |
Turns out this was a bug in my code, and I added another test where the dimension not being concatenated along is dynamic to exercise the fix. @@ -1110,7 +1110,7 @@ struct ConcatOpInterface
for (const auto &[dimIdx, dimSize] :
llvm::enumerate(tensorType.getShape())) {
if (dimSize == ShapedType::kDynamic) {
- auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimSize);
+ auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
sizes.push_back(dimOp.getResult());
if (dimIdx == concatDim)
dynamicConcatDim = true; |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/9632 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/13256 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/20914 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/12940 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/9610 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/10819 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/53/builds/16010 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/32061 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/9670 Here is the relevant piece of the build log for the reference
|
Oh, |
@j2kun I've reverted your PR. In my local build, I'm getting an error:
|
I'll push a revert |
Someone beat me to it in 6d9ce67 |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/28293 Here is the relevant piece of the build log for the reference
|
…rface for concat op (#140171)" This reverts commit 6d9ce67. Multiple builtbot failures have been reported: llvm/llvm-project#140171
…oncat op (llvm#140171) This restores the previously reverted commit with forward fixes
Lowers
tensor.concat
to an alloc with a series ofmemref.copy
ops to copy the operands to the alloc.Example:
Produces
This is my first time implementing BufferizableOpInterface, so I'm looking for some advice on how I can:
memref.copy
ops in the// initialization
section above when handling duplicatetensor.concat
operands.