Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 613f72a

Browse filesBrowse files
committed
[rfc][mlir][gpu] Add operations to extract/insert/rotate within subgroup
Add gpu.rotate, gpu.subgroup_mma_extract, and gpu.subgroup_mma_insert operations.
1 parent 316a6ff commit 613f72a
Copy full SHA for 613f72a

File tree

5 files changed

+257
-1
lines changed
Filter options

5 files changed

+257
-1
lines changed

‎mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+102Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op<
13641364
];
13651365
}
13661366

1367+
def GPU_RotateOp : GPU_Op<
1368+
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
1369+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
1370+
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
1371+
let summary = "Rotate values within a subgroup.";
1372+
let description = [{
1373+
The "rotate" op moves values to a across lanes circularly (a.k.a.,
1374+
invocations, work items) within the same subgroup. The `width` argument
1375+
specifies the number of lanes that participate in the rotation, and must
1376+
be uniform across all lanes. Further, the first `width` lanes of the
1377+
subgroup must be active.
1378+
1379+
Example:
1380+
1381+
```mlir
1382+
%cst1 = arith.constant 1 : i32
1383+
%width = arith.constant 16 : i32
1384+
%1 = gpu.rotate %0, %cst1, %width : f32
1385+
```
1386+
1387+
For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
1388+
For lane k == 0, return the value from lane 15.
1389+
}];
1390+
1391+
let assemblyFormat = [{
1392+
$value `,` $offset `,` $width attr-dict `:` type($value)
1393+
}];
1394+
}
1395+
13671396
def GPU_BarrierOp : GPU_Op<"barrier"> {
13681397
let summary = "Synchronizes all work items of a workgroup.";
13691398
let description = [{
@@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19191948
}];
19201949
}
19211950

1951+
def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
1952+
[Pure,
1953+
TypesMatchWith<"value type matches element type of mma_matrix",
1954+
"matrix", "res",
1955+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
1956+
1957+
let summary = "Extract a value from GPU warp by invocation and indices";
1958+
1959+
let description = [{
1960+
The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
1961+
by the invocation in a subgroup.
1962+
1963+
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
1964+
matrix across a subgroup. The op returns a scalar value stored in the invocation
1965+
in the subgroup. If there are multiple values packed in an invocation, use
1966+
`indices` to specify the element to extract.
1967+
1968+
Example:
1969+
1970+
```mlir
1971+
%c0 = arith.constant 0 : index
1972+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1973+
```
1974+
}];
1975+
1976+
let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
1977+
1978+
let results = (outs AnyIntegerOrFloat:$res);
1979+
1980+
let assemblyFormat = [{
1981+
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
1982+
}];
1983+
}
1984+
1985+
def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
1986+
[Pure,
1987+
TypesMatchWith<"value type matches element type of mma_matrix",
1988+
"matrix", "value",
1989+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
1990+
1991+
let summary = "Insert a value into GPU warp by invocation and indices";
1992+
1993+
let description = [{
1994+
The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
1995+
by the invocation in a subgroup.
1996+
1997+
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1998+
as its second operand. It is the matrix across a subgroup. The op inserts the
1999+
scalar value stored in the invocation in the subgroup to the matrix. If there
2000+
are multiple values packed in an invocation, use `indices` to specify the
2001+
location to insert in the packing.
2002+
2003+
The op returns `!gpu.mma_matrix` with the updated value.
2004+
2005+
Example:
2006+
2007+
```mlir
2008+
%c0 = arith.constant 0 : index
2009+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
2010+
-> !gpu.mma_matrix<16x16xf16, "COp">
2011+
```
2012+
}];
2013+
2014+
let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
2015+
Variadic<Index>:$indices);
2016+
2017+
let results = (outs GPU_MMAMatrix:$res);
2018+
2019+
let assemblyFormat = [{
2020+
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
2021+
}];
2022+
}
2023+
19222024
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
19232025
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
19242026
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;

‎mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+40-1Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
122122
ConversionPatternRewriter &rewriter) const override;
123123
};
124124

125+
/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op.
126+
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
127+
public:
128+
using OpConversionPattern::OpConversionPattern;
129+
130+
LogicalResult
131+
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
132+
ConversionPatternRewriter &rewriter) const override;
133+
};
134+
125135
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126136
public:
127137
using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
458468
return success();
459469
}
460470

