Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[CIR] Implement switch case simplify #140649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2025

Conversation

Andres-Salamanca
Copy link
Contributor

This PR introduces a new CIR simplify for switch cases, which folds multiple cascading Equal cases (that contain only a YieldOp) into a single CaseOp of kind AnyOf.

This logic is based on the suggestion from this discussion:
#138003 (comment)

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 20, 2025
@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-clangir

Author: None (Andres-Salamanca)

Changes

This PR introduces a new CIR simplify for switch cases, which folds multiple cascading Equal cases (that contain only a YieldOp) into a single CaseOp of kind AnyOf.

This logic is based on the suggestion from this discussion:
#138003 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/140649.diff

4 Files Affected:

  • (modified) clang/include/clang/CIR/MissingFeatures.h (-1)
  • (modified) clang/lib/CIR/CodeGen/CIRGenStmt.cpp (-6)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (+104-2)
  • (added) clang/test/CIR/Transforms/switch-fold.cir (+196)
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
   static bool opUnaryPromotionType() { return false; }
 
   // SwitchOp handling
-  static bool foldCascadingCases() { return false; }
   static bool foldRangeCase() { return false; }
 
   // Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
                                   cir::IntAttr::get(condType, endVal)});
     kind = cir::CaseOpKind::Range;
-
-    // We don't currently fold case range statements with other case statements.
-    // TODO(cir): Add this capability. Folding these cases is going to be
-    // implemented in CIRSimplify when it is upstreamed.
-    assert(!cir::MissingFeatures::foldRangeCase());
-    assert(!cir::MissingFeatures::foldCascadingCases());
   } else {
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
     kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
   }
 };
 
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+///   cir.case equal, [#cir.int<0> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<1> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<2> : !s32i] {
+///     cir.break
+///   }
+///
+/// After applying SimplifySwitch:
+///   cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+///   !s32i] {
+///     cir.break
+///   }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const override {
+
+    LogicalResult changed = mlir::failure();
+    llvm::SmallVector<CaseOp, 8> cases;
+    SmallVector<CaseOp, 4> cascadingCases;
+    SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+    op.collectCases(cases);
+    if (cases.empty())
+      return mlir::failure();
+
+    auto flushMergedOps = [&]() {
+      for (CaseOp &c : cascadingCases) {
+        rewriter.eraseOp(c);
+      }
+      cascadingCases.clear();
+      cascadingCaseValues.clear();
+    };
+
+    auto mergeCascadingInto = [&](CaseOp &target) {
+      rewriter.modifyOpInPlace(target, [&]() {
+        target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+        target.setKind(CaseOpKind::Anyof);
+      });
+      changed = mlir::success();
+    };
+
+    for (CaseOp c : cases) {
+      cir::CaseOpKind kind = c.getKind();
+      if (kind == cir::CaseOpKind::Equal &&
+          isa<YieldOp>(c.getCaseRegion().front().front())) {
+        // If the case contains only a YieldOp, collect it for cascading merge
+        cascadingCases.push_back(c);
+        cascadingCaseValues.push_back(c.getValue()[0]);
+
+      } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+        // merge previously collected cascading cases
+        cascadingCaseValues.push_back(c.getValue()[0]);
+        mergeCascadingInto(c);
+        flushMergedOps();
+      } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+        // If a Default, Anyof or Range case is found and there are previous
+        // cascading cases, merge all of them into the last cascading case.
+        CaseOp lastCascadingCase = cascadingCases.back();
+        mergeCascadingInto(lastCascadingCase);
+        cascadingCases.pop_back();
+        flushMergedOps();
+      } else {
+        cascadingCases.clear();
+        cascadingCaseValues.clear();
+      }
+    }
+
+    // Edge case: all cases are simple cascading cases
+    if (cascadingCases.size() == cases.size()) {
+      CaseOp lastCascadingCase = cascadingCases.back();
+      mergeCascadingInto(lastCascadingCase);
+      cascadingCases.pop_back();
+      flushMergedOps();
+    }
+    // We don't currently fold case range statements with other case statements.
+    assert(!cir::MissingFeatures::foldRangeCase());
+    return changed;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // CIRSimplifyPass
 //===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
   // clang-format off
   patterns.add<
     SimplifyTernary,
