diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b9cef003fa365..c7169c5297d9a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -109,17 +109,110 @@ struct LinearizeVectorizable final } }; -/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works -/// on a linearized vector. -/// Following, +template +static bool stridesAllOne(TOp op) { + static_assert( + std::is_same_v || + std::is_same_v, + "expected vector.extract_strided_slice or vector.insert_strided_slice"); + ArrayAttr strides = op.getStrides(); + return llvm::all_of( + strides, [](auto stride) { return isConstantIntValue(stride, 1); }); +} + +/// Convert an array of attributes into a vector of integers, if possible. +static FailureOr> intsFromArrayAttr(ArrayAttr attrs) { + if (!attrs) + return failure(); + SmallVector ints; + ints.reserve(attrs.size()); + for (auto attr : attrs) { + if (auto intAttr = dyn_cast(attr)) { + ints.push_back(intAttr.getInt()); + } else { + return failure(); + } + } + return ints; +} + +/// Consider inserting a vector of shape `small` into a vector of shape `large`, +/// at position `offsets`: this function enumeratates all the indices in `large` +/// that are written to. The enumeration is with row-major ordering. +/// +/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 +/// positions written to are (1,3) and (1,4), which have linearized indices 8 +/// and 9. So [8,9] is returned. +/// +/// The length of the returned vector is equal to the number of elements in +/// the shape `small` (i.e. the product of dimensions of `small`). +SmallVector static getStridedSliceInsertionIndices( + ArrayRef small, ArrayRef large, + ArrayRef offsets) { + + // Example of alignment between, `large`, `small` and `offsets`: + // large = 4, 5, 6, 7, 8 + // small = 1, 6, 7, 8 + // offsets = 2, 3, 0 + // + // `offsets` has implicit trailing 0s, `small` has implicit leading 1s. + assert((large.size() >= small.size()) && + "rank of 'large' cannot be lower than rank of 'small'"); + assert((large.size() >= offsets.size()) && + "rank of 'large' cannot be lower than the number of offsets"); + unsigned delta = large.size() - small.size(); + unsigned nOffsets = offsets.size(); + auto getSmall = [&](int64_t i) -> int64_t { + return i >= delta ? small[i - delta] : 1; + }; + auto getOffset = [&](int64_t i) -> int64_t { + return i < nOffsets ? offsets[i] : 0; + }; + + // Using 2 vectors of indices, at each iteration populate the updated set of + // indices based on the old set of indices, and the size of the small vector + // in the current iteration. + SmallVector indices{0}; + int64_t stride = 1; + for (int i = large.size() - 1; i >= 0; --i) { + int64_t currentSize = indices.size(); + int64_t smallSize = getSmall(i); + int64_t nextSize = currentSize * smallSize; + SmallVector nextIndices(nextSize); + int64_t *base = nextIndices.begin(); + int64_t offset = getOffset(i) * stride; + for (int j = 0; j < smallSize; ++j) { + for (int k = 0; k < currentSize; ++k) { + base[k] = indices[k] + offset; + } + offset += stride; + base += currentSize; + } + stride *= large[i]; + indices = std::move(nextIndices); + } + return indices; +} + +/// This pattern converts a vector.extract_strided_slice operation into a +/// vector.shuffle operation that has a rank-1 (linearized) operand and result. +/// +/// For example, the following: +/// +/// ``` /// vector.extract_strided_slice %source /// { offsets = [..], strides = [..], sizes = [..] } +/// ``` +/// /// is converted to : +/// ``` /// %source_1d = vector.shape_cast %source -/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] -/// %out_nd = vector.shape_cast %out_1d -/// `shuffle_indices_1d` is computed using the offsets and sizes of the -/// extraction. +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d +/// ``` +/// +/// `shuffle_indices_1d` is computed using the offsets and sizes of the original +/// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -129,88 +222,116 @@ struct LinearizeVectorExtractStridedSlice final : OpConversionPattern(typeConverter, context, benefit) {} LogicalResult - matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, + matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType dstType = - getTypeConverter()->convertType(extractOp.getType()); - assert(dstType && "vector type destination expected."); - if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) - return rewriter.notifyMatchFailure(extractOp, - "scalable vectors are not supported."); - ArrayAttr offsets = extractOp.getOffsets(); - ArrayAttr sizes = extractOp.getSizes(); - ArrayAttr strides = extractOp.getStrides(); - if (!isConstantIntValue(strides[0], 1)) + VectorType flatOutputType = getTypeConverter()->convertType( + extractStridedSliceOp.getType()); + assert(flatOutputType && "vector type expected"); + + // Expect a legalization failure if the strides are not all 1 (if ever the + // verifier for extract_strided_slice allows non-1 strides). + if (!stridesAllOne(extractStridedSliceOp)) { return rewriter.notifyMatchFailure( - extractOp, "Strided slice with stride != 1 is not supported."); - Value srcVector = adaptor.getVector(); - // If kD offsets are specified for nD source vector (n > k), the granularity - // of the extraction is greater than 1. In this case last (n-k) dimensions - // form the extraction granularity. - // Example : - // vector.extract_strided_slice %src { - // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : - // vector<4x8x8xf32> to vector<2x2x8xf32> - // Here, extraction granularity is 8. - int64_t extractGranularitySize = 1; - int64_t nD = extractOp.getSourceVectorType().getRank(); - int64_t kD = (int64_t)offsets.size(); - int64_t k = kD; - while (k < nD) { - extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; - ++k; + extractStridedSliceOp, + "extract_strided_slice with strides != 1 not supported"); } - // Get total number of extracted slices. - int64_t nExtractedSlices = 1; - for (Attribute size : sizes) { - nExtractedSlices *= cast(size).getInt(); + + FailureOr> offsets = + intsFromArrayAttr(extractStridedSliceOp.getOffsets()); + if (failed(offsets)) { + return rewriter.notifyMatchFailure(extractStridedSliceOp, + "failed to get integer offsets"); } - // Compute the strides of the source vector considering first k dimensions. - llvm::SmallVector sourceStrides(kD, extractGranularitySize); - for (int i = kD - 2; i >= 0; --i) { - sourceStrides[i] = sourceStrides[i + 1] * - extractOp.getSourceVectorType().getShape()[i + 1]; + + ArrayRef inputShape = + extractStridedSliceOp.getSourceVectorType().getShape(); + + ArrayRef outputShape = extractStridedSliceOp.getType().getShape(); + + SmallVector indices = getStridedSliceInsertionIndices( + outputShape, inputShape, offsets.value()); + + Value srcVector = adaptor.getVector(); + rewriter.replaceOpWithNewOp( + extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); + return success(); + } +}; + +/// This pattern converts a vector.insert_strided_slice operation into a +/// vector.shuffle operation that has rank-1 (linearized) operands and result. +/// +/// For example, the following: +/// ``` +/// %0 = vector.insert_strided_slice %to_store, %into +/// {offsets = [1, 0, 0, 0], strides = [1, 1]} +/// : vector<2x2xi8> into vector<2x1x3x2xi8> +/// ``` +/// +/// is converted to +/// ``` +/// %to_store_1d +/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> +/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> +/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] +/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> +/// ``` +/// +/// where shuffle_indices_1d in this case is +/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. +/// ^^^^^^^^^^^^^^ +/// to_store_1d +/// +struct LinearizeVectorInsertStridedSlice final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Expect a legalization failure if the strides are not all 1 (if ever the + // verifier for insert_strided_slice allows non-1 strides). + if (!stridesAllOne(insertStridedSliceOp)) { + return rewriter.notifyMatchFailure( + insertStridedSliceOp, + "insert_strided_slice with strides != 1 not supported"); } - // Final shuffle indices has nExtractedSlices * extractGranularitySize - // elements. - llvm::SmallVector indices(nExtractedSlices * - extractGranularitySize); - // Compute the strides of the extracted kD vector. - llvm::SmallVector extractedStrides(kD, 1); - // Compute extractedStrides. - for (int i = kD - 2; i >= 0; --i) { - extractedStrides[i] = - extractedStrides[i + 1] * cast(sizes[i + 1]).getInt(); + + VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); + ArrayRef inputShape = inputType.getShape(); + + VectorType outputType = insertStridedSliceOp.getType(); + ArrayRef outputShape = outputType.getShape(); + int64_t nOutputElements = outputType.getNumElements(); + + FailureOr> offsets = + intsFromArrayAttr(insertStridedSliceOp.getOffsets()); + if (failed(offsets)) { + return rewriter.notifyMatchFailure(insertStridedSliceOp, + "failed to get integer offsets"); } - // Iterate over all extracted slices from 0 to nExtractedSlices - 1 - // and compute the multi-dimensional index and the corresponding linearized - // index within the source vector. - for (int64_t i = 0; i < nExtractedSlices; ++i) { - int64_t index = i; - // Compute the corresponding multi-dimensional index. - llvm::SmallVector multiDimIndex(kD, 0); - for (int64_t j = 0; j < kD; ++j) { - multiDimIndex[j] = (index / extractedStrides[j]); - index -= multiDimIndex[j] * extractedStrides[j]; - } - // Compute the corresponding linearized index in the source vector - // i.e. shift the multiDimIndex by the offsets. - int64_t linearizedIndex = 0; - for (int64_t j = 0; j < kD; ++j) { - linearizedIndex += - (cast(offsets[j]).getInt() + multiDimIndex[j]) * - sourceStrides[j]; - } - // Fill the indices array form linearizedIndex to linearizedIndex + - // extractGranularitySize. - for (int64_t j = 0; j < extractGranularitySize; ++j) { - indices[i * extractGranularitySize + j] = linearizedIndex + j; - } + SmallVector sliceIndices = getStridedSliceInsertionIndices( + inputShape, outputShape, offsets.value()); + + SmallVector indices(nOutputElements); + std::iota(indices.begin(), indices.end(), 0); + for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { + indices[sliceIndex] = index + nOutputElements; } - // Perform a shuffle to extract the kD vector. - rewriter.replaceOpWithNewOp( - extractOp, dstType, srcVector, srcVector, indices); + + Value flatToStore = adaptor.getValueToStore(); + Value flatDest = adaptor.getDest(); + rewriter.replaceOpWithNewOp(insertStridedSliceOp, + flatDest.getType(), flatDest, + flatToStore, indices); return success(); } }; @@ -296,7 +417,7 @@ struct LinearizeVectorExtract final // Skip if result is not a vector type if (!isa(extractOp.getType())) return rewriter.notifyMatchFailure(extractOp, - "scalar extract is not supported."); + "scalar extract not supported"); Type dstTy = getTypeConverter()->convertType(extractOp.getType()); assert(dstTy && "expected 1-D vector type"); @@ -453,8 +574,8 @@ struct LinearizeVectorSplat final static bool isNotLinearizableBecauseScalable(Operation *op) { bool unsupported = - isa( - op); + isa(op); if (!unsupported) return false; @@ -539,6 +660,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice, + LinearizeVectorInsertStridedSlice>(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 01ad1ac48b012..3cdbef8db604b 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s // CHECK-LABEL: test_linearize // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) @@ -131,9 +131,9 @@ func.func @test_0d_vector() -> vector { // ----- -// CHECK-LABEL: test_extract_strided_slice_1 +// CHECK-LABEL: test_extract_strided_slice_2D // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> { -func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { +func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> { // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32> // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] @@ -147,13 +147,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf // ----- -// CHECK-LABEL: func.func @test_extract_strided_slice_1_scalable( +// CHECK-LABEL: func.func @test_extract_strided_slice_2D_scalable( // CHECK-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { -func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { // CHECK-NOT: vector.shuffle // CHECK-NOT: vector.shape_cast - // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32> + // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32> // CHECK: return %[[RES]] : vector<2x[8]xf32> @@ -162,9 +162,9 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve // ----- -// CHECK-LABEL: test_extract_strided_slice_2 +// CHECK-LABEL: test_extract_strided_slice_3D // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> { -func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { +func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> { // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32> // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] @@ -178,6 +178,76 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4 // ----- +// Test of insert_strided_slice -> shuffle. +// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. +// CHECK-LABEL: insert_strided_slice_2D_into_4D +func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> { + +// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<4xi8> +// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<12xi8> +// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] +// CHECK-SAME: [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8> + +// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8> +// CHECK: return %[[RES]] : vector<2x1x3x2xi8> + return %0 : vector<2x1x3x2xi8> +} + +// ----- + +// Test of insert_strided_slice -> shuffle. +// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]] +// ^ ^ +// | | +// where the 2 elements are inserted into the 3x3x2 vector +// CHECK-LABEL: insert_strided_slice_3D +func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg1 : vector<3x3x2xi8>) -> vector<3x3x2xi8> { + +// CHECK-DAG: %[[ARG0:.*]] = vector.shape_cast {{.*}} to vector<2xi8> +// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8> +// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]] +// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8> + +// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8> +// CHECK: return %[[RES]] : vector<3x3x2xi8> + return %0 : vector<3x3x2xi8> +} + +// ----- + +// CHECK-LABEL: insert_strided_slice_2D_higher_offsets +func.func @insert_strided_slice_2D_higher_offsets(%arg0 : vector<2x1xi8>, %arg1 : vector<2x2xi8>, %arg2 : vector<5x2xi8>) -> vector<5x2xi8> { + + // CHECK: [0, 1, 2, 3, 10, 11, 12, 13, 8, 9] + // ^^^ ^^^ ^^^ ^^^ + // insertion indices + %0 = vector.insert_strided_slice %arg1, %arg2 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x2xi8> into vector<5x2xi8> + + // CHECK: [0, 1, 2, 3, 10, 5, 11, 7, 8, 9] + // ^^^ ^^^ + %1 = vector.insert_strided_slice %arg0, %0 {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> + + // CHECK: [0, 1, 2, 3, 4, 5, 6, 10, 8, 11] + // ^^^ ^^^ + %2 = vector.insert_strided_slice %arg0, %1 {offsets = [3, 1], sizes = [2, 1], strides = [1, 1]} : vector<2x1xi8> into vector<5x2xi8> + + return %2 : vector<5x2xi8> +} + +// ----- + +// CHECK-LABEL: negative_insert_strided_slice_scalable +// CHECK-NOT: vector.shuffle +// CHECK: return +func.func @negative_insert_strided_slice_scalable(%arg0 : vector<1x[2]xi8>, %arg1 : vector<2x[2]xi8>) -> vector<2x[2]xi8> { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0], strides = [1,1]} : vector<1x[2]xi8> into vector<2x[2]xi8> + return %0 : vector<2x[2]xi8> +} + +// ----- + // CHECK-LABEL: test_vector_shuffle // CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> { func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> { @@ -345,3 +415,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { %0 = vector.splat %arg0 : vector<4x[2]xi32> return %0 : vector<4x[2]xi32> } +