diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 52bb0eb992b69..70aecfcfa3ec7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2558,7 +2558,6 @@ def Tosa_IfOp : Tosa_Op<"cond_if", SizedRegion<1>:$else_graph ); - let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 371c6dc27b428..2d7c80cbf7848 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3518,65 +3518,6 @@ std::optional> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } -// parse and print of IfOp refer to the implementation of SCF dialect. -ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { - // Create the regions for 'then'. - result.regions.reserve(2); - Region *thenRegion = result.addRegion(); - Region *elseRegion = result.addRegion(); - - auto &builder = parser.getBuilder(); - OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) - return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); - // Parse the 'then' region. - if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - - // If we find an 'else' keyword then parse the 'else' region. - if (!parser.parseOptionalKeyword("else")) { - if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - } - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} - -void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; - } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); - - // Print the 'else' regions if it exists and has a block. - auto &elseRegion = getElseGraph(); - if (!elseRegion.empty()) { - p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); - } - - p.printOptionalAttrDict((*this)->getAttrs()); -} - LogicalResult IfOp::verify() { if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(), "'then_graph' arguments", getInputList(), diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index fa7a91cda0a47..78f5040eab97a 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -36,20 +36,15 @@ func.func @while_test(%arg0 : tensor) -> (tensor) { func.func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor) { // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]] // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor) { - %0 = tosa.cond_if %arg2 -> (tensor) { - - // CHECK: scf.yield [[ARG0]] - tosa.yield %arg0 : tensor - - // CHECK: } else { - } else { - - // CHECK: scf.yield [[ARG1]] - tosa.yield %arg1 : tensor - - // CHECK: } - // CHECK: return [[IF]] - } + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: scf.yield [[ARG0]] + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: scf.yield [[ARG1]] + tosa.yield %arg4 : tensor + }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 75126a11ac504..5381d6c533d01 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -645,13 +645,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: tosa.cond_if profiles: [ ] // CHECK: tosa.cond_if extensions: [ [controlflow] ] - %0 = tosa.cond_if %arg2 -> (tensor) { - %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 2364985442e43..c688b6592ed9f 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -337,13 +337,15 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 // ----- func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { + }, { + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index d24c1fa57883d..5b11aa782637a 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1503,40 +1503,87 @@ func.func @test_while_tensor_list_size(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: // ----- -func.func @test_cond_if_max_nested_depth(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { - %0 = tosa.cond_if %arg2 -> (tensor) { - %1 = tosa.cond_if %arg3 -> (tensor) { - %2 = tosa.cond_if %arg2 -> (tensor) { - %3 = tosa.cond_if %arg3 -> (tensor) { - %4 = tosa.cond_if %arg2 -> (tensor) { +func.func @test_cond_if_max_nested_depth(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + + // COM: then graph of IF-1 + ^bb1(%a1: tensor, %b1: tensor): + %cond1 = tosa.equal %a1, %b1 : (tensor, tensor) -> tensor + %1 = "tosa.cond_if"(%cond1, %a1, %b1) ({ + + // COM: then graph of IF-2 + ^bb2(%a2: tensor, %b2: tensor): + %cond2 = tosa.equal %a2, %b2 : (tensor, tensor) -> tensor + %2 = "tosa.cond_if"(%cond2, %a2, %b2) ({ + + // COM: then graph of IF-3 + ^bb3(%a3: tensor, %b3: tensor): + %cond3 = tosa.equal %a3, %b3 : (tensor, tensor) -> tensor + %3 = "tosa.cond_if"(%cond3, %a3, %b3) ({ + + // COM: then graph of IF-4 + ^bb4(%a4: tensor, %b4: tensor): + %cond4 = tosa.equal %a4, %b4 : (tensor, tensor) -> tensor + %4 = "tosa.cond_if"(%cond4, %a4, %b4) ({ + + // COM: then graph of IF-5 + ^bb5(%a5: tensor, %b5: tensor): + %cond5 = tosa.equal %a5, %b5 : (tensor, tensor) -> tensor // expected-error@+1 {{'tosa.cond_if' op failed level check: 6 >= MAX_NESTING}} - %5 = tosa.cond_if %arg3 -> (tensor) { + %5 = "tosa.cond_if"(%cond5, %a5, %b5) ({ + + // COM: then graph of IF-6 + ^bb6(%a6: tensor, %b6: tensor): %res = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %res : tensor - } else { + }, { + + // COM: else graph of IF-6 + ^bb6(%a6: tensor, %b6: tensor): tosa.yield %arg0 : tensor - } + }) : (tensor, tensor, tensor) -> tensor + tosa.yield %5 : tensor - } else { + }, { + + // COM: else graph of IF-5 + ^bb5(%a5: tensor, %b5: tensor): + tosa.yield %arg0 : tensor + }) : (tensor, tensor, tensor) -> tensor + + tosa.yield %4 : tensor + }, { + + // COM: else graph of IF-4 + ^bb4(%a4: tensor, %b4: tensor): tosa.yield %arg0 : tensor - } - tosa.yield %4 : tensor - } else { - tosa.yield %arg0 : tensor - } - tosa.yield %3 : tensor - } else { + }) : (tensor, tensor, tensor) -> tensor + + tosa.yield %3 : tensor + }, { + + // COM: else graph of IF-3 + ^bb3(%a3: tensor, %b3: tensor): tosa.yield %arg0 : tensor - } - tosa.yield %2 : tensor - } else { - tosa.yield %arg0 : tensor - } - tosa.yield %1 : tensor - } else { - %res = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + }) : (tensor, tensor, tensor) -> tensor + + tosa.yield %2 : tensor + }, { + + // COM: else graph of IF-2 + ^bb2(%a2: tensor, %b2: tensor): + tosa.yield %arg0 : tensor + }) : (tensor, tensor, tensor) -> tensor + + tosa.yield %1 : tensor + }, { + + // COM: else graph of IF-1 + ^bb1(%a1: tensor, %b1: tensor): + %res = tosa.sub %a1, %b1 : (tensor, tensor) -> tensor tosa.yield %res : tensor - } + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index e327ed900f45f..e3036cf07171f 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -781,13 +781,15 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { + }, { + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 1ad1e6c76c294..981e3cc7fc129 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1121,12 +1121,14 @@ func.func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tens %b = tosa.log %arg1 : (tensor) -> tensor // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { - tosa.yield %a : tensor - } else { - tosa.yield %b : tensor - } + // CHECK: -> tensor + %0 = "tosa.cond_if"(%arg2, %a, %b) ({ + ^bb0(%a1: tensor, %b1: tensor): + tosa.yield %a1 : tensor + }, { + ^bb0(%a1: tensor, %b1: tensor): + tosa.yield %b1 : tensor + }) : (tensor, tensor, tensor) -> tensor return } @@ -1135,12 +1137,14 @@ func.func @if_test_simple(%arg0 : tensor, %arg1 : tensor, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { - tosa.yield %arg0 : tensor<2xf32> - } else { - tosa.yield %arg1 : tensor<3xf32> - } + // CHECK: -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>): + tosa.yield %a : tensor<2xf32> + }, { + ^bb0(%a: tensor<2xf32>, %b: tensor<3xf32>): + tosa.yield %b : tensor<3xf32> + }) : (tensor, tensor<2xf32>, tensor<3xf32>) -> tensor return } @@ -1149,12 +1153,14 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { - tosa.yield %arg0 : tensor - } else { - tosa.yield %arg1 : tensor<3xf32> - } + // CHECK: -> tensor<*xf32> + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%a: tensor, %b: tensor<3xf32>): + tosa.yield %a : tensor + }, { + ^bb0(%a: tensor, %b: tensor<3xf32>): + tosa.yield %b : tensor<3xf32> + }) : (tensor, tensor, tensor<3xf32>) -> tensor<*xf32> return } @@ -1163,14 +1169,16 @@ func.func @if_test_unranked(%arg0 : tensor, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor) - %0 = tosa.cond_if %arg2 -> (tensor) { - %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: -> tensor + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%a: tensor, %b: tensor): + %1 = tosa.add %a, %b : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + }, { + ^bb0(%a: tensor, %b: tensor): + %1 = tosa.sub %a, %b : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> tensor return } diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 990e0d954f54e..e99608dfbeff4 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -502,14 +502,17 @@ func.func @test_cond_if_input_list_mismatch_else_block_2(%arg0: tensor, %ar func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor tosa.yield %1, %2 : tensor, tensor - } else { + }, { + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor } @@ -517,13 +520,15 @@ func.func @test_cond_if_output_list_mismatch_then_block(%arg0: tensor, %arg func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'then_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { + %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { + }, { + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> (tensor, tensor) return %0 : tensor } @@ -531,14 +536,16 @@ func.func @test_cond_if_output_list_mismatch_then_block_2(%arg0: tensor, %a func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (2) and 'output_list' (1)}} - %0 = tosa.cond_if %arg2 -> (tensor) { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } else { - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor - %2 = tosa.add %1, %arg1 : (tensor, tensor) -> tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor + %2 = tosa.add %1, %arg3 : (tensor, tensor) -> tensor tosa.yield %1, %2 : tensor, tensor - } + }) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -546,14 +553,16 @@ func.func @test_cond_if_output_list_mismatch_else_block(%arg0: tensor, %arg func.func @test_cond_if_output_list_mismatch_else_block_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // expected-error@+1 {{'tosa.cond_if' op require same number of values in 'else_graph' results (1) and 'output_list' (2)}} - %0, %2 = tosa.cond_if %arg2 -> (tensor, tensor) { - %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor - %2 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + %0, %2 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor + %2 = tosa.sub %arg3, %arg4 : (tensor, tensor) -> tensor tosa.yield %1, %2 : tensor, tensor - } else { - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = tosa.add %arg3, %arg4 : (tensor, tensor) -> tensor tosa.yield %1 : tensor - } + }) : (tensor, tensor, tensor) -> (tensor, tensor) return %0 : tensor }