-    SimplifySelect
+    SimplifySelect,
+    SimplifySwitch
   >(patterns.getContext());
   // clang-format on
 }
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
   // Collect operations to apply patterns.
   llvm::SmallVector<Operation *, 16> ops;
   getOperation()->walk([&](Operation *op) {
-    if (isa<TernaryOp, SelectOp>(op))
+    if (isa<TernaryOp, SelectOp, SwitchOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+    cir.func @foldCascade(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          %2 = cir.const #cir.int<2> : !s32i
+          cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK-NEXT:     cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
+  //CHECK-NEXT:       %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+  //CHECK-NEXT:       cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+  //CHECK-NEXT:       cir.break
+  //CHECK-NEXT:     }
+  //CHECK-NEXT:     cir.yield
+  //CHECK-NEXT:   }
+
+    cir.func @foldCascade2(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: @foldCascade2
+  //CHECK:   cir.switch (%[[COND2:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+  //CHECK:       cir.break
+  //cehck:     }
+  //CHECK:     cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.break
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+  cir.func @foldCascade3(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+      %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%2 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade3
+  //CHECK:   cir.switch (%[[COND3:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.break
+  //CHECK:    }
+  //CHECK:    cir.yield
+  //CHECK:   }
+  cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.case(default, []) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<6> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<7> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascadeWithDefault
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:      cir.case(equal, [#cir.int<3> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(default, []) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.yield
+  //CHECK:   }
+  cir.func @foldAllCascade(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldAllCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.yield
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+}

@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-clang

Author: None (Andres-Salamanca)

Changes

This PR introduces a new CIR simplify for switch cases, which folds multiple cascading Equal cases (that contain only a YieldOp) into a single CaseOp of kind AnyOf.

This logic is based on the suggestion from this discussion:
#138003 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/140649.diff

4 Files Affected:

  • (modified) clang/include/clang/CIR/MissingFeatures.h (-1)
  • (modified) clang/lib/CIR/CodeGen/CIRGenStmt.cpp (-6)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (+104-2)
  • (added) clang/test/CIR/Transforms/switch-fold.cir (+196)
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
   static bool opUnaryPromotionType() { return false; }
 
   // SwitchOp handling
-  static bool foldCascadingCases() { return false; }
   static bool foldRangeCase() { return false; }
 
   // Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
                                   cir::IntAttr::get(condType, endVal)});
     kind = cir::CaseOpKind::Range;
-
-    // We don't currently fold case range statements with other case statements.
-    // TODO(cir): Add this capability. Folding these cases is going to be
-    // implemented in CIRSimplify when it is upstreamed.
-    assert(!cir::MissingFeatures::foldRangeCase());
-    assert(!cir::MissingFeatures::foldCascadingCases());
   } else {
     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
     kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
   }
 };
 
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+///   cir.case equal, [#cir.int<0> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<1> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<2> : !s32i] {
+///     cir.break
+///   }
+///
+/// After applying SimplifySwitch:
+///   cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+///   !s32i] {
+///     cir.break
+///   }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const override {
+
+    LogicalResult changed = mlir::failure();
+    llvm::SmallVector<CaseOp, 8> cases;
+    SmallVector<CaseOp, 4> cascadingCases;
+    SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+    op.collectCases(cases);
+    if (cases.empty())
+      return mlir::failure();
+
+    auto flushMergedOps = [&]() {
+      for (CaseOp &c : cascadingCases) {
+        rewriter.eraseOp(c);
+      }
+      cascadingCases.clear();
+      cascadingCaseValues.clear();
+    };
+
+    auto mergeCascadingInto = [&](CaseOp &target) {
+      rewriter.modifyOpInPlace(target, [&]() {
+        target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+        target.setKind(CaseOpKind::Anyof);
+      });
+      changed = mlir::success();
+    };
+
+    for (CaseOp c : cases) {
+      cir::CaseOpKind kind = c.getKind();
+      if (kind == cir::CaseOpKind::Equal &&
+          isa<YieldOp>(c.getCaseRegion().front().front())) {
+        // If the case contains only a YieldOp, collect it for cascading merge
+        cascadingCases.push_back(c);
+        cascadingCaseValues.push_back(c.getValue()[0]);
+
+      } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+        // merge previously collected cascading cases
+        cascadingCaseValues.push_back(c.getValue()[0]);
+        mergeCascadingInto(c);
+        flushMergedOps();
+      } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+        // If a Default, Anyof or Range case is found and there are previous
+        // cascading cases, merge all of them into the last cascading case.
+        CaseOp lastCascadingCase = cascadingCases.back();
+        mergeCascadingInto(lastCascadingCase);
+        cascadingCases.pop_back();
+        flushMergedOps();
+      } else {
+        cascadingCases.clear();
+        cascadingCaseValues.clear();
+      }
+    }
+
+    // Edge case: all cases are simple cascading cases
+    if (cascadingCases.size() == cases.size()) {
+      CaseOp lastCascadingCase = cascadingCases.back();
+      mergeCascadingInto(lastCascadingCase);
+      cascadingCases.pop_back();
+      flushMergedOps();
+    }
+    // We don't currently fold case range statements with other case statements.
+    assert(!cir::MissingFeatures::foldRangeCase());
+    return changed;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // CIRSimplifyPass
 //===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
   // clang-format off
   patterns.add<
     SimplifyTernary,
-    SimplifySelect
+    SimplifySelect,
+    SimplifySwitch
   >(patterns.getContext());
   // clang-format on
 }
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
   // Collect operations to apply patterns.
   llvm::SmallVector<Operation *, 16> ops;
   getOperation()->walk([&](Operation *op) {
-    if (isa<TernaryOp, SelectOp>(op))
+    if (isa<TernaryOp, SelectOp, SwitchOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+    cir.func @foldCascade(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          %2 = cir.const #cir.int<2> : !s32i
+          cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK-NEXT:     cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
+  //CHECK-NEXT:       %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+  //CHECK-NEXT:       cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+  //CHECK-NEXT:       cir.break
+  //CHECK-NEXT:     }
+  //CHECK-NEXT:     cir.yield
+  //CHECK-NEXT:   }
+
+    cir.func @foldCascade2(%arg0: !s32i) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: @foldCascade2
+  //CHECK:   cir.switch (%[[COND2:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+  //CHECK:       cir.break
+  //cehck:     }
+  //CHECK:     cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.break
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+  cir.func @foldCascade3(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+      %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%2 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascade3
+  //CHECK:   cir.switch (%[[COND3:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.break
+  //CHECK:    }
+  //CHECK:    cir.yield
+  //CHECK:   }
+  cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.break
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.case(default, []) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<6> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<7> : !s32i]) {
+          cir.break
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldCascadeWithDefault
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:      cir.case(equal, [#cir.int<3> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(default, []) {
+  //CHECK:        cir.yield
+  //CHECK:      }
+  //CHECK:      cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+  //CHECK:        cir.break
+  //CHECK:      }
+  //CHECK:      cir.yield
+  //CHECK:   }
+  cir.func @foldAllCascade(%arg0: !s32i ) {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    cir.scope {
+      %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.switch (%1 : !s32i) {
+        cir.case(equal, [#cir.int<0> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<1> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<2> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<3> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<4> : !s32i]) {
+          cir.yield
+        }
+        cir.case(equal, [#cir.int<5> : !s32i]) {
+          cir.yield
+        }
+        cir.yield
+      }
+    }
+    cir.return
+  }
+  //CHECK: cir.func @foldAllCascade
+  //CHECK:   cir.switch (%[[COND:.*]] : !s32i) {
+  //CHECK:     cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+  //CHECK:       cir.yield
+  //CHECK:     }
+  //CHECK:     cir.yield
+  //CHECK:   }
+}

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, some inline comments.

auto flushMergedOps = [&]() {
for (CaseOp &c : cascadingCases) {
rewriter.eraseOp(c);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curly braces not necessary here.

@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<TernaryOp, SelectOp>(op))
if (isa<TernaryOp, SelectOp, SwitchOp>(op))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check if the test passes without issues if -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON is used while building clang? We currently having incubator issues with this and probably best to make sure we don't introduce them here if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have built Clang with the -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON flag, and this pull request passes successfully. However, I have identified two failing tests:


********************
Failed Tests (2):
  Clang :: CIR/CodeGen/loop.cpp
  Clang :: CIR/Transforms/switch.cir
********************

The CIR/Transforms/switch.cir test fails when applying the -cir-flatten-cfg pass. I'm going to check if this test is also failing in the incubator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked and it also fails in the incubator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for walking the extra leg here, can you create an issue for this so we can take a look later? (cc @xlauko which pointed this recently).

@bcardosolopes bcardosolopes changed the title Implement CIR switch case simplify with appropriate tests [CIR] Implement switch case simplify May 20, 2025
@Andres-Salamanca
Copy link
Contributor Author

@andykaylor @erichkeane @bcardosolopes
While implementing this simplification, I noticed a difference compared to the previous behavior where the folding happened during codegen.

Consider the following test case:

void sw7(int a) {
  switch (a) {
  case 0:
  case 1:
  case 2:
    int x;
  case 3:
  case 4:
  case 5:
    break;
  }
}

In the previous implementation, the codegen phase performed the folding. Variable x was hoisted outside the switch and the generated CIR looked like this:

 cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
   cir.yield
 }
 cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
   cir.break
 }

Now, without codegen folding and before applying the new pass, the output is:

%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64} // x hoisted out as before
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%2 : !s32i) {
  cir.case(equal, [#cir.int<0> : !s32i]) {
    cir.yield
  }
  cir.case(equal, [#cir.int<1> : !s32i]) {
    cir.yield
  }
  cir.case(equal, [#cir.int<2> : !s32i]) {
    cir.yield
  }
  cir.case(equal, [#cir.int<3> : !s32i]) {
    cir.yield
  }
  cir.case(equal, [#cir.int<4> : !s32i]) {
    cir.yield
  }
  cir.case(equal, [#cir.int<5> : !s32i]) {
    cir.break
  }
  cir.yield
}

After applying the new CIR simplify pass, the output is:

cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
   cir.break
}

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I have just a few minor suggestions.

clang/test/CIR/Transforms/switch-fold.cir Show resolved Hide resolved
clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp Outdated Show resolved Hide resolved
clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp Outdated Show resolved Hide resolved
@andykaylor
Copy link
Contributor

@andykaylor @erichkeane @bcardosolopes While implementing this simplification, I noticed a difference compared to the previous behavior where the folding happened during codegen.
...
After applying the new CIR simplify pass, the output is:

cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
   cir.break
}

That's semantically equivalent, and I would say it's an improvement.

@bcardosolopes
Copy link
Member

That's semantically equivalent, and I would say it's an improvement.

This is awesome, big +1.

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once all the other comments are addressed.

@andykaylor andykaylor merged commit 3eb9e77 into llvm:main May 22, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
This PR introduces a new **CIR simplify for `switch` cases**, which
folds multiple **cascading `Equal` cases** (that contain only a
`YieldOp`) into a single `CaseOp` of kind `AnyOf`.

This logic is based on the suggestion from this discussion:
llvm#138003 (comment)
ajaden-codes pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 6, 2025
This PR introduces a new **CIR simplify for `switch` cases**, which
folds multiple **cascading `Equal` cases** (that contain only a
`YieldOp`) into a single `CaseOp` of kind `AnyOf`.

This logic is based on the suggestion from this discussion:
llvm#138003 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.