Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][AMDGPU] Add scaled floating point conversion ops #141554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[mlir][AMDGPU] implement ScaledExtPackedOp and PackedScaledTruncOp
  • Loading branch information
tgymnich committed Jun 11, 2025
commit dcd7a039053863f51ad75629238a06fa3ebc5cec
62 changes: 62 additions & 0 deletions 62 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,38 @@ def AMDGPU_ExtPackedFp8Op :
}];
}

def AMDGPU_ScaledExtPackedOp
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>,
tgymnich marked this conversation as resolved.
Show resolved Hide resolved
VectorOfLengthAndType<[2, 3, 4, 5, 6, 7, 8],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I suppose odd numbers work, since you can just zext

[F4E2M1FN]>]>:$source,
F32:$scale,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
Results<(
outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>,
FixedVectorOfLengthAndType<[2], [F16]>,
FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> {
let summary = "Extend a vector of packed floating point values";

let description = [{
Extend and scale two packed floats in `source[index]` to two floats and
return them.
Comment on lines +130 to +131
Copy link
Contributor

@umangyadav umangyadav Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this a bit confusing source[index] would only point to one element. How do we get two floats in return ?

Ans: It is selecting byte of a 32-bit word. For F4 each byte would be two floats.


This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
extending 8-bit floats (which are, currently, the only way to work with
this operation) take packed vectors of 2 such floats.

If the passed-in vector has fewer than two elements, or the input is scalar,
the remaining values in the <2 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
}];
}

def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
Expand Down Expand Up @@ -139,6 +171,36 @@ def AMDGPU_PackedTrunc2xFp8Op :
let hasVerifier = 1;
}

def AMDGPU_PackedScaledTruncOp
: AMDGPU_Op<"packed_scaled_trunc", [Pure]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
for the ext it is named as scaled_ext_packed
here it is packed_scaled_trunc. I find it better to call it scaled_trunc_packed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That follows existing convention in the dialect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have ext_packed (the things being extended are packed) but packed_trunc (the result is packed)

Arguments<(ins VectorOfLengthAndType<[2], [F32, F16, BF16]>:$source,
F32:$scale,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index,
Optional<AnyTypeOf<
[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>>:$existing)>,
Results<(
outs AnyTypeOf<[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>,
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> {
let summary = "Round two floats into a packed vector of floats";
let description = [{
Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
specified) into the low or high word (bottom two or top two) elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a selector for selecting low or high word ?
What's the role of index attribute here ? Can you explain ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is either low/high half or byte 0-3 depending, probably

of the returned vector, keeping the other two elements of `existing`
unchanged if present (or undefined if it was not passed in).

The reason for this odd signature is that AMD GPUs cannot easily work with
sub-registers, and so the conversion intrinsics take 32-bit wide
packed vectors of float values.
}];
let assemblyFormat = [{
attr-dict $source `into` ($existing^):(`undef`)? `[` `index` $index `]`
tgymnich marked this conversation as resolved.
Show resolved Hide resolved
`,` $scale
`:` type($source) `to` type($res) (`into` type($existing)^)?
}];
let hasVerifier = 0;
}

def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
Expand Down
181 changes: 180 additions & 1 deletion 181 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -1174,6 +1175,32 @@ struct PackedStochRoundFp8OpLowering final
PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct ScaledExtPackedOpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct PackedScaledTruncOpLowering final
: public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // end namespace

LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Expand Down Expand Up @@ -1230,6 +1257,157 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}

LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

Value source = adaptor.getSource();
Value scale = adaptor.getScale();

VectorType sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
VectorType destVecType = dyn_cast<VectorType>(op.getResult().getType());
Type destElemType = getElementTypeOrSelf(op.getResult());

VectorType packedVecType;
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
} else {
llvm_unreachable("invalid element type for scaled ext");
}

// Extend to a packedVectorType
if (!sourceVecType ||
sourceVecType.getNumElements() < packedVecType.getNumElements()) {
Value longVec = rewriter.create<LLVM::UndefOp>(loc, packedVecType);
tgymnich marked this conversation as resolved.
Show resolved Hide resolved
if (!sourceVecType) {
longVec = rewriter.create<LLVM::InsertElementOp>(
loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
longVec =
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
}
}
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);

if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
Comment on lines +1307 to +1335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
else
return failure();
if (isa<Float8E5M2Type>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float8E4M3FNType>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float4E2M1FNType>(sourceElemType)) {
if (destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else if (destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
op, destVecType, i32Source, scale, op.getIndex());
} else {
return failure();
}
} else {
return failure();
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this makes the code less readable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, that what i suggested is feels better to read. Anyways it's personal opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a static std::tuple<OperationName, StringAttr, IntegerType> getConversionOpNameAndSelectorArg(Type inType, Type outType); that'd return the information needed to construct this operation generically. That'd save us all the duplicated op, destVecType, i32Source, scale, op.getIndex() sections

Basically, have we considered doing it the way the MFMA/WMMA code handles all this?

(And, based on my experience with said code, the nested ifs Umang proposed end up a bit easier to work with ... though, in retrospect, a lot of those could've been a llvm::TypeSwitch.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Said note is non-blocknig if you'd really rather not rewrite all this again, but I thought I'd raise it.)


return success();
}

LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset != kGfx950)
return rewriter.notifyMatchFailure(
loc, "Scaled fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Type v2i16 = getTypeConverter()->convertType(
VectorType::get(2, rewriter.getI16Type()));
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());

Type resultType = op.getResult().getType();
Type resultElemType = getElementTypeOrSelf(resultType);
Type sourceElemType = getElementTypeOrSelf(op.getSource());

Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;

Value source = adaptor.getSource();
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, intResultType);

Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
}

Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
Comment on lines +1384 to +1412
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
if (isa<Float8E5M2Type>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float8E4M3FNType>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else if (isa<Float4E2M1FNType>(resultElemType)) {
if (sourceElemType.isF32()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
} else if (sourceElemType.isF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else if (sourceElemType.isBF16()) {
result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
loc, intResultType, existing, source, scale, op.getIndex());
} else {
return failure();
}
} else {
return failure();
}


result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
return success();
}

LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -1547,7 +1725,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.