Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][Transform] Reuse bbArgs in FuseIntoContainingOp #135066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions 86 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,54 @@ static Operation *replaceForAllWithNewSignature(
return newforallOp;
}

/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
/// outer loop. To determine the second condition, this function iterates
/// using a worklist over the enclosing loops, trying to find 'src' in any of
/// the parent loop's iter args.
static bool sameOrEquivalentIterArg(Value src, Value dst) {
// Stack like vector containing possible iterArgs candidates. The first one
// is dst, and we will transverse the IR from there.
SmallVector<Value> destWorklist;
destWorklist.push_back(dst);

while (!destWorklist.empty()) {
Value currentDst = destWorklist.pop_back_val();

// We have found the same operand in some iter arg in the loop structure,
// so src and dst are equivalent.
if (src == currentDst)
return true;

// The operands are not equivalent, look for enclosing loops over
// currentDst.
auto bbArg = dyn_cast<BlockArgument>(currentDst);
if (!bbArg)
continue;

Block *parentBlock = bbArg.getOwner();
assert(parentBlock && "unlinked block argument");

Operation *parentOp = parentBlock->getParentOp();
assert(parentOp && "expected block argument with parent operation");

// Check if parent is loop-like. If it's not, do not add it to the worklist.
auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
if (!parentLoop)
continue;

for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
// No need to check for null as innerIterArg is tied to parentLoop.
OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
Value loopBlockArgument =
parentLoop->getOperand(operand->getOperandNumber());
destWorklist.push_back(loopBlockArgument);
}
}

return false;
}

/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
Expand Down Expand Up @@ -755,6 +803,40 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);

// Clone the producer inside the consumer and try to update the producer init
// operands using the loop bbArgs if applicable. More precisely, if the bbArg
// of the container loop points to a value that it is used by the consumer op,
// then, instead of using such value on the consumer, use the value coming
// from the bbArg instead. This allows to reuse the output tensor (instead of
// creating a new one) of the container when both producer and container write
// to the same output.
if (LoopLikeOpInterface containerLoop =
dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
Operation *clone = rewriter.clone(*producerOp);
rewriter.modifyOpInPlace(clone, [&]() {
// Iterate over the outputs of the producer and over the loop bbArgs and
// check if any bbArg points to the same value as the producer output. In
// such case, make the producer output point to the bbArg directly.
for (OpOperand &initOperandPtr :
cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
Value producerOperand =
clone->getOperand(initOperandPtr.getOperandNumber());
for (BlockArgument containerIterArg :
containerLoop.getRegionIterArgs()) {
OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
ftynse marked this conversation as resolved.
Show resolved Hide resolved
Value consumerOperand =
containerLoop->getOperand(bbArg->getOperandNumber());
// The producer has the same init as the loop bbArg, use it.
if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
initOperandPtr.set(containerIterArg);
}
}
}
});

tileableProducer = dyn_cast<TilingInterface>(clone);
}

// Tile the producer.
int64_t resultNumber =
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
Expand Down Expand Up @@ -797,6 +879,10 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
resultNumber, offsets, sizes);

// Cleanup clone.
if (dyn_cast<LoopLikeOpInterface>(containingOp))
rewriter.eraseOp(tileableProducer);

return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
}

Expand Down
100 changes: 100 additions & 0 deletions 100 mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,106 @@ module {
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

module {
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
// CHECK-SAME: %[[INOUT:[0-9a-z]+]]: tensor<?xf32>
func.func @fuse_tileable_op_through_bbarg_inout(%arg0: index, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
%1 = affine.apply #map0()[%d0, %arg0]

// CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[INOUT]]) -> (tensor<?xf32>) {
%2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg1) -> (tensor<?xf32>) {
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%d0, %arg0]
%5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>

// CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
// CHECK: %[[T1:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}]
// CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>

// CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
%7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
}
// CHECK: }
func.return %2 : tensor<?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op

// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
}

// -----

module {
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x?xf32>
func.func @fuse_tileable_op_through_bbarg_inout_nested(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
// CHECK: scf.for {{.*}} iter_args(%[[BBARG0:.*]] = %[[ARG1]]) -> (tensor<?x?x?xf32>) {
// CHECK: scf.for {{.*}} iter_args(%[[BBARG1:.*]] = %[[BBARG0]]) -> (tensor<?x?x?xf32>) {
// CHECK: scf.for {{.*}} iter_args(%[[BBARG2:.*]] = %[[BBARG1]]) -> (tensor<?x?x?xf32>) {
%1 = scf.for %arg2 = %c0 to %dim step %c1 iter_args(%arg3 = %arg1) -> (tensor<?x?x?xf32>) {
%2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor<?x?x?xf32>) {
%3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor<?x?x?xf32>) {
// CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
// CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
// CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
%extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
%4 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<?x?x?xf32>
scf.yield %inserted_slice : tensor<?x?x?xf32>
}
scf.yield %3 : tensor<?x?x?xf32>
}
scf.yield %2 : tensor<?x?x?xf32>
}
return %1 : tensor<?x?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.structured.fuse_into_containing_op %2#0 into %3#2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
}

// -----

#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

module {
// CHECK-LABEL: func.func @fuse_tileable_multi_output_op
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.