Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5e7fc62

Browse filesBrowse files
committed
Restore [mlir][bufferization] implement BufferizableOpInterface for concat op (llvm#140171)
This restores the previously reverted commit with forward fixes
1 parent 7b8bc1b commit 5e7fc62
Copy full SHA for 5e7fc62

File tree

3 files changed

+222
-2
lines changed
Filter options

3 files changed

+222
-2
lines changed

‎mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
4949
>();
5050
addInterfaces<TensorInlinerInterface>();
5151
declarePromisedInterfaces<
52-
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
53-
EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
52+
bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
53+
DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
5454
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
5555
ReshapeOp, SplatOp>();
5656
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,

‎mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+129Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,134 @@ struct SplatOpInterface
10481048
}
10491049
};
10501050

1051+
/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1052+
/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1053+
/// on subviews instead of memref.store.
1054+
struct ConcatOpInterface
1055+
: public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1056+
tensor::ConcatOp> {
1057+
1058+
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1059+
1060+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1061+
const AnalysisState &state) const {
1062+
return false;
1063+
}
1064+
1065+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1066+
const AnalysisState &state) const {
1067+
return true;
1068+
}
1069+
1070+
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1071+
const AnalysisState &state) const {
1072+
return {};
1073+
}
1074+
1075+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1076+
const BufferizationOptions &options) const {
1077+
OpBuilder::InsertionGuard g(rewriter);
1078+
auto concatOp = cast<tensor::ConcatOp>(op);
1079+
1080+
// Allocate memory.
1081+
Location loc = op->getLoc();
1082+
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1083+
rewriter, loc, concatOp.getResult(), options,
1084+
/*copy=*/false);
1085+
if (failed(tensorAlloc))
1086+
return failure();
1087+
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1088+
1089+
// TODO: Implement memory space for this op.
1090+
if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1091+
return op->emitError("memory space not implemented yet");
1092+
1093+
MemRefLayoutAttrInterface layout;
1094+
MemRefType memrefType =
1095+
MemRefType::get(concatOp.getResultType().getShape(),
1096+
concatOp.getResultType().getElementType(), layout);
1097+
Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
1098+
op->getLoc(), memrefType, *tensorAlloc);
1099+
1100+
// Extract the dimension for the concat op
1101+
uint64_t concatDim = concatOp.getDim();
1102+
bool dynamicConcatDim = false;
1103+
1104+
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1105+
rewriter.getIndexAttr(0));
1106+
SmallVector<OpFoldResult> strides(tensorType.getRank(),
1107+
rewriter.getIndexAttr(1));
1108+
SmallVector<OpFoldResult> sizes;
1109+
1110+
for (const auto &[dimIdx, dimSize] :
1111+
llvm::enumerate(tensorType.getShape())) {
1112+
if (dimSize == ShapedType::kDynamic) {
1113+
auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1114+
sizes.push_back(dimOp.getResult());
1115+
if (dimIdx == concatDim)
1116+
dynamicConcatDim = true;
1117+
} else {
1118+
sizes.push_back(rewriter.getIndexAttr(dimSize));
1119+
}
1120+
}
1121+
1122+
int64_t concatDimOffset = 0;
1123+
std::optional<Value> dynamicOffset;
1124+
std::optional<Value> dynamicSize;
1125+
if (dynamicConcatDim) {
1126+
// One or more operands have dynamic size, so we must accumulate the
1127+
// offset with arith ops.
1128+
dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1129+
}
1130+
1131+
for (auto operand : concatOp.getInputs()) {
1132+
// Get the buffer for the operand.
1133+
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
1134+
if (failed(srcBuffer))
1135+
return failure();
1136+
1137+
// Each operand may have a different size along the concat dimension,
1138+
// so the offset on that axis must accumulate through the loop, and the
1139+
// size must change to the size of the current operand.
1140+
auto operandTensorType = cast<RankedTensorType>(operand.getType());
1141+
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1142+
1143+
if (dynamicConcatDim) {
1144+
offsets[concatDim] = dynamicOffset.value();
1145+
dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1146+
.getResult();
1147+
sizes[concatDim] = dynamicSize.value();
1148+
} else {
1149+
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1150+
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1151+
}
1152+
1153+
// Create a subview of the destination buffer.
1154+
auto dstMemrefType = cast<MemRefType>(memrefType);
1155+
MemRefType subviewMemRefType =
1156+
memref::SubViewOp::inferRankReducedResultType(
1157+
operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1158+
strides);
1159+
Value subview = rewriter.create<memref::SubViewOp>(
1160+
loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1161+
1162+
// Copy the source buffer into the destination subview.
1163+
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1164+
return failure();
1165+
1166+
if (dynamicConcatDim) {
1167+
dynamicOffset = rewriter.create<arith::AddIOp>(
1168+
loc, dynamicOffset.value(), dynamicSize.value());
1169+
} else {
1170+
concatDimOffset += operandConcatDimSize;
1171+
}
1172+
}
1173+
1174+
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
1175+
return success();
1176+
}
1177+
};
1178+
10511179
} // namespace
10521180
} // namespace tensor
10531181
} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
10571185
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
10581186
CastOp::attachInterface<CastOpInterface>(*ctx);
10591187
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1188+
ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
10601189
DimOp::attachInterface<DimOpInterface>(*ctx);
10611190
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
10621191
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);

‎mlir/test/Dialect/Tensor/bufferize.mlir

Copy file name to clipboardExpand all lines: mlir/test/Dialect/Tensor/bufferize.mlir
+91Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
615615

616616
// -----
617617

618+
// CHECK-LABEL: func @tensor.concat(
619+
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
620+
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
621+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
622+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
623+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
624+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
625+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
626+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
627+
// CHECK: return %[[RET]]
628+
// CHECK: }
629+
func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
630+
%t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
631+
return %t : tensor<16xf32>
632+
}
633+
634+
// -----
635+
636+
// CHECK-LABEL: func @tensor.concat_different_shapes(
637+
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
638+
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
639+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
640+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
641+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
642+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
643+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
644+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
645+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
646+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
647+
// CHECK: return %[[RET]]
648+
// CHECK: }
649+
func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
650+
%t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
651+
return %t : tensor<8x9xf32>
652+
}
653+
654+
// -----
655+
656+
// CHECK-LABEL: func @tensor.concat_dynamic(
657+
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
658+
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
659+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
660+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
661+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
662+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
663+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
664+
// CHECK: %[[ALLOC:.*]] = memref.alloc
665+
// CHECK-SAME: memref<8x?xf32>
666+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
667+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
668+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
669+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
670+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
671+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
672+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
673+
// CHECK: return %[[RET]]
674+
// CHECK: }
675+
func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
676+
%t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
677+
return %t : tensor<8x?xf32>
678+
}
679+
680+
// -----
681+
682+
// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
683+
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
684+
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
685+
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
686+
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
687+
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
688+
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
689+
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
690+
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
691+
// CHECK: %[[ALLOC:.*]] = memref.alloc
692+
// CHECK-SAME: memref<?x?xf32>
693+
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
694+
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
695+
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
696+
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
697+
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
698+
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
699+
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
700+
// CHECK: return %[[RET]]
701+
// CHECK: }
702+
func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
703+
%t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
704+
return %t : tensor<?x?xf32>
705+
}
706+
707+
// -----
708+
618709
// CHECK-LABEL: func @tensor.splat_dynamic(
619710
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
620711
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.