Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 91f3cdb

Browse filesBrowse files
authored
[mlir][gpu] Pattern to promote gpu.shuffle to specialized AMDGPU ops (#137109)
Only swizzle promotion for now, may add DPP ops support later.
1 parent 382ad6f commit 91f3cdb
Copy full SHA for 91f3cdb

File tree

Expand file treeCollapse file tree

8 files changed

+129
-18
lines changed
Filter options
Expand file treeCollapse file tree

8 files changed

+129
-18
lines changed

‎mlir/include/mlir/Conversion/Passes.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Conversion/Passes.td
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
601601
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
602602
let dependentDialects = [
603603
"ROCDL::ROCDLDialect",
604+
"amdgpu::AMDGPUDialect",
604605
"cf::ControlFlowDialect",
605606
"memref::MemRefDialect",
606607
];
@@ -1415,7 +1416,7 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14151416
"bool", /*default=*/"false",
14161417
"Use the preferred alignment of a vector type in load/store "
14171418
"operations instead of the alignment of the element type of the "
1418-
"memref. This flag is intended for use with hardware which requires"
1419+
"memref. This flag is intended for use with hardware which requires"
14191420
"vector alignment, or in application contexts where it is known all "
14201421
"vector access are naturally aligned. ">,
14211422
Option<"amx", "enable-amx",

‎mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+19-8Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,24 @@ def MapNestedForallToThreads :
132132
TransformEachOpTrait,
133133
TransformOpInterface]> {
134134
let description = [{
135-
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
135+
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
136136
distributed `gpu.thread_id` attribute.
137137

138138
The operation searches for `scf.forall` ops nested under `target` and maps
139-
each such op to GPU threads.
140-
139+
each such op to GPU threads.
140+
141141
`scf.forall` induction variables are rewritten to `gpu.thread_id` according
142142
to the `mapping` attribute.
143143

144144
Different types of mappings attributes are supported:
145145
- the block_dims is a list of integers that specifies the number of
146146
threads in each dimension. This is a mandatory attribute that is used
147-
to constrain the number of threads in each dimension. If an
147+
to constrain the number of threads in each dimension. If an
148148
`scf.forall` op is mapped to fewer threads, predication occurs.
149149
- the warp_dims is a list of integers that specifies the number of
150150
warps in each dimension. This is an optional attribute that is used
151151
to constrain the number of warps in each dimension. When present, this
152-
attribute must be specified in a way that is compatible with the
152+
attribute must be specified in a way that is compatible with the
153153
block_dims attribute. If an `scf.forall` op is mapped to fewer warps,
154154
predication occurs.
155155

@@ -164,7 +164,7 @@ def MapNestedForallToThreads :
164164
inserted after each scf.forall op. At this time, this is an all or nothing
165165
choice. This will need to be tightened in the future.
166166

167-
The operation alters the block size of the given gpu_launch using the
167+
The operation alters the block size of the given gpu_launch using the
168168
mandatory block_dims argument.
169169

170170
#### Return modes:
@@ -268,7 +268,7 @@ def MapForallToBlocks :
268268
Only scf.forall distributed to **at most 3 dimensions** are
269269
currently supported.
270270

271-
The operation alters the block size of the given gpu_launch using the
271+
The operation alters the block size of the given gpu_launch using the
272272
grid_dims argument.
273273

274274
#### Return modes:
@@ -300,7 +300,7 @@ def MapForallToBlocks :
300300
`:` functional-type($target, $result)
301301
}];
302302
let hasVerifier = 1;
303-
303+
304304
let extraClassDeclaration = [{
305305
::mlir::DiagnosedSilenceableFailure applyToOne(
306306
::mlir::transform::TransformRewriter &rewriter,
@@ -310,4 +310,15 @@ def MapForallToBlocks :
310310
}];
311311
}
312312

313+
def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
314+
"apply_patterns.gpu.gpu_shuffle_to_amdgpu",
315+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
316+
let description = [{
317+
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
318+
AMDGPU intrinsics.
319+
}];
320+
let assemblyFormat = "attr-dict";
321+
}
322+
323+
313324
#endif // GPU_TRANSFORM_OPS

‎mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
113113
/// Erase barriers that do not enforce conflicting memory side effects.
114114
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
115115

116+
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
117+
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
118+
116119
/// Generate the code for registering passes.
117120
#define GEN_PASS_REGISTRATION
118121
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

