Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][gpu] Add GPU subgroup MMA extract and insert operations #139048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions 89 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,95 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}

def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "res",
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{

let summary = "Extract a value from GPU warp by invocation and indices";

let description = [{
The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix`
that is stored at subgroup level.

This operation takes `!gpu.mma_matrix` as its first operand. It is the source
matrix across a subgroup. The op returns a scalar value stored in the invocation
in the subgroup.

Since `matrix` is packed into the the threads within a subgroup, `indices` are
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
does not necessarily refer to the first element of the matrix, but the first element
that a particular thread holds.

The mapping of matrix elements to threads is not defined by this operation and may
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
`[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.

Example:

```mlir
%c0 = arith.constant 0 : index
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
```
}];

let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);

let results = (outs AnyIntegerOrFloat:$res);

let assemblyFormat = [{
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
}];
}

def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "value",
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{

let summary = "Insert a value into GPU warp by invocation and indices";

let description = [{
The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix`
that is stored at subgroup level.

This operation takes scalar value as its first operand and `!gpu.mma_matrix`
as its second operand. The op inserts the scalar value to the matrix.

Since `matrix` is packed into the the threads within a subgroup, `indices` are
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
does not necessarily refer to the first element of the matrix, but the first element
that a particular thread holds.

The mapping of matrix elements to threads is not defined by this operation and may
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
`[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.

The op returns `!gpu.mma_matrix` with the updated value.

Example:

```mlir
%c0 = arith.constant 0 : index
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
-> !gpu.mma_matrix<16x16xf16, "COp">
```
}];

let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
Variadic<Index>:$indices);

let results = (outs GPU_MMAMatrix:$res);

let assemblyFormat = [{
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
}];
}

def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
Expand Down
63 changes: 63 additions & 0 deletions 63 mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
}
};

/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaExtractOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value matrix = adaptor.getMatrix();
auto coopType =
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
matrix.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

SmallVector<int32_t> intValues;
for (Value val : op.getIndices()) {
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
intValues.push_back(static_cast<int32_t>(constOp.value()));
} else {
return rewriter.notifyMatchFailure(op, "indices must be constants");
}
}

Type elementType = coopType.getElementType();
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
return success();
}
};

/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaInsertOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value value = adaptor.getValue();
Value matrix = adaptor.getMatrix();
auto coopType = getTypeConverter()->convertType(matrix.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

SmallVector<int32_t> intValues;
for (Value val : op.getIndices()) {
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
intValues.push_back(static_cast<int32_t>(constOp.value()));
} else {
return rewriter.notifyMatchFailure(op, "indices must be constants");
}
}

rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
Expand Down Expand Up @@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ module attributes {
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_extract_thread_local_op
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
gpu.func @gpu_wmma_extract_thread_local_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
%c0 = arith.constant 0 : index
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_insert_thread_local_op
// CHECK-SAME: %[[ARG0:.+]]: f16
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_insert_thread_local_op(%val: f16,
%m: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%c0 = arith.constant 0 : index
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
Expand Down
14 changes: 14 additions & 0 deletions 14 mlir/test/Dialect/GPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,20 @@ module attributes {gpu.container_module} {
gpu.wait [%token16]
return
}

// CHECK-LABEL: func @extract_insert_mma
func.func @extract_insert_mma(%src : !gpu.mma_matrix<16x16xf32, "COp">,
%ptr: memref<16x16xf32>) {
%zero = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
// CHECK: gpu.subgroup_mma_extract_thread_local
%val = gpu.subgroup_mma_extract_thread_local %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
%m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: gpu.subgroup_mma_insert_thread_local
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
return
}
}

// Just check that this doesn't crash.
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.