10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
13
+ #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14
+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
13
15
#include " mlir/Dialect/Arith/IR/Arith.h"
14
16
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
15
17
#include " mlir/Dialect/GPU/Transforms/Passes.h"
16
18
#include " mlir/Dialect/GPU/Utils/GPUUtils.h"
19
+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
17
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
21
#include " mlir/IR/BuiltinTypes.h"
19
22
#include " mlir/IR/Location.h"
20
23
#include " mlir/IR/PatternMatch.h"
21
24
#include " mlir/IR/TypeUtilities.h"
25
+ #include " llvm/Support/ErrorHandling.h"
22
26
#include " llvm/Support/FormatVariadic.h"
23
27
#include " llvm/Support/MathExtras.h"
24
28
#include < cassert>
@@ -362,6 +366,163 @@ struct VectorSubgroupReduceToShuffles final
362
366
unsigned shuffleBitwidth = 0 ;
363
367
bool matchClustered = false ;
364
368
};
369
+
370
+ static FailureOr<Value>
371
+ createSubgroupDPPReduction (PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372
+ Value input, gpu::AllReduceOperation mode,
373
+ const ClusterInfo &ci, amdgpu::Chipset chipset) {
374
+ Location loc = op.getLoc ();
375
+ Value dpp;
376
+ Value res = input;
377
+ constexpr int allRows = 0xf ;
378
+ constexpr int allBanks = 0xf ;
379
+ const bool boundCtrl = true ;
380
+ if (ci.clusterSize >= 2 ) {
381
+ // Perform reduction between all lanes N <-> N+1.
382
+ dpp = rewriter.create <amdgpu::DPPOp>(
383
+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
384
+ rewriter.getI32ArrayAttr ({1 , 0 , 3 , 2 }), allRows, allBanks, boundCtrl);
385
+ res = vector::makeArithReduction (rewriter, loc,
386
+ gpu::convertReductionKind (mode), res, dpp);
387
+ }
388
+
389
+ if (ci.clusterSize >= 4 ) {
390
+ // Perform reduction between all lanes N <-> N+2.
391
+ dpp = rewriter.create <amdgpu::DPPOp>(
392
+ loc, res.getType (), res, res, amdgpu::DPPPerm::quad_perm,
393
+ rewriter.getI32ArrayAttr ({2 , 3 , 0 , 1 }), allRows, allBanks, boundCtrl);
394
+ res = vector::makeArithReduction (rewriter, loc,
395
+ gpu::convertReductionKind (mode), res, dpp);
396
+ }
397
+ if (ci.clusterSize >= 8 ) {
398
+ // Perform reduction between all lanes N <-> 7-N,
399
+ // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400
+ dpp = rewriter.create <amdgpu::DPPOp>(
401
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_half_mirror,
402
+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
403
+ res = vector::makeArithReduction (rewriter, loc,
404
+ gpu::convertReductionKind (mode), res, dpp);
405
+ }
406
+ if (ci.clusterSize >= 16 ) {
407
+ // Perform reduction between all lanes N <-> 15-N,
408
+ // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409
+ dpp = rewriter.create <amdgpu::DPPOp>(
410
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_mirror,
411
+ rewriter.getUnitAttr (), allRows, allBanks, boundCtrl);
412
+ res = vector::makeArithReduction (rewriter, loc,
413
+ gpu::convertReductionKind (mode), res, dpp);
414
+ }
415
+ if (ci.clusterSize >= 32 ) {
416
+ if (chipset.majorVersion <= 9 ) {
417
+ // Broadcast last value from each row to next row.
418
+ // Use row mask to avoid polluting rows 1 and 3.
419
+ dpp = rewriter.create <amdgpu::DPPOp>(
420
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_15,
421
+ rewriter.getUnitAttr (), 0xa , allBanks,
422
+ /* bound_ctrl*/ false );
423
+ res = vector::makeArithReduction (
424
+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
425
+ } else if (chipset.majorVersion <= 12 ) {
426
+ // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427
+ Value uint32Max = rewriter.create <arith::ConstantOp>(
428
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (-1 ));
429
+ dpp = rewriter.create <ROCDL::PermlaneX16Op>(loc, res.getType (), res, res,
430
+ uint32Max, uint32Max,
431
+ /* fi=*/ true ,
432
+ /* bound_ctrl=*/ false );
433
+ res = vector::makeArithReduction (
434
+ rewriter, loc, gpu::convertReductionKind (mode), res, dpp);
435
+ if (ci.subgroupSize == 32 ) {
436
+ Value lane0 = rewriter.create <arith::ConstantOp>(
437
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
438
+ res =
439
+ rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
440
+ }
441
+ } else {
442
+ return rewriter.notifyMatchFailure (
443
+ op, " Subgroup reduce lowering to DPP not currently supported for "
444
+ " this device." );
445
+ }
446
+ }
447
+ if (ci.clusterSize >= 64 ) {
448
+ if (chipset.majorVersion <= 9 ) {
449
+ // Broadcast 31st lane value to rows 2 and 3.
450
+ // Use row mask to avoid polluting rows 0 and 1.
451
+ dpp = rewriter.create <amdgpu::DPPOp>(
452
+ loc, res.getType (), res, res, amdgpu::DPPPerm::row_bcast_31,
453
+ rewriter.getUnitAttr (), 0xc , allBanks,
454
+ /* bound_ctrl*/ false );
455
+
456
+ } else if (chipset.majorVersion <= 12 ) {
457
+ // Assume reduction across 32 lanes has been done.
458
+ // Perform final reduction manually by summing values in lane 0 and
459
+ // lane 32.
460
+ Value lane0 = rewriter.create <arith::ConstantOp>(
461
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (0 ));
462
+ Value lane32 = rewriter.create <arith::ConstantOp>(
463
+ loc, rewriter.getI32Type (), rewriter.getI32IntegerAttr (32 ));
464
+ dpp = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane32);
465
+ res = rewriter.create <ROCDL::ReadlaneOp>(loc, res.getType (), res, lane0);
466
+ } else {
467
+ return rewriter.notifyMatchFailure (
468
+ op, " Subgroup reduce lowering to DPP not currently supported for "
469
+ " this device." );
470
+ }
471
+ res = vector::makeArithReduction (rewriter, loc,
472
+ gpu::convertReductionKind (mode), res, dpp);
473
+ }
474
+ assert (res.getType () == input.getType ());
475
+ return res;
476
+ }
477
+
478
+ // / Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
479
+ // / ops over scalar types. Assumes that the subgroup has
480
+ // / `subgroupSize` lanes. Applicable only to AMD GPUs.
481
+ struct ScalarSubgroupReduceToDPP final
482
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
483
+ ScalarSubgroupReduceToDPP (MLIRContext *ctx, unsigned subgroupSize,
484
+ bool matchClustered, amdgpu::Chipset chipset,
485
+ PatternBenefit benefit)
486
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
487
+ matchClustered (matchClustered), chipset(chipset) {}
488
+
489
+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
490
+ PatternRewriter &rewriter) const override {
491
+ if (op.getClusterSize ().has_value () != matchClustered) {
492
+ return rewriter.notifyMatchFailure (
493
+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
494
+ " only match {1}clustered ops" ,
495
+ matchClustered ? " non-" : " " ,
496
+ matchClustered ? " " : " non-" ));
497
+ }
498
+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
499
+ if (failed (ci))
500
+ return failure ();
501
+
502
+ if (ci->clusterStride != 1 )
503
+ return rewriter.notifyMatchFailure (
504
+ op, " Subgroup reductions using DPP are currently only available for "
505
+ " clusters of contiguous lanes." );
506
+
507
+ Type valueTy = op.getType ();
508
+ if (!valueTy.isIntOrFloat ())
509
+ return rewriter.notifyMatchFailure (
510
+ op, " Value type is not a compatible scalar." );
511
+
512
+ FailureOr<Value> dpp = createSubgroupDPPReduction (
513
+ rewriter, op, op.getValue (), op.getOp (), *ci, chipset);
514
+ if (failed (dpp))
515
+ return failure ();
516
+
517
+ rewriter.replaceOp (op, dpp.value ());
518
+ return success ();
519
+ }
520
+
521
+ private:
522
+ unsigned subgroupSize = 0 ;
523
+ bool matchClustered = false ;
524
+ amdgpu::Chipset chipset;
525
+ };
365
526
} // namespace
366
527
367
528
void mlir::populateGpuBreakDownSubgroupReducePatterns (
@@ -372,6 +533,22 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372
533
patterns.add <ScalarizeSingleElementReduce>(patterns.getContext (), benefit);
373
534
}
374
535
536
+ void mlir::populateGpuLowerSubgroupReduceToDPPPatterns (
537
+ RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
538
+ PatternBenefit benefit) {
539
+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
540
+ /* matchClustered=*/ false , chipset,
541
+ benefit);
542
+ }
543
+
544
+ void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns (
545
+ RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
546
+ PatternBenefit benefit) {
547
+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
548
+ /* matchClustered=*/ true , chipset,
549
+ benefit);
550
+ }
551
+
375
552
void mlir::populateGpuLowerSubgroupReduceToShufflePatterns (
376
553
RewritePatternSet &patterns, unsigned subgroupSize,
377
554
unsigned shuffleBitwidth, PatternBenefit benefit) {
0 commit comments