diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 6fd992afbf043..2f083b55d4904 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [ def TileLoadOp : ArmSME_Op<"tile_load", [ ArmSMETileOpInterface, AttrSizedOperandSegments, + AllElementTypesMatch<["result", "base"]>, OptionalTypesMatchWith< "padding type matches element type of result", "result", "padding", @@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ ``` }]; let arguments = (ins - Arg:$base, + Arg, "the reference to load from", [MemRead]>:$base, Variadic:$indices, Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout @@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ def TileStoreOp : ArmSME_Op<"tile_store", [ ArmSMETileOpInterface, AttrSizedOperandSegments, + AllElementTypesMatch<["valueToStore", "base"]>, HasMatchingMaskTypeConstraint<"valueToStore", "mask">, ]> { let summary = "Tile store operation"; @@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ ``` }]; let arguments = (ins SMETile:$valueToStore, - Arg:$base, + Arg, "the reference to store to", [MemWrite]>:$base, Variadic:$indices, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); @@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ ArmSMETileOpInterface, + AllElementTypesMatch<["tile", "base"]>, AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask"> ]> { let summary = "Tile slice load and update operation"; @@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ ArmSMETileOpInterface, + AllElementTypesMatch<["tile", "base"]>, TileSliceMaskConstraint<"tile", "mask"> ]> { let summary = "Tile slice store operation"; diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 630414030d98b..458628c29c6ac 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -33,20 +33,15 @@ SmallVector getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, Value tileSliceNumElts, Location loc, PatternRewriter &rewriter) { - assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + assert(rank == 2 && "memref has unexpected rank!"); SmallVector outIndices; auto tileSliceOffset = tileSliceIndex; - if (rank == 1) - tileSliceOffset = - rewriter.create(loc, tileSliceOffset, tileSliceNumElts); auto baseIndexPlusTileSliceOffset = rewriter.create(loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); - - if (rank == 2) - outIndices.push_back(indices[1]); + outIndices.push_back(indices[1]); return outIndices; } @@ -60,6 +55,10 @@ FailureOr createLoadStoreForOverTileSlices( makeLoopBody) { PatternRewriter::InsertionGuard guard(rewriter); + // TODO: This case should be captured and rejected by a verifier. + if (memrefIndices.size() != 2) + return rewriter.notifyMatchFailure(loc, "invalid number of indices"); + auto minTileSlices = rewriter.create( loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 700b2412ff7a7..8c5a098a0c785 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> { // ----- -func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { +func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { %c0 = arith.constant 0 : index // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}} %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8> @@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8 // ----- -func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> { +func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> { %c0 = arith.constant 0 : index // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}} %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32> return %0 : vector<[4]x[4]xf32> } +// ----- + +func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}} + %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + //===----------------------------------------------------------------------===// // arm_sme.extract_tile_slice //===----------------------------------------------------------------------===// // ----- -func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> { +func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> { + // expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}} + %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32> + return %0 : vector<[2]xf32> +} + +// ----- + +func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> { // expected-error@+1 {{op failed to verify that type of 'result' matches type of 'tile' slice}} - %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32> - return %0 : vector<[2]xf64> + %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32> + return %0 : vector<[4]xf64> } //===----------------------------------------------------------------------===// @@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64 return } +// ----- + +func.func @arm_sme_tile_load__bad_memref_rank(%src : memref, %pad : f64) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op operand #0 must be 2D memref of any type values, but got 'memref'}} + %tile = arm_sme.tile_load %src[%c0], %pad, : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_tile_load__bad_element_type(%src : memref) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{failed to verify that all of {result, base} have same element type}} + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.load_tile_slice //===----------------------------------------------------------------------===// @@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref, %mask : return } +// ----- + +func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}} + %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref, vector<[16]xi1>, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.tile_store //===----------------------------------------------------------------------===// @@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask return } +// ----- + +func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op operand #1 must be 2D memref of any type values, but got 'memref'}} + arm_sme.tile_store %tile, %dest[%c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that all of {valueToStore, base} have same element type}} + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.store_tile_slice //===----------------------------------------------------------------------===// @@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>, return } +// ----- + +func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref) -> () { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that all of {tile, base} have same element type}} + arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref, vector<[16]xi1>, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir index b7144be08a853..8d4b4a07994e2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir @@ -17,7 +17,7 @@ func.func @entry() { %za_s_size = arith.muli %svl_s, %svl_s : index // Allocate memory. - %mem1 = memref.alloca(%za_s_size) : memref + %mem1 = memref.alloca(%svl_s, %svl_s) : memref // Fill each "row" of "mem1" with row number. // @@ -29,15 +29,15 @@ func.func @entry() { // 3, 3, 3, 3 // %init_0 = arith.constant 0 : i32 - scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) { + scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) { %splat_val = vector.broadcast %val : i32 to vector<[4]xi32> - vector.store %splat_val, %mem1[%i] : memref, vector<[4]xi32> + vector.store %splat_val, %mem1[%i, %c0] : memref, vector<[4]xi32> %val_next = arith.addi %val, %c1_i32 : i32 scf.yield %val_next : i32 } // Load tile from "mem1" vertically. - %0 = arm_sme.tile_load %mem1[%c0, %c0] layout : memref, vector<[4]x[4]xi32> + %0 = arm_sme.tile_load %mem1[%c0, %c0] layout : memref, vector<[4]x[4]xi32> // 1. ORIGINAL HORIZONTAL LAYOUT // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least @@ -50,8 +50,8 @@ func.func @entry() { // CHECK-NEXT: ( 3, 3, 3, 3 // CHECK: TILE END vector.print str "TILE BEGIN\n" - scf.for %i = %c0 to %za_s_size step %svl_s { - %tileslice = vector.load %mem1[%i] : memref, vector<[4]xi32> + scf.for %i = %c0 to %svl_s step %c1 { + %tileslice = vector.load %mem1[%i, %c0] : memref, vector<[4]xi32> vector.print %tileslice : vector<[4]xi32> } vector.print str "TILE END\n"