Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][affine] Set overflow flags when lowering [de]linearize_index #139612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions 8 mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.

As with other affine operations, lowerings of delinearize_index may assume
that the underlying computations do not overflow the index type in a signed sense
- that is, the product of all basis elements is positive as an `index` as well.
}];

let arguments = (ins Index:$linear_index,
Expand Down Expand Up @@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
In addition, `disjoint` is an assertion that all bases elements are non-negative.

Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.

As with other affine ops, undefined behavior occurs if the linearization
computation overflows in the signed sense.

Example:

```mlir
Expand Down
38 changes: 28 additions & 10 deletions 38 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ using namespace mlir::affine;
///
/// If excess dynamic values are provided, the values at the beginning
/// will be ignored. This allows for dropping the outer bound without
/// needing to manipulate the dynamic value array.
/// needing to manipulate the dynamic value array. `knownPositive`
/// indicases that the values being used to compute the strides are known
/// to be non-negative.
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
ValueRange dynamicBasis,
ArrayRef<int64_t> staticBasis) {
ArrayRef<int64_t> staticBasis,
bool knownNonNegative) {
if (staticBasis.empty())
return {};

Expand All @@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
size_t dynamicIndex = dynamicBasis.size();
Value dynamicPart = nullptr;
int64_t staticPart = 1;
// The products of the strides can't have overflow by definition of
// affine.*_index.
arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
if (knownNonNegative)
ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
for (int64_t elem : llvm::reverse(staticBasis)) {
if (ShapedType::isDynamic(elem)) {
// Note: basis elements and their products are, definitionally,
// non-negative, so `nuw` is justified.
if (dynamicPart)
dynamicPart = rewriter.create<arith::MulIOp>(
loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
Expand All @@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
Value stride =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
stride =
rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
Expand Down Expand Up @@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/true);

Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);

Expand All @@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
// If the correction is relevant, this term is <= stride, which is known
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
// this branch won't be taken, so the risk of `poison` is fine.
Value corrected = rewriter.create<arith::AddIOp>(
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
Expand Down Expand Up @@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
staticBasis = staticBasis.drop_front();

SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);

Expand All @@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
Value scaledIdx =
rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
Value scaledIdx = rewriter.create<arith::MulIOp>(
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
Expand All @@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
Expand Down
54 changes: 35 additions & 19 deletions 54 mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]]
// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]]
// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]]
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]] overflow<nsw>
// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]]
// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]]
// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]]
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]] overflow<nsw>
// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
// CHECK: return %[[N]], %[[P]], %[[Q]]
func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) {
Expand All @@ -30,16 +30,16 @@ func.func @delinearize_static_basis(%linear_index: index) -> (index, index, inde
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] :
// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] :
// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]]
// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]] overflow<nsw, nuw>
// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]]
// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]]
// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]]
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]]
// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] overflow<nsw>
// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]]
// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[DIM2]]
// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]]
// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]]
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]]
// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] overflow<nsw>
// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]]
// CHECK: return %[[N]], %[[P]], %[[Q]]
func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (index, index, index) {
Expand All @@ -58,10 +58,10 @@ func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref<?x?x?xf3
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]]
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]]
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]] overflow<nsw>
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]] overflow<nsw>
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
// CHECK: return %[[val_1]]
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
Expand All @@ -72,11 +72,11 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {

// CHECK-LABEL: @linearize_dynamic
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]]
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]]
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]]
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]]
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]]
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw>
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw>
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw>
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
// CHECK: return %[[val_1]]
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
// Note: no outer bounds
Expand All @@ -86,17 +86,33 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in

// -----

// CHECK-LABEL: @linearize_dynamic_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] overflow<nsw, nuw>
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] overflow<nsw>
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] overflow<nsw>
// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] overflow<nsw>
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] overflow<nsw>
// CHECK: return %[[val_1]]
func.func @linearize_dynamic_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
// Note: no outer bounds
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
func.return %0 : index
}

// -----

// CHECK-LABEL: @linearize_sort_adds
// CHECK-SAME: (%[[arg0:.+]]: memref<?xi32>, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK: scf.for %[[arg3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} {
// CHECK: scf.for %[[arg4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} {
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]]
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]]
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]]
// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] overflow<nsw, nuw>
// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] overflow<nsw>
// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] overflow<nsw>
// Note: even though %arg3 has a lower stride, we add it first
// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]]
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]]
// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] overflow<nsw>
// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] overflow<nsw>
// CHECK: memref.store %{{.*}}, %[[arg0]][%[[val_1]]]
func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0 = arith.constant 0 : index
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.