Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a37ab7d

Browse filesBrowse files
committed
[MLIR][AMDGPU] Clean up and redo after other recent patches here.
1 parent cba8da0 commit a37ab7d
Copy full SHA for a37ab7d

File tree

Expand file treeCollapse file tree

3 files changed

+22
-13
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+22
-13
lines changed

‎mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ struct Chipset {
4949
#undef DEFINE_COMP_OPERATOR
5050

5151
bool isGfx940() const {
52-
return majorVersion == 9 && minorVersion >= 0x40 && minorVersion < 0x50;
52+
return majorVersion == 9 && minorVersion >= 4 && minorVersion < 5;
5353
}
5454
bool hasOcpFp8() const {
55-
return (majorVersion == 9 && minorVersion >= 0x50) || majorVersion >= 12;
55+
return (majorVersion == 9 && minorVersion >= 5) || majorVersion >= 12;
5656
}
5757
};
5858

‎mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
771771
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
772772
ConversionPatternRewriter &rewriter) const {
773773
Location loc = op.getLoc();
774-
if (chipset.majorVersion != 9 || chipset < kGfx940)
774+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
775775
return rewriter.notifyMatchFailure(
776776
loc, "Fp8 conversion instructions are not available on target "
777777
"architecture and their emulation is not implemented");
@@ -815,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
815815
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
816816
ConversionPatternRewriter &rewriter) const {
817817
Location loc = op.getLoc();
818-
if (chipset.majorVersion != 9 || chipset < kGfx940)
818+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
819819
return rewriter.notifyMatchFailure(
820820
loc, "Fp8 conversion instructions are not available on target "
821821
"architecture and their emulation is not implemented");
@@ -852,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
852852
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
853853
ConversionPatternRewriter &rewriter) const {
854854
Location loc = op.getLoc();
855-
if (chipset.majorVersion != 9 || chipset < kGfx940)
855+
if (!(chipset.isGfx940() || chipset.hasOcpFp8()))
856856
return rewriter.notifyMatchFailure(
857857
loc, "Fp8 conversion instructions are not available on target "
858858
"architecture and their emulation is not implemented");

‎mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+17-8Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final
4141
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4242
using OpRewritePattern::OpRewritePattern;
4343

44+
Chipset chipset;
45+
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
46+
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
47+
4448
LogicalResult match(arith::ExtFOp op) const override;
4549
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
4650
};
@@ -68,6 +72,15 @@ struct TruncfToFloat16RewritePattern final
6872

6973
} // end namespace
7074

75+
static LogicalResult isSupportedFp8(Type elementType, Chipset chipset) {
76+
if (chipset.isGfx940())
77+
return success(elementType.isFloat8E5M2FNUZ() ||
78+
elementType.isFloat8E4M3FNUZ());
79+
if (chipset.hasOcpFp8())
80+
return success(elementType.isFloat8E5M2() || elementType.isFloat8E4M3FN());
81+
return failure();
82+
}
83+
7184
static Value castF32To(Type elementType, Value f32, Location loc,
7285
PatternRewriter &rewriter) {
7386
if (elementType.isF32())
@@ -86,8 +99,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8699
return failure();
87100
inType = inVecType.getElementType();
88101
}
89-
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ() ||
90-
inType.isFloat8E5M2() || inType.isFloat8E4M3FN());
102+
return isSupportedFp8(inType, chipset);
91103
}
92104

93105
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -218,10 +230,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
218230
// Conversion between 8-bit floats is not supported with truncation enabled.
219231
return failure();
220232

221-
return success((((outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()) &&
222-
chipset.isGfx940()) ||
223-
((outType.isFloat8E5M2() || outType.isFloat8E4M3FN()) &&
224-
chipset.hasOcpFp8())));
233+
return isSupportedFp8(outType, chipset);
225234
}
226235

227236
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
@@ -370,7 +379,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
370379
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
371380

372381
if (convertFP8Arithmetic) {
373-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
382+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
374383
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
375384
saturateFP8Truncf, chipset);
376385
}
@@ -389,7 +398,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
389398
}
390399

391400
bool convertFP8Arithmetic =
392-
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
401+
maybeChipset->isGfx940() || maybeChipset->hasOcpFp8();
393402
arith::populateArithToAMDGPUConversionPatterns(
394403
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
395404
*maybeChipset);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.