Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add arith expansion of f8E8M0 type for extf/trunc ops #140332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
Loading
from

Conversation

umangyadav
Copy link
Contributor

@umangyadav umangyadav commented May 17, 2025

F8E8M0 floating type is supposed to represent unbiased exponent bits of F32 type in OCP Micro scaling floating point formats.

https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR expands arith.truncf and arith.extf to support this behavior.

For the arith.truncf thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as 0xFF. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast.

cc: @krzysz00 @MaheshRavishankar @Muzammiluddin-Syed-ECE

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:arith labels May 17, 2025
@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Umang Yadav (umangyadav)

Changes

F8E8M0 floating type is supposed to represent unbiased exponent bits of F32 type in OCP Micro scaling floating point formats.

https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR expands arith.truncf and arith.extf to support this behavior.

For the arith.truncf thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as 0xFF. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast.


Full diff: https://github.com/llvm/llvm-project/pull/140332.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+5-3)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+129-9)
  • (modified) mlir/lib/IR/Types.cpp (+1-1)
  • (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+129-1)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 8d81d8ec14ee7..5aaac8d8e3dc5 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
 void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
 
+/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
+void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index d026d494cb50c..e14b2aeee1c69 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
 def ArithExpandOpsPass : Pass<"arith-expand"> {
   let summary = "Legalize Arith ops to be convertible to LLVM.";
   let dependentDialects = ["vector::VectorDialect"];
-  let options = [
-    Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
-           "Enable the BF16 expansion patterns">,
+  let options =
+      [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
+              "Enable the BF16 expansion patterns">,
+       Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
+              "Enable the F8E8M0 expansion patterns">,
   ];
 }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 4ffdbfa5b1224..55a7c6bb11784 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -109,6 +109,7 @@ class Type {
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isF8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 2d627e523cde5..f5240cf92bdc4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Constant used to make the rounding bias.
     Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
     // Constant used to generate a quiet NaN.
-    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
     // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
     Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+    // create constants for NaNs
+    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.isBF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.isF16()) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(resultTy);
+    if (!resultETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+    }
+    if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
+      return rewriter.notifyMatchFailure(
+          op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
+    }
+
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+    if (!operandETy.isF32()) {
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
+    }
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
+    Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
       arith::MinNumFOp
     >();
 
-    if (includeBf16) {
+    if(includeBf16) {
       arith::populateExpandBFloat16Patterns(patterns);
+    }
+    if(includeF8E8M0) {
+      arith::populateExpandF8E8M0Patterns(patterns);
+    }
+    if (includeBf16 || includeF8E8M0) {
       target.addDynamicallyLegalOp<arith::ExtFOp>(
-        [](arith::ExtFOp op) {
+        [=](arith::ExtFOp op) {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isBF16() && outETy.isF32());
+          if(includeBf16 && includeF8E8M0)
+            return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
+          if(includeBf16)
+            return !(inETy.isBF16() && outETy.isF32());
+          return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
         });
 
       target.addDynamicallyLegalOp<arith::TruncFOp>(
-        [](arith::TruncFOp op)  {
+        [=](arith::TruncFOp op)  {
           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
           Type outETy = getElementTypeOrSelf(op.getType());
-          return !(inETy.isF32() && outETy.isBF16());
+          if(includeBf16 && includeF8E8M0) 
+            return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
+          if(includeBf16)
+            return !(inETy.isF32() && outETy.isBF16());
+          return 
+            !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16())); 
         });
     }
-
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -389,6 +504,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
       patterns.getContext());
 }
 
+void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
+  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
+      patterns.getContext());
+}
+
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index 765b787d3d17a..975b26ae4369f 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -33,7 +33,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 //===----------------------------------------------------------------------===//
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-
+bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index bdf022642b717..5b6badf13d763 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
 
 // Test ceil divide with signed integer
 // CHECK-LABEL:       func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
 // CHECK-LABEL: @truncf_vector_f32
 // CHECK-NOT: arith.truncf
 
+// -----
+func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f32 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
+    %0 = arith.truncf %arg0 : f16 to f8E8M0FNU
+    return %0 : f8E8M0FNU
+}
+// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
+// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
+// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
+    %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
+    return %0 : vector<4xf8E8M0FNU>
+}
+
+// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
+// CHECK-NOT: arith.truncf
+
+
+// -----
+func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
+    %0 = arith.extf %arg0 : f8E8M0FNU to f16
+    return %0 : f16
+}
+
+// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
+// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
+// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
+// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
+// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
+// CHECK: return %[[F16_RESULT]]
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
+    return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
+// CHECK-NOT: arith.extf
+
+// -----
+
+func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
+    %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
+    return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
+// CHECK-NOT: arith.extf
+
+
 // -----
 
 func.func @maxsi(%a: i32, %b: i32) -> i32 {

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but this does seem like a good fit for ExpandOps

@@ -291,7 +291,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Constant used to make the rounding bias.
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
// Constant used to generate a quiet NaN.
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change?

LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto operand = op.getOperand();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably be Value

return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
}

if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd not hardcode a list of targets here - I'd just plow forward with a cast to f32 and if the final target has a bitwidth less than 32 you truncate and for > 32 you extend

if (!resultETy.isF8E8M0FNU()) {
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
}
if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note: extend or truncate to f32 as needed

@@ -351,23 +453,36 @@ struct ArithExpandOpsPass
arith::MinNumFOp
>();

if (includeBf16) {
if(includeBf16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the space back?

Type inETy = getElementTypeOrSelf(op.getOperand().getType());
Type outETy = getElementTypeOrSelf(op.getType());
return !(inETy.isF32() && outETy.isBF16());
if(includeBf16 && includeF8E8M0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the condition can be simplified here

@@ -109,6 +109,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isF8E8M0FNU() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the isF* for small types were recently removed: #123326

@tgymnich
Copy link
Member

tgymnich commented May 17, 2025

When converting f32 to f8e8m0, should we map negative numbers to 0 (e.g. underflow to smallest normalized value)?

@krzysz00
Copy link
Contributor

As far as I'm aware, the cast down is meant to drop the sign but, not go to 0

Especially since f8E8M0FNU doesn't have a 0

@tgymnich
Copy link
Member

tgymnich commented May 17, 2025

Just dropping the sign makes sense to me.
I meant 0 as in smallest normalized value.

https://github.com/iree-org/iree/blob/c447638dae70fc21f5d84ad4cf402ca034a60cda/runtime/src/iree/base/internal/math.h#L596

@krzysz00 I assume this is wrong then and needs to be changed?

@krzysz00
Copy link
Contributor

Yeah, given that this meant to be an exponent part for scaling other floats by, I figure fabs() might be a better thing to have there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:arith mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.