diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index c90ebe4487ca4..7ca7d760d9165 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -718,6 +718,54 @@ static Operation *replaceForAllWithNewSignature( return newforallOp; } +/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true +/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a +/// outer loop. To determine the second condition, this function iterates +/// using a worklist over the enclosing loops, trying to find 'src' in any of +/// the parent loop's iter args. +static bool sameOrEquivalentIterArg(Value src, Value dst) { + // Stack like vector containing possible iterArgs candidates. The first one + // is dst, and we will transverse the IR from there. + SmallVector destWorklist; + destWorklist.push_back(dst); + + while (!destWorklist.empty()) { + Value currentDst = destWorklist.pop_back_val(); + + // We have found the same operand in some iter arg in the loop structure, + // so src and dst are equivalent. + if (src == currentDst) + return true; + + // The operands are not equivalent, look for enclosing loops over + // currentDst. + auto bbArg = dyn_cast(currentDst); + if (!bbArg) + continue; + + Block *parentBlock = bbArg.getOwner(); + assert(parentBlock && "unlinked block argument"); + + Operation *parentOp = parentBlock->getParentOp(); + assert(parentOp && "expected block argument with parent operation"); + + // Check if parent is loop-like. If it's not, do not add it to the worklist. + auto parentLoop = dyn_cast(parentOp); + if (!parentLoop) + continue; + + for (auto innerIterArg : parentLoop.getRegionIterArgs()) { + // No need to check for null as innerIterArg is tied to parentLoop. + OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg); + Value loopBlockArgument = + parentLoop->getOperand(operand->getOperandNumber()); + destWorklist.push_back(loopBlockArgument); + } + } + + return false; +} + /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. @@ -755,6 +803,40 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sliceOpToTile); + // Clone the producer inside the consumer and try to update the producer init + // operands using the loop bbArgs if applicable. More precisely, if the bbArg + // of the container loop points to a value that it is used by the consumer op, + // then, instead of using such value on the consumer, use the value coming + // from the bbArg instead. This allows to reuse the output tensor (instead of + // creating a new one) of the container when both producer and container write + // to the same output. + if (LoopLikeOpInterface containerLoop = + dyn_cast(sliceOpToTile->getParentOp())) { + Operation *clone = rewriter.clone(*producerOp); + rewriter.modifyOpInPlace(clone, [&]() { + // Iterate over the outputs of the producer and over the loop bbArgs and + // check if any bbArg points to the same value as the producer output. In + // such case, make the producer output point to the bbArg directly. + for (OpOperand &initOperandPtr : + cast(clone).getDpsInitsMutable()) { + Value producerOperand = + clone->getOperand(initOperandPtr.getOperandNumber()); + for (BlockArgument containerIterArg : + containerLoop.getRegionIterArgs()) { + OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg); + Value consumerOperand = + containerLoop->getOperand(bbArg->getOperandNumber()); + // The producer has the same init as the loop bbArg, use it. + if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) { + initOperandPtr.set(containerIterArg); + } + } + } + }); + + tileableProducer = dyn_cast(clone); + } + // Tile the producer. int64_t resultNumber = cast(sliceOpToTile.getSource()).getResultNumber(); @@ -797,6 +879,10 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, rewriter, diag, producerOp, containingOp, *tileAndFuseResult, resultNumber, offsets, sizes); + // Cleanup clone. + if (dyn_cast(containingOp)) + rewriter.eraseOp(tileableProducer); + return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 4115f2857a20c..572a2ae70e0a4 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -206,6 +206,106 @@ module { #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +module { + // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[INOUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op_through_bbarg_inout(%arg0: index, %arg1: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor + %d0 = tensor.dim %arg1, %c0 : tensor + %1 = affine.apply #map0()[%d0, %arg0] + + // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor) { + %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg1) -> (tensor) { + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%d0, %arg0] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]] + %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor) outs(%[[T0]] : tensor) + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested + // CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op_through_bbarg_inout_nested(%arg0: tensor, %arg1: tensor) -> tensor { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%arg0 : tensor) outs(%arg1 : tensor) -> tensor + %dim = tensor.dim %arg1, %c0 : tensor + %dim_0 = tensor.dim %arg1, %c1 : tensor + %dim_1 = tensor.dim %arg1, %c2 : tensor + // CHECK: scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor) { + // CHECK: scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor) { + // CHECK: scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor) { + %1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args(%arg3 = %arg1) -> (tensor) { + %2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor) { + %3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor) { + // CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor to tensor<1x1x1xf32> + // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + // CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor to tensor<1x1x1xf32> + // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + %extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor to tensor<1x1x1xf32> + %extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor to tensor<1x1x1xf32> + %4 = linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32> + %inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor + scf.yield %inserted_slice : tensor + } + scf.yield %3 : tensor + } + scf.yield %2 : tensor + } + return %1 : tensor + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + module { // CHECK-LABEL: func.func @fuse_tileable_multi_output_op // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index