‎mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2828
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2929
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
30+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
3031
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
3132
#include "mlir/Dialect/Func/IR/FuncOps.h"
3233
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -197,8 +198,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
197198
Value widthOrZeroIfOutside =
198199
rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
199200
Value dstLane;
200-
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
201-
// perf.
201+
202202
switch (op.getMode()) {
203203
case gpu::ShuffleMode::UP:
204204
dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId,
@@ -319,6 +319,7 @@ struct LowerGpuOpsToROCDLOpsPass final
319319
{
320320
RewritePatternSet patterns(ctx);
321321
populateGpuRewritePatterns(patterns);
322+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
322323
(void)applyPatternsGreedily(m, std::move(patterns));
323324
}
324325

‎mlir/lib/Dialect/GPU/CMakeLists.txt

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/GPU/CMakeLists.txt
+6-5Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ add_mlir_dialect_library(MLIRGPUTransforms
3737
Transforms/ModuleToBinary.cpp
3838
Transforms/NVVMAttachTarget.cpp
3939
Transforms/ParallelLoopMapper.cpp
40+
Transforms/PromoteShuffleToAMDGPU.cpp
4041
Transforms/ROCDLAttachTarget.cpp
41-
Transforms/ShuffleRewriter.cpp
4242
Transforms/SPIRVAttachTarget.cpp
43+
Transforms/ShuffleRewriter.cpp
4344
Transforms/SubgroupIdRewriter.cpp
4445
Transforms/SubgroupReduceLowering.cpp
4546

@@ -53,8 +54,8 @@ add_mlir_dialect_library(MLIRGPUTransforms
5354
MLIRParallelLoopMapperEnumsGen
5455

5556
LINK_LIBS PUBLIC
56-
MLIRAffineUtils
5757
MLIRAMDGPUDialect
58+
MLIRAffineUtils
5859
MLIRArithDialect
5960
MLIRAsyncDialect
6061
MLIRBufferizationDialect
@@ -68,12 +69,12 @@ add_mlir_dialect_library(MLIRGPUTransforms
6869
MLIRMemRefDialect
6970
MLIRNVVMTarget
7071
MLIRPass
72+
MLIRROCDLDialect
73+
MLIRROCDLTarget
7174
MLIRSCFDialect
72-
MLIRSideEffectInterfaces
7375
MLIRSPIRVTarget
76+
MLIRSideEffectInterfaces
7477
MLIRSupport
75-
MLIRROCDLDialect
76-
MLIRROCDLTarget
7778
MLIRTransformUtils
7879
MLIRVectorDialect
7980
)

‎mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+9-2Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1415
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -136,6 +137,11 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
136137
populateGpuRewritePatterns(patterns);
137138
}
138139

140+
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
141+
RewritePatternSet &patterns) {
142+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
143+
}
144+
139145
//===----------------------------------------------------------------------===//
140146
// ApplyUnrollVectorsSubgroupMmaOp
141147
//===----------------------------------------------------------------------===//
@@ -914,9 +920,10 @@ class GPUTransformDialectExtension
914920
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
915921

916922
GPUTransformDialectExtension() {
917-
declareGeneratedDialect<scf::SCFDialect>();
918-
declareGeneratedDialect<arith::ArithDialect>();
919923
declareGeneratedDialect<GPUDialect>();
924+
declareGeneratedDialect<amdgpu::AMDGPUDialect>();
925+
declareGeneratedDialect<arith::ArithDialect>();
926+
declareGeneratedDialect<scf::SCFDialect>();
920927
registerTransformOps<
921928
#define GET_OP_LIST
922929
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
+64Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10+
// AMDGPU intrinsics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
25+
/// and offset must be a constant integer in the range [0, 31].
26+
struct PromoteShuffleToSwizzlePattern
27+
: public OpRewritePattern<gpu::ShuffleOp> {
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
31+
PatternRewriter &rewriter) const override {
32+
if (op.getMode() != gpu::ShuffleMode::XOR)
33+
return rewriter.notifyMatchFailure(op,
34+
"only xor shuffle mode is supported");
35+
36+
if (!isConstantIntValue(op.getWidth(), 64))
37+
return rewriter.notifyMatchFailure(op,
38+
"only 64 width shuffle is supported");
39+
40+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
41+
if (!offset)
42+
return rewriter.notifyMatchFailure(op,
43+
"offset must be a constant integer");
44+
45+
int64_t offsetValue = *offset;
46+
if (offsetValue < 0 || offsetValue >= 32)
47+
return rewriter.notifyMatchFailure(op,
48+
"offset must be in the range [0, 31]");
49+
50+
Location loc = op.getLoc();
51+
Value res = rewriter.create<amdgpu::SwizzleBitModeOp>(
52+
loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
53+
/*orMask=*/0, /*xorMask=*/offsetValue);
54+
Value valid = rewriter.create<arith::ConstantIntOp>(loc, 1, /*width*/ 1);
55+
rewriter.replaceOp(op, {res, valid});
56+
return success();
57+
}
58+
};
59+
} // namespace
60+
61+
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
62+
RewritePatternSet &patterns) {
63+
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
64+
}
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
2+
3+
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
5+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %func {
7+
transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
8+
} : !transform.any_op
9+
transform.yield
10+
}
11+
}
12+
13+
// CHECK-LABEL: func @gpu_shuffle_swizzle
14+
// CHECK-SAME: (%[[ARG:.*]]: i32)
15+
func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
16+
// CHECK: %[[TRUE:.*]] = arith.constant true
17+
// CHECK: %[[RES:.*]] = amdgpu.swizzle_bitmode %[[ARG]] 31 0 23 : i32
18+
// CHECK: return %[[RES]], %[[TRUE]] : i32, i1
19+
%width = arith.constant 64 : i32
20+
%offset = arith.constant 23 : i32
21+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
22+
func.return %shfl, %pred : i32, i1
23+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.