471+
//===----------------------------------------------------------------------===//
472+
// Rotate
473+
//===----------------------------------------------------------------------===//
474+
475+
LogicalResult GPURotateConversion::matchAndRewrite(
476+
gpu::RotateOp rotateOp, OpAdaptor adaptor,
477+
ConversionPatternRewriter &rewriter) const {
478+
// Require the rotate width to be the same as the target's subgroup size,
479+
// given that for SPIR-V non-uniform subgroup ops, we cannot select
480+
// participating invocations.
481+
auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
482+
unsigned subgroupSize =
483+
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
484+
IntegerAttr widthAttr;
485+
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
486+
widthAttr.getValue().getZExtValue() != subgroupSize)
487+
return rewriter.notifyMatchFailure(
488+
rotateOp, "rotate width and target subgroup size mismatch");
489+
490+
Location loc = rotateOp.getLoc();
491+
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
492+
493+
Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
494+
loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth());
495+
496+
rewriter.replaceOp(rotateOp, result);
497+
return success();
498+
}
499+
461500
//===----------------------------------------------------------------------===//
462501
// Group ops
463502
//===----------------------------------------------------------------------===//
@@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
733772
RewritePatternSet &patterns) {
734773
patterns.add<
735774
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
736-
GPUReturnOpConversion, GPUShuffleConversion,
775+
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
737776
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
738777
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
739778
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

‎mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+63Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
111111
}
112112
};
113113

114+
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115+
/// matrix ops.
116+
struct WmmaExtractOpToSPIRVLowering final
117+
: OpConversionPattern<gpu::SubgroupMmaExtractOp> {
118+
using OpConversionPattern::OpConversionPattern;
119+
120+
LogicalResult
121+
matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
122+
ConversionPatternRewriter &rewriter) const override {
123+
Value matrix = adaptor.getMatrix();
124+
auto coopType =
125+
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126+
matrix.getType());
127+
if (!coopType)
128+
return rewriter.notifyMatchFailure(op, "type conversion failed");
129+
130+
SmallVector<int32_t> intValues;
131+
for (Value val : op.getIndices()) {
132+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133+
intValues.push_back(static_cast<int32_t>(constOp.value()));
134+
} else {
135+
return rewriter.notifyMatchFailure(op, "Indices must be constants");
136+
}
137+
}
138+
139+
Type elementType = coopType.getElementType();
140+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141+
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142+
return success();
143+
}
144+
};
145+
146+
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147+
/// matrix ops.
148+
struct WmmaInsertOpToSPIRVLowering final
149+
: OpConversionPattern<gpu::SubgroupMmaInsertOp> {
150+
using OpConversionPattern::OpConversionPattern;
151+
152+
LogicalResult
153+
matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
154+
ConversionPatternRewriter &rewriter) const override {
155+
Value value = adaptor.getValue();
156+
Value matrix = adaptor.getMatrix();
157+
auto coopType = getTypeConverter()->convertType(matrix.getType());
158+
if (!coopType)
159+
return rewriter.notifyMatchFailure(op, "type conversion failed");
160+
161+
SmallVector<int32_t> intValues;
162+
for (Value val : op.getIndices()) {
163+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164+
intValues.push_back(static_cast<int32_t>(constOp.value()));
165+
} else {
166+
return rewriter.notifyMatchFailure(op, "Indices must be constants");
167+
}
168+
}
169+
170+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171+
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172+
return success();
173+
}
174+
};
175+
114176
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
115177
/// the default case.
116178
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296358
MLIRContext *context = patterns.getContext();
297359
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298360
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361+
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
299362
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
300363
// Give the following patterns higher benefit to prevail over the default one.
301364
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+25Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, #spirv.resource_limits<subgroup_size = 16>>
6+
} {
7+
8+
gpu.module @kernels {
9+
// CHECK-LABEL: spirv.func @rotate()
10+
gpu.func @rotate() kernel
11+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [4, 4, 1]>} {
12+
// CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32
13+
// CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32
14+
// CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32
15+
%offset = arith.constant 8 : i32
16+
%width = arith.constant 16 : i32
17+
%val = arith.constant 42.0 : f32
18+
19+
// CHECK: spirv.GroupNonUniformRotateKHR <Subgroup>, %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]])
20+
%result = gpu.rotate %val, %offset, %width : f32
21+
gpu.return
22+
}
23+
}
24+
25+
}

‎mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Copy file name to clipboardExpand all lines: mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ module attributes {
9393
gpu.return
9494
}
9595

96+
// CHECK-LABEL: spirv.func @gpu_wmma_extract_op
97+
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
98+
gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
99+
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
100+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
101+
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
102+
%c0 = arith.constant 0 : index
103+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
104+
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
105+
gpu.return
106+
}
107+
108+
// CHECK-LABEL: spirv.func @gpu_wmma_insert_op
109+
// CHECK-SAME: %[[ARG0:.+]]: f16
110+
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
111+
gpu.func @gpu_wmma_insert_op(%val: f16,
112+
%m: !gpu.mma_matrix<16x16xf16, "COp">,
113+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
114+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
115+
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
116+
%c0 = arith.constant 0 : index
117+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
118+
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
119+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
120+
gpu.return
121+
}
122+
96123
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97124
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98125
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.