-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[CIR] Implement switch case simplify #140649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clangir Author: None (Andres-Salamanca) ChangesThis PR introduces a new CIR simplify for This logic is based on the suggestion from this discussion: Full diff: https://github.com/llvm/llvm-project/pull/140649.diff 4 Files Affected:
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
static bool opUnaryPromotionType() { return false; }
// SwitchOp handling
- static bool foldCascadingCases() { return false; }
static bool foldRangeCase() { return false; }
// Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
cir::IntAttr::get(condType, endVal)});
kind = cir::CaseOpKind::Range;
-
- // We don't currently fold case range statements with other case statements.
- // TODO(cir): Add this capability. Folding these cases is going to be
- // implemented in CIRSimplify when it is upstreamed.
- assert(!cir::MissingFeatures::foldRangeCase());
- assert(!cir::MissingFeatures::foldCascadingCases());
} else {
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
}
};
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+/// cir.case equal, [#cir.int<0> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<1> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<2> : !s32i] {
+/// cir.break
+/// }
+///
+/// After applying SimplifySwitch:
+/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+/// !s32i] {
+/// cir.break
+/// }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const override {
+
+ LogicalResult changed = mlir::failure();
+ llvm::SmallVector<CaseOp, 8> cases;
+ SmallVector<CaseOp, 4> cascadingCases;
+ SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+ op.collectCases(cases);
+ if (cases.empty())
+ return mlir::failure();
+
+ auto flushMergedOps = [&]() {
+ for (CaseOp &c : cascadingCases) {
+ rewriter.eraseOp(c);
+ }
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ };
+
+ auto mergeCascadingInto = [&](CaseOp &target) {
+ rewriter.modifyOpInPlace(target, [&]() {
+ target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+ target.setKind(CaseOpKind::Anyof);
+ });
+ changed = mlir::success();
+ };
+
+ for (CaseOp c : cases) {
+ cir::CaseOpKind kind = c.getKind();
+ if (kind == cir::CaseOpKind::Equal &&
+ isa<YieldOp>(c.getCaseRegion().front().front())) {
+ // If the case contains only a YieldOp, collect it for cascading merge
+ cascadingCases.push_back(c);
+ cascadingCaseValues.push_back(c.getValue()[0]);
+
+ } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+ // merge previously collected cascading cases
+ cascadingCaseValues.push_back(c.getValue()[0]);
+ mergeCascadingInto(c);
+ flushMergedOps();
+ } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+ // If a Default, Anyof or Range case is found and there are previous
+ // cascading cases, merge all of them into the last cascading case.
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ } else {
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ }
+ }
+
+ // Edge case: all cases are simple cascading cases
+ if (cascadingCases.size() == cases.size()) {
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ }
+ // We don't currently fold case range statements with other case statements.
+ assert(!cir::MissingFeatures::foldRangeCase());
+ return changed;
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
- SimplifySelect
+ SimplifySelect,
+ SimplifySwitch
>(patterns.getContext());
// clang-format on
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
- if (isa<TernaryOp, SelectOp>(op))
+ if (isa<TernaryOp, SelectOp, SwitchOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @foldCascade(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ %2 = cir.const #cir.int<2> : !s32i
+ cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
+ //CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+ //CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+ //CHECK-NEXT: cir.break
+ //CHECK-NEXT: }
+ //CHECK-NEXT: cir.yield
+ //CHECK-NEXT: }
+
+ cir.func @foldCascade2(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: @foldCascade2
+ //CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+ //CHECK: cir.break
+ //cehck: }
+ //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascade3(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+ %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%2 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade3
+ //CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.case(default, []) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<6> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<7> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascadeWithDefault
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(default, []) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldAllCascade(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldAllCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+}
|
@llvm/pr-subscribers-clang Author: None (Andres-Salamanca) ChangesThis PR introduces a new CIR simplify for This logic is based on the suggestion from this discussion: Full diff: https://github.com/llvm/llvm-project/pull/140649.diff 4 Files Affected:
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
static bool opUnaryPromotionType() { return false; }
// SwitchOp handling
- static bool foldCascadingCases() { return false; }
static bool foldRangeCase() { return false; }
// Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
cir::IntAttr::get(condType, endVal)});
kind = cir::CaseOpKind::Range;
-
- // We don't currently fold case range statements with other case statements.
- // TODO(cir): Add this capability. Folding these cases is going to be
- // implemented in CIRSimplify when it is upstreamed.
- assert(!cir::MissingFeatures::foldRangeCase());
- assert(!cir::MissingFeatures::foldCascadingCases());
} else {
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
}
};
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+/// cir.case equal, [#cir.int<0> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<1> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<2> : !s32i] {
+/// cir.break
+/// }
+///
+/// After applying SimplifySwitch:
+/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+/// !s32i] {
+/// cir.break
+/// }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const override {
+
+ LogicalResult changed = mlir::failure();
+ llvm::SmallVector<CaseOp, 8> cases;
+ SmallVector<CaseOp, 4> cascadingCases;
+ SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+ op.collectCases(cases);
+ if (cases.empty())
+ return mlir::failure();
+
+ auto flushMergedOps = [&]() {
+ for (CaseOp &c : cascadingCases) {
+ rewriter.eraseOp(c);
+ }
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ };
+
+ auto mergeCascadingInto = [&](CaseOp &target) {
+ rewriter.modifyOpInPlace(target, [&]() {
+ target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+ target.setKind(CaseOpKind::Anyof);
+ });
+ changed = mlir::success();
+ };
+
+ for (CaseOp c : cases) {
+ cir::CaseOpKind kind = c.getKind();
+ if (kind == cir::CaseOpKind::Equal &&
+ isa<YieldOp>(c.getCaseRegion().front().front())) {
+ // If the case contains only a YieldOp, collect it for cascading merge
+ cascadingCases.push_back(c);
+ cascadingCaseValues.push_back(c.getValue()[0]);
+
+ } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+ // merge previously collected cascading cases
+ cascadingCaseValues.push_back(c.getValue()[0]);
+ mergeCascadingInto(c);
+ flushMergedOps();
+ } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+ // If a Default, Anyof or Range case is found and there are previous
+ // cascading cases, merge all of them into the last cascading case.
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ } else {
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ }
+ }
+
+ // Edge case: all cases are simple cascading cases
+ if (cascadingCases.size() == cases.size()) {
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ }
+ // We don't currently fold case range statements with other case statements.
+ assert(!cir::MissingFeatures::foldRangeCase());
+ return changed;
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
- SimplifySelect
+ SimplifySelect,
+ SimplifySwitch
>(patterns.getContext());
// clang-format on
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
- if (isa<TernaryOp, SelectOp>(op))
+ if (isa<TernaryOp, SelectOp, SwitchOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @foldCascade(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ %2 = cir.const #cir.int<2> : !s32i
+ cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
+ //CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+ //CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+ //CHECK-NEXT: cir.break
+ //CHECK-NEXT: }
+ //CHECK-NEXT: cir.yield
+ //CHECK-NEXT: }
+
+ cir.func @foldCascade2(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: @foldCascade2
+ //CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+ //CHECK: cir.break
+ //cehck: }
+ //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascade3(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+ %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%2 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade3
+ //CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.case(default, []) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<6> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<7> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascadeWithDefault
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(default, []) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldAllCascade(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldAllCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, some inline comments.
auto flushMergedOps = [&]() { | ||
for (CaseOp &c : cascadingCases) { | ||
rewriter.eraseOp(c); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curly braces not necessary here.
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() { | ||
// Collect operations to apply patterns. | ||
llvm::SmallVector<Operation *, 16> ops; | ||
getOperation()->walk([&](Operation *op) { | ||
if (isa<TernaryOp, SelectOp>(op)) | ||
if (isa<TernaryOp, SelectOp, SwitchOp>(op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you double check if the test passes without issues if -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON
is used while building clang? We currently having incubator issues with this and probably best to make sure we don't introduce them here if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have built Clang with the -DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON
flag, and this pull request passes successfully. However, I have identified two failing tests:
********************
Failed Tests (2):
Clang :: CIR/CodeGen/loop.cpp
Clang :: CIR/Transforms/switch.cir
********************
The CIR/Transforms/switch.cir
test fails when applying the -cir-flatten-cfg
pass. I'm going to check if this test is also failing in the incubator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked and it also fails in the incubator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for walking the extra leg here, can you create an issue for this so we can take a look later? (cc @xlauko which pointed this recently).
@andykaylor @erichkeane @bcardosolopes Consider the following test case: void sw7(int a) {
switch (a) {
case 0:
case 1:
case 2:
int x;
case 3:
case 4:
case 5:
break;
}
} In the previous implementation, the codegen phase performed the folding. Variable
Now, without codegen folding and before applying the new pass, the output is: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64} // x hoisted out as before
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.switch (%2 : !s32i) {
cir.case(equal, [#cir.int<0> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<3> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<4> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
cir.break
}
cir.yield
} After applying the new CIR simplify pass, the output is: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
cir.break
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. I have just a few minor suggestions.
That's semantically equivalent, and I would say it's an improvement. |
This is awesome, big +1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once all the other comments are addressed.
This PR introduces a new **CIR simplify for `switch` cases**, which folds multiple **cascading `Equal` cases** (that contain only a `YieldOp`) into a single `CaseOp` of kind `AnyOf`. This logic is based on the suggestion from this discussion: llvm#138003 (comment)
This PR introduces a new **CIR simplify for `switch` cases**, which folds multiple **cascading `Equal` cases** (that contain only a `YieldOp`) into a single `CaseOp` of kind `AnyOf`. This logic is based on the suggestion from this discussion: llvm#138003 (comment)
This PR introduces a new CIR simplify for
switch
cases, which folds multiple cascadingEqual
cases (that contain only aYieldOp
) into a singleCaseOp
of kindAnyOf
.This logic is based on the suggestion from this discussion:
#138003 (comment)