33
33
#include " llvm/ADT/STLExtras.h"
34
34
#include " llvm/ADT/SmallBitVector.h"
35
35
#include " llvm/ADT/StringRef.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include < algorithm>
38
39
#include < optional>
@@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330
331
331
332
// / Determines whether the tensor::CastOp casts to a more static version of the
332
333
// / source tensor. This is useful to fold into a producing op and implement
333
- // / canonicaliation patterns with the `tensor.cast` op as the root, but producer
334
- // / being from different dialects. Returns true when all conditions are met:
334
+ // / canonicalization patterns with the `tensor.cast` op as the root, but
335
+ // / producer being from different dialects. Returns true when all conditions are
336
+ // / met:
335
337
// / 1. source and result and ranked tensors with same element type and rank.
336
338
// / 2. the result type has more static information than the source.
337
339
// /
@@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773
775
return success ();
774
776
}
775
777
};
778
+
779
+ // / Propagate static shapes into the operands of a `tensor.concat`.
780
+ // /
781
+ // / `tensor.concat` requires every operand to match on all dimensions except the
782
+ // / concatenation dimension. If one operand is already static in those
783
+ // / dimensions, the other operands may safely be refined to that same static
784
+ // / shape.
785
+ // /
786
+ // / Example:
787
+ // /
788
+ // / ```mlir
789
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790
+ // / tensor<?x12xi32>
791
+ // / ```
792
+ // / ->
793
+ // / ```mlir
794
+ // / %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795
+ // / %2 = tensor.concat dim(0) %0, %cast :
796
+ // / (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797
+ // / ```
798
+ struct InferConcatOperandTypes : public OpRewritePattern <ConcatOp> {
799
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
800
+
801
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
802
+ PatternRewriter &rewriter) const override {
803
+ auto operandTensorTypes =
804
+ llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
805
+ return llvm::cast<RankedTensorType>(type);
806
+ });
807
+
808
+ int64_t dim = concatOp.getDim ();
809
+ ArrayRef<int64_t > inferredResultShape =
810
+ ConcatOp::inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
811
+
812
+ // Find operands for which a more static shape can be inferred.
813
+ LogicalResult matched = failure ();
814
+ for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
815
+ // Compute inferred type for operand.
816
+ SmallVector<int64_t > inferredOperandShape (inferredResultShape);
817
+ inferredOperandShape[dim] = operandType.getDimSize (dim);
818
+ auto inferredOperandType = RankedTensorType::get (
819
+ inferredOperandShape, operandType.getElementType ());
820
+
821
+ // Check if inferred type is more static.
822
+ if (!preservesStaticInformation (inferredOperandType, operandType)) {
823
+ matched = success ();
824
+
825
+ // Use refined operand type and create cast from original operand.
826
+ auto castOp =
827
+ rewriter.create <CastOp>(concatOp->getLoc (), inferredOperandType,
828
+ concatOp.getOperand (operandIdx));
829
+ rewriter.modifyOpInPlace (concatOp, [=, operandIdx = operandIdx] {
830
+ concatOp->setOperand (operandIdx, castOp->getResult (0 ));
831
+ });
832
+ }
833
+ }
834
+
835
+ return matched;
836
+ }
837
+ };
838
+
839
+ // Ensure `tensor.concat`'s result type is at least as static as can be inferred
840
+ // from its operand types.
841
+ // /
842
+ // / Example:
843
+ // / ```mlir
844
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
845
+ // / tensor<?x?xi32>
846
+ // / ```
847
+ // / ->
848
+ // / ```mlir
849
+ // / %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
850
+ // / -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
851
+ // / tensor<?x?xi32>
852
+ // / ```
853
+ struct InferConcatResultType : public OpRewritePattern <ConcatOp> {
854
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
855
+
856
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
857
+ PatternRewriter &rewriter) const override {
858
+ int64_t dim = concatOp.getDim ();
859
+ RankedTensorType inferredResultType =
860
+ ConcatOp::inferResultType (dim, concatOp->getOperandTypes ());
861
+
862
+ // The result type should be at least as static as inferred result type.
863
+ if (preservesStaticInformation (inferredResultType,
864
+ concatOp.getResultType ())) {
865
+ return failure ();
866
+ }
867
+
868
+ auto newConcatOp = rewriter.create <ConcatOp>(
869
+ concatOp->getLoc (), inferredResultType, dim, concatOp->getOperands ());
870
+ rewriter.replaceOpWithNewOp <CastOp>(concatOp, concatOp.getResultType (),
871
+ newConcatOp);
872
+
873
+ return success ();
874
+ }
875
+ };
776
876
} // namespace
777
877
778
878
void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
779
879
MLIRContext *context) {
780
- results.add <SingleInputConcatOp>(context);
880
+ results
881
+ .add <SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
882
+ context);
781
883
}
782
884
783
885
// ===----------------------------------------------------------------------===//
0 commit comments