diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index cba35bbca1f83..484cea84f669b 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -41,8 +41,8 @@ class AMDGPU_Op traits = []> : def AMDGPU_ExtPackedFp8Op : AMDGPU_Op<"ext_packed_fp8", [Pure]>, - Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, - VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ]>]>:$source, + Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, + VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, ConfinedAttr]>:$index)>, Results<(outs F32:$res)> { let summary = "Extend one of a vector of packed fp8 values to a float"; @@ -68,8 +68,8 @@ def AMDGPU_PackedTrunc2xFp8Op : Arguments<(ins F32:$sourceA, Optional:$sourceB, ConfinedAttr]>:$wordIndex, - Optional>:$existing)>, - Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { let summary = "Round two floats into a packed vector of 8-bit floats"; let description = [{ Round the inputs `sourceA` and `sourceB` (which is undefined if not @@ -95,8 +95,8 @@ def AMDGPU_PackedStochRoundFp8Op : Arguments<(ins F32:$source, I32:$stochiasticParam, ConfinedAttr]>:$storeIndex, - Optional>:$existing)>, - Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ]>:$res)> { + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { let summary = "Round float stochiastically into a packed vector of 8-bit floats"; let description = [{ Round the input `source`, adding in `stochiasticParam`, and place it into @@ -546,7 +546,7 @@ def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64, VectorOfLengthAndType<[4], [F16]>, VectorOfLengthAndType<[2, 4], [BF16]>, VectorOfLengthAndType<[4, 8], [I8]>, - VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>; + VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>; def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [F32]>, VectorOfLengthAndType<[4, 16, 32], [I32]>, diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h index a5dab1ab89630..768b390ed5381 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h @@ -49,6 +49,14 @@ struct Chipset { #undef DEFINE_COMP_OPERATOR }; +inline bool isGfx940Series(const Chipset &chipset) { + return chipset.majorVersion == 9 && chipset.minorVersion == 4; +} +inline bool hasOcpFp8(const Chipset &chipset) { + return (chipset.majorVersion == 9 && chipset.minorVersion >= 5) || + chipset.majorVersion >= 12; +} + } // namespace mlir::amdgpu #endif diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index b6a307fd7cb0f..7feab4d966d59 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -140,6 +140,9 @@ class Type { bool isF64() const; bool isF80() const; bool isF128() const; + /// Return true if this is an float type (with the specified width). + bool isFloat() const; + bool isFloat(unsigned width) const; /// Return true if this is an integer type (with the specified width). bool isInteger() const; diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index f80d2793eaef5..4a76739c7a06a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -454,6 +454,20 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, } } +/// Return true if `type` is the E5M2 variant of an 8-bit float that is +/// supported by the `_bf8` instructions on the given `chipset`. +static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { + return (isGfx940Series(chipset) && type.isFloat8E5M2FNUZ()) || + (hasOcpFp8(chipset) && type.isFloat8E5M2()); +} + +/// Return true if `type` is the E4M3FN variant of an 8-bit float that is +/// supported by the `_fp8` instructions on the given `chipset`. +static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { + return (isGfx940Series(chipset) && type.isFloat8E4M3FNUZ()) || + (hasOcpFp8(chipset) && type.isFloat8E4M3FN()); +} + /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. @@ -550,38 +564,38 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } - if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) { + if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, sourceElem)) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } - if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) { + if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, sourceElem)) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { - if (sourceBElem.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); - if (sourceBElem.isFloat8E4M3FNUZ()) + if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } @@ -757,7 +771,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -787,10 +801,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( } Value i32Source = rewriter.create(loc, i32, source); Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); - if (sourceElemType.isFloat8E5M2FNUZ()) { + if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); - } else if (sourceElemType.isFloat8E4M3FNUZ()) { + } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); } @@ -801,7 +815,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -822,10 +836,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); Value result; - if (resultElemType.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); - else if (resultElemType.isFloat8E4M3FNUZ()) + else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); @@ -838,7 +852,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset < kGfx940) + if (!(isGfx940Series(chipset) || hasOcpFp8(chipset))) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -857,10 +871,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); Value result; - if (resultElemType.isFloat8E5M2FNUZ()) + if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, source, stoch, existing, byteSel); - else if (resultElemType.isFloat8E4M3FNUZ()) + else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, source, stoch, existing, byteSel); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 6b27ec9947cb0..e16f9f65cc919 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -41,6 +41,10 @@ struct ArithToAMDGPUConversionPass final struct ExtFOnFloat8RewritePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + Chipset chipset; + ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) + : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + LogicalResult match(arith::ExtFOp op) const override; void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; }; @@ -68,6 +72,14 @@ struct TruncfToFloat16RewritePattern final } // end namespace +static LogicalResult isSupportedF8(Type elementType, Chipset chipset) { + if (isGfx940Series(chipset)) + return success(isa(elementType)); + if (hasOcpFp8(chipset)) + return success(isa(elementType)); + return failure(); +} + static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter) { if (elementType.isF32()) @@ -86,7 +98,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { return failure(); inType = inVecType.getElementType(); } - return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); + return isSupportedF8(inType, chipset); } void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, @@ -216,7 +228,8 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { if (inType && inType.getWidth() <= 8 && saturateFP8) // Conversion between 8-bit floats is not supported with truncation enabled. return failure(); - return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); + + return isSupportedF8(outType, chipset); } void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, @@ -365,7 +378,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns( bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { if (convertFP8Arithmetic) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), chipset); patterns.add(patterns.getContext(), saturateFP8Truncf, chipset); } @@ -384,7 +397,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() { } bool convertFP8Arithmetic = - maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0); + isGfx940Series(*maybeChipset) || hasOcpFp8(*maybeChipset); arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, *maybeChipset); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 63447baa31eb0..48fb1dfb0a003 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() { } Type sourceBType = getSourceB().getType(); - if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { + if (sourceElem.isFloat(8)) { int64_t sourceBLen = 1; Type sourceBElem = sourceBType; if (auto sourceBVector = llvm::dyn_cast(sourceBType)) { sourceBLen = sourceBVector.getNumElements(); sourceBElem = sourceBVector.getElementType(); } - if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ()) + if (!sourceBElem.isFloat(8)) return emitOpError("expected both source operands to have f8 elements"); if (sourceLen != sourceBLen) return emitOpError( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index b78c372af77e6..9d621e46126c5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -509,7 +509,8 @@ bool TosaValidation::isValidElementType(Type type) { if (isa(type)) { if (profile == TosaProfileEnum::BaseInference) return false; - return type.isF32() || type.isF16() || type.isBF16(); + return isa(type); } if (auto intTy = dyn_cast(type)) { if (intTy.isUnsigned()) { diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index fa093664cf77f..6d8f9af6ad6a5 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -56,6 +56,15 @@ bool Type::isF64() const { return llvm::isa(*this); } bool Type::isF80() const { return llvm::isa(*this); } bool Type::isF128() const { return llvm::isa(*this); } +bool Type::isFloat() const { return llvm::isa(*this); } + +/// Return true if this is an integer type with the specified width. +bool Type::isFloat(unsigned width) const { + if (auto fltTy = llvm::dyn_cast(*this)) + return fltTy.getWidth() == width; + return false; +} + bool Type::isIndex() const { return llvm::isa(*this); } bool Type::isInteger() const { return llvm::isa(*this); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir new file mode 100644 index 0000000000000..70775a603e54d --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 | FileCheck %s + +// CHECK-LABEL: func @ext_scalar +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : f8E5M2 to i8 +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 +// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_scalar(%v: f8E5M2) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_short_vec +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<2xf8E4M3FN> to vector<2xi8> +// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8> +// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8> +// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8> +// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> +// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 +// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: return [[EXT]] +func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @ext_full_vec( +// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> +// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 +// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: return [[EXT]] : f32 + +func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { + %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32 + func.return %ret : f32 +} + +// CHECK-LABEL: func @packed_trunc +// CHECK-SAME: ([[V:%.+]]: f32) +// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> +func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> { + %ret = amdgpu.packed_trunc_2xfp8 %v, undef into undef[word 0] : f32 to vector<4xf8E4M3FN> + func.return %ret : vector<4xf8E4M3FN> +} + +// CHECK-LABEL: func @packed_truncx2 +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> +func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into undef[word 0] : f32 to vector<4xf8E4M3FN> + func.return %ret : vector<4xf8E4M3FN> +} + +// CHECK-LABEL: func @packed_truncx2_into +// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2> +func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> { + %ret = amdgpu.packed_trunc_2xfp8 %v, %w into %existing[word 1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2> + func.return %ret : vector<4xf8E5M2> +} + +// CHECK-LABEL: func @packed_stoch_round +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32) +// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 +// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> +func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into undef[0] : f32 to vector<4xf8E4M3FN> + func.return %ret : vector<4xf8E4M3FN> +} + +// CHECK-LABEL: func @packed_stoch_round_into +// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>) +// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8> +// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 +// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32 +// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> +// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2> +func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> { + %ret = amdgpu.packed_stoch_round_fp8 %v + %s into %existing[1] : f32 to vector<4xf8E5M2> into vector<4xf8E5M2> + func.return %ret : vector<4xf8E5M2> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir new file mode 100644 index 0000000000000..2df5f2fa1965f --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation-ocp.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt --split-input-file %s \ +// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx950 saturate-fp8-truncf=true}))' \ +// RUN: | FileCheck %s + +// RUN: mlir-opt --split-input-file %s \ +// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx1200 saturate-fp8-truncf=true}))' \ +// RUN: | FileCheck %s + +// CHECK-LABEL: func.func @scalar_trunc +// CHECK-SAME: ([[V:%.+]]: f16) +// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16 +// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16 +// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16 +// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16 +// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]] +// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]] +// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]] +// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]] +// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]] +// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]] +// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]] +// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]] +// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32 +// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2> +// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2> +// CHECK: return [[W]] : f8E5M2 +func.func @scalar_trunc(%v: f16) -> f8E5M2 { + %w = arith.truncf %v : f16 to f8E5M2 + return %w : f8E5M2 +} + +// No 0-D test because arith.truncf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_trunc +// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FN> { +// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-4.480000e+02> : vector<2xf32> +// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<4.480000e+02> : vector<2xf32> +// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32> +// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32> +// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]] +// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]] +// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]] +// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]] +// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]] +// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]] +// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]] +// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]] +// CHECK: [[F0:%.+]] = vector.extract [[SATURATED]][0] +// CHECK: [[F1:%.+]] = vector.extract [[SATURATED]][1] +// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FN> +// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FN> to vector<2xf8E4M3FN> +// CHECK: return [[W]] : vector<2xf8E4M3FN> +func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FN> { + %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FN> + return %w : vector<2xf8E4M3FN> +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir new file mode 100644 index 0000000000000..0e7f58c9e6749 --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats-ocp.mlir @@ -0,0 +1,176 @@ +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1200" | FileCheck %s + +// CHECK-LABEL: func.func @scalar_ext +// CHECK-SAME: ([[V:%.+]]: f8E5M2) +// CHECK: [[FLOAT:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : f8E5M2 to f32 +// CHECK: [[W:%.+]] = arith.truncf [[FLOAT]] : f32 to f16 +// CHECK: return [[W]] +func.func @scalar_ext(%v: f8E5M2) -> f16 { + %w = arith.extf %v : f8E5M2 to f16 + return %w : f16 +} + +// No 0-D test because arith.extf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_ext_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf8E5M2>) +// CHECK-DAG: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK: [[FLOAT0:%.+]] = amdgpu.ext_packed_fp8 [[V]][0] : vector<2xf8E5M2> to f32 +// CHECK: [[EXT0:%.+]] = arith.extf [[FLOAT0]] : f32 to f64 +// CHECK: [[W0:%.+]] = vector.insert [[EXT0]], [[ZEROES]] [0] +// CHECK: [[FLOAT1:%.+]] = amdgpu.ext_packed_fp8 [[V]][1] : vector<2xf8E5M2> to f32 +// CHECK: [[EXT1:%.+]] = arith.extf [[FLOAT1]] +// CHECK: [[W1:%.+]] = vector.insert [[EXT1]], [[W0]] [1] +// CHECK: return [[W1]] : vector<2xf64> + +func.func @vector_ext_short(%v: vector<2xf8E5M2>) -> vector<2xf64> { + %w = arith.extf %v : vector<2xf8E5M2> to vector<2xf64> + return %w : vector<2xf64> +} + +// ----- + +// CHECK-LABEL: func.func @vector_ext_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf8E4M3FN>) +// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[V]] {offsets = [0], sizes = [4], strides = [1]} +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] +// CHECK: [[W0:%.+]] = vector.insert [[F0]] +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] +// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] +// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] +// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[V]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] +// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] +// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] +// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] +// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] +// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[V]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> +// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] +// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] +// CHECK: return [[W8]] +func.func @vector_ext_long(%v: vector<9xf8E4M3FN>) -> vector<9xf32> { + %w = arith.extf %v : vector<9xf8E4M3FN> to vector<9xf32> + return %w : vector<9xf32> +} + +// ----- + +// CHECK-LABEL: func.func @scalar_trunc +// CHECK-SAME: ([[V:%.+]]: f16) +// CHECK: [[FLOAT:%.+]] = arith.extf [[V]] : f16 to f32 +// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2> +// CHECK: [[W:%.+]] = vector.extract [[TRUNCV]][0] : f8E5M2 from vector<4xf8E5M2> +// CHECK: return [[W]] : f8E5M2 +func.func @scalar_trunc(%v: f16) -> f8E5M2 { + %w = arith.truncf %v : f16 to f8E5M2 + return %w : f8E5M2 +} + +// No 0-D test because arith.truncf hasn't been extended to support it. + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_short +// CHECK-SAME: ([[V:%.+]]: vector<2xf64>) -> vector<2xf8E5M2> { +// CHECK: [[V0:%.+]] = vector.extract [[V]][0] +// CHECK: [[F0:%.+]] = arith.truncf [[V0]] : f64 to f32 +// CHECK: [[V1:%.+]] = vector.extract [[V]][1] +// CHECK: [[F1:%.+]] = arith.truncf [[V1]] : f64 to f32 +// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E5M2> +// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK: return [[W]] : vector<2xf8E5M2> +func.func @vector_trunc_short(%v: vector<2xf64>) -> vector<2xf8E5M2> { + %w = arith.truncf %v : vector<2xf64> to vector<2xf8E5M2> + return %w : vector<2xf8E5M2> +} + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_long +// CHECK-SAME: ([[V:%.+]]: vector<9xf32>) +// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN> +// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1] +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]} + +// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1] +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]} + +// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0] +// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]} +// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]} +// CHECK: return [[W]] +func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FN> { + %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FN> + return %w : vector<9xf8E4M3FN> +} + +// ----- + +// CHECK-LABEL: func.func @vector_trunc_long_2d +// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>) +// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FN> +// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1] +// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]} + +// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0] +// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1] +// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]} + +// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0] +// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]} +// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]} +// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FN> to vector<1x9xf8E4M3FN> +// CHECK: return [[RE]] +func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FN> { + %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FN> + return %w : vector<1x9xf8E4M3FN> +} + +// ----- + +// CHECK-LABEL: func.func @vector_ext_long_2d +// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FN>) +// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FN> to vector<9xf8E4M3FN> +// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]} +// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0] +// CHECK: [[W0:%.+]] = vector.insert [[F0]] +// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1] +// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]] +// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2] +// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]] +// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3] +// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]] + +// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FN> to vector<4xf8E4M3FN> +// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0] +// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]] +// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1] +// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]] +// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2] +// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]] +// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3] +// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]] + +// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FN> to vector<1xf8E4M3FN> +// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0] +// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]] +// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32> +// CHECK: return [[CAST]] +func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FN>) -> vector<1x9xf32> { + %w = arith.extf %v : vector<1x9xf8E4M3FN> to vector<1x9xf32> + return %w : vector<1x9xf32> +}