Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit da944e0

Browse filesBrowse files
[mlir][tensor] Add shape inference support for tensor.concat op. (#140168)
## description `tensor.concat` requires operands and the result to match on all dimensions except the concatenation dimension. If one operand is already static in those dimensions, the other operands and result type may safely be refined to that same static shape. This PR adds canonicalization patterns to refine `tensor.concat` types and propagate static shapes to other canonicalization patterns through casts. ## example ```mlir %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->tensor<?x12xi32> ``` becomes: ```mlir %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32> %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32> ``` --------- Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
1 parent ba38e56 commit da944e0
Copy full SHA for da944e0

File tree

2 files changed

+131
-3
lines changed
Filter options

2 files changed

+131
-3
lines changed

‎mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+105-3Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/STLExtras.h"
3434
#include "llvm/ADT/SmallBitVector.h"
3535
#include "llvm/ADT/StringRef.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include <algorithm>
3839
#include <optional>
@@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330331

331332
/// Determines whether the tensor::CastOp casts to a more static version of the
332333
/// source tensor. This is useful to fold into a producing op and implement
333-
/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
334-
/// being from different dialects. Returns true when all conditions are met:
334+
/// canonicalization patterns with the `tensor.cast` op as the root, but
335+
/// producer being from different dialects. Returns true when all conditions are
336+
/// met:
335337
/// 1. source and result and ranked tensors with same element type and rank.
336338
/// 2. the result type has more static information than the source.
337339
///
@@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773775
return success();
774776
}
775777
};
778+
779+
/// Propagate static shapes into the operands of a `tensor.concat`.
780+
///
781+
/// `tensor.concat` requires every operand to match on all dimensions except the
782+
/// concatenation dimension. If one operand is already static in those
783+
/// dimensions, the other operands may safely be refined to that same static
784+
/// shape.
785+
///
786+
/// Example:
787+
///
788+
/// ```mlir
789+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790+
/// tensor<?x12xi32>
791+
/// ```
792+
/// ->
793+
/// ```mlir
794+
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795+
/// %2 = tensor.concat dim(0) %0, %cast :
796+
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797+
/// ```
798+
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
799+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
800+
801+
LogicalResult matchAndRewrite(ConcatOp concatOp,
802+
PatternRewriter &rewriter) const override {
803+
auto operandTensorTypes =
804+
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
805+
return llvm::cast<RankedTensorType>(type);
806+
});
807+
808+
int64_t dim = concatOp.getDim();
809+
ArrayRef<int64_t> inferredResultShape =
810+
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
811+
812+
// Find operands for which a more static shape can be inferred.
813+
LogicalResult matched = failure();
814+
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
815+
// Compute inferred type for operand.
816+
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
817+
inferredOperandShape[dim] = operandType.getDimSize(dim);
818+
auto inferredOperandType = RankedTensorType::get(
819+
inferredOperandShape, operandType.getElementType());
820+
821+
// Check if inferred type is more static.
822+
if (!preservesStaticInformation(inferredOperandType, operandType)) {
823+
matched = success();
824+
825+
// Use refined operand type and create cast from original operand.
826+
auto castOp =
827+
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
828+
concatOp.getOperand(operandIdx));
829+
rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
830+
concatOp->setOperand(operandIdx, castOp->getResult(0));
831+
});
832+
}
833+
}
834+
835+
return matched;
836+
}
837+
};
838+
839+
// Ensure `tensor.concat`'s result type is at least as static as can be inferred
840+
// from its operand types.
841+
///
842+
/// Example:
843+
/// ```mlir
844+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
845+
/// tensor<?x?xi32>
846+
/// ```
847+
/// ->
848+
/// ```mlir
849+
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
850+
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
851+
/// tensor<?x?xi32>
852+
/// ```
853+
struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
854+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
855+
856+
LogicalResult matchAndRewrite(ConcatOp concatOp,
857+
PatternRewriter &rewriter) const override {
858+
int64_t dim = concatOp.getDim();
859+
RankedTensorType inferredResultType =
860+
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
861+
862+
// The result type should be at least as static as inferred result type.
863+
if (preservesStaticInformation(inferredResultType,
864+
concatOp.getResultType())) {
865+
return failure();
866+
}
867+
868+
auto newConcatOp = rewriter.create<ConcatOp>(
869+
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
870+
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
871+
newConcatOp);
872+
873+
return success();
874+
}
875+
};
776876
} // namespace
777877

778878
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
779879
MLIRContext *context) {
780-
results.add<SingleInputConcatOp>(context);
880+
results
881+
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
882+
context);
781883
}
782884

783885
//===----------------------------------------------------------------------===//

‎mlir/test/Dialect/Tensor/canonicalize.mlir

Copy file name to clipboardExpand all lines: mlir/test/Dialect/Tensor/canonicalize.mlir
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
136136

137137
// -----
138138

139+
// CHECK-LABEL: infer_concat_operand_types
140+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
141+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
142+
func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
143+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
144+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
145+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
146+
return %0 : tensor<?x12xi32>
147+
// CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
148+
}
149+
150+
// -----
151+
152+
// CHECK-LABEL: infer_concat_return_type
153+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
154+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
155+
func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
156+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
157+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
158+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
159+
return %0 : tensor<?x?xi32>
160+
// CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
161+
}
162+
163+
// -----
164+
139165
// CHECK-LABEL: func @fold_extract
140166
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141167
%const_0 = arith.constant 0 : index

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.