Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8c74dc1

Browse filesBrowse files
authored
[MLIR][Affine] Fix affine.apply verifier and add functionality to demote invalid symbols to dims (#128289)
Fixes: #120189, #128403 Fix affine.apply verifier to reject symbolic operands that are valid dims for affine purposes. This doesn't affect other users in other contexts where the operands were neither valid dims or symbols (for eg. in scf.for or other region ops). Otherwise, it was possible for `-canonicalize` to have generated invalid IR when such affine.apply ops were composed. Introduce a method to demote a symbolic operand to a dimensional one (the inverse of the current canonicalizePromotedSymbols). Demote operands that could/should have been valid affine dimensional values (affine loop IVs or their functions) from symbols to dims. This is a general method that can be used to legalize a map + operands post construction depending on its operands. Use it during `canonicalizeMapOrSetAndOperands` so that pattern rewriter-based passes are able to generate valid IR post folding. Users outside of affine analyses/dialects remain unaffected. In some cases, this change also leads to better simplified operands, duplicates eliminated as shown in one of the test cases where the same operand appeared as a symbol and as a dim. This commit also fixes test cases where dimensional positions should have been ideally used with affine.apply (for affine loop IVs for example).
1 parent 8dbf92e commit 8c74dc1
Copy full SHA for 8c74dc1

File tree

Expand file treeCollapse file tree

5 files changed

+105
-26
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+105
-26
lines changed

‎mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Copy file name to clipboardExpand all lines: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+8-4Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,25 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
4040
let description = [{
4141
The `affine.apply` operation applies an [affine mapping](#affine-maps)
4242
to a list of SSA values, yielding a single SSA value. The number of
43-
dimension and symbol arguments to `affine.apply` must be equal to the
43+
dimension and symbol operands to `affine.apply` must be equal to the
4444
respective number of dimensional and symbolic inputs to the affine mapping;
4545
the affine mapping has to be one-dimensional, and so the `affine.apply`
4646
operation always returns one value. The input operands and result must all
4747
have ‘index’ type.
4848

49+
An operand that is a valid dimension as per the [rules on valid affine
50+
dimensions and symbols](#restrictions-on-dimensions-and-symbols)
51+
cannot be used as a symbolic operand.
52+
4953
Example:
5054

5155
```mlir
52-
#map10 = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)>
56+
#map = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)>
5357
...
54-
%1 = affine.apply #map10 (%s, %t)
58+
%1 = affine.apply #map (%s, %t)
5559

5660
// Inline example.
57-
%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
61+
%2 = affine.apply affine_map<(i)[s0] -> (i + s0)> (%42)[%n]
5862
```
5963
}];
6064
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);

‎mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Copy file name to clipboardExpand all lines: mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+63-2Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ LogicalResult AffineApplyOp::verify() {
578578
if (affineMap.getNumResults() != 1)
579579
return emitOpError("mapping must produce one value");
580580

581+
// Do not allow valid dims to be used in symbol positions. We do allow
582+
// affine.apply to use operands for values that may neither qualify as affine
583+
// dims or affine symbols due to usage outside of affine ops, analyses, etc.
584+
Region *region = getAffineScope(*this);
585+
for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
586+
if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
587+
return emitError("dimensional operand cannot be used as a symbol");
588+
}
589+
581590
return success();
582591
}
583592

@@ -1359,13 +1368,64 @@ static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
13591368

13601369
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
13611370
*operands = resultOperands;
1362-
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1363-
oldNumSyms + nextSym);
1371+
*mapOrSet = mapOrSet->replaceDimsAndSymbols(
1372+
dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
13641373

13651374
assert(mapOrSet->getNumInputs() == operands->size() &&
13661375
"map/set inputs must match number of operands");
13671376
}
13681377

1378+
/// A valid affine dimension may appear as a symbol in affine.apply operations.
1379+
/// Given an application of `operands` to an affine map or integer set
1380+
/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1381+
/// dims, but not valid symbols into actual dims. Without such a legalization,
1382+
/// the affine.apply will be invalid. This method is the exact inverse of
1383+
/// canonicalizePromotedSymbols.
1384+
template <class MapOrSet>
1385+
static void legalizeDemotedDims(MapOrSet &mapOrSet,
1386+
SmallVectorImpl<Value> &operands) {
1387+
if (!mapOrSet || operands.empty())
1388+
return;
1389+
1390+
unsigned numOperands = operands.size();
1391+
1392+
assert(mapOrSet->getNumInputs() == numOperands &&
1393+
"map/set inputs must match number of operands");
1394+
1395+
auto *context = mapOrSet.getContext();
1396+
SmallVector<Value, 8> resultOperands;
1397+
resultOperands.reserve(numOperands);
1398+
SmallVector<Value, 8> remappedDims;
1399+
remappedDims.reserve(numOperands);
1400+
SmallVector<Value, 8> symOperands;
1401+
symOperands.reserve(mapOrSet.getNumSymbols());
1402+
unsigned nextSym = 0;
1403+
unsigned nextDim = 0;
1404+
unsigned oldNumDims = mapOrSet.getNumDims();
1405+
SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1406+
resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1407+
for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1408+
if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1409+
// This is a valid dim that appears as a symbol, legalize it.
1410+
symRemapping[i - oldNumDims] =
1411+
getAffineDimExpr(oldNumDims + nextDim++, context);
1412+
remappedDims.push_back(operands[i]);
1413+
} else {
1414+
symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1415+
symOperands.push_back(operands[i]);
1416+
}
1417+
}
1418+
1419+
append_range(resultOperands, remappedDims);
1420+
append_range(resultOperands, symOperands);
1421+
operands = resultOperands;
1422+
mapOrSet = mapOrSet.replaceDimsAndSymbols(
1423+
/*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1424+
1425+
assert(mapOrSet->getNumInputs() == operands.size() &&
1426+
"map/set inputs must match number of operands");
1427+
}
1428+
13691429
// Works for either an affine map or an integer set.
13701430
template <class MapOrSet>
13711431
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
@@ -1380,6 +1440,7 @@ static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
13801440
"map/set inputs must match number of operands");
13811441

13821442
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1443+
legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
13831444

13841445
// Check to see what dims are used.
13851446
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());

‎mlir/test/Dialect/Affine/canonicalize.mlir

Copy file name to clipboardExpand all lines: mlir/test/Dialect/Affine/canonicalize.mlir
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,8 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
14601460
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
14611461
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
14621462
affine.for %arg3 = 0 to 8 {
1463-
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
1464-
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
1463+
%1 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3)
1464+
// CHECK: affine.prefetch [[PARAM_0_]][[[I_0_]] * 64], read, locality<3>, data : memref<512xf32>
14651465
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
14661466
}
14671467
return

‎mlir/test/Dialect/Affine/invalid.mlir

Copy file name to clipboardExpand all lines: mlir/test/Dialect/Affine/invalid.mlir
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,17 @@ func.func @no_upper_bound() {
563563
}
564564
return
565565
}
566+
567+
// -----
568+
569+
func.func @invalid_symbol() {
570+
affine.for %arg1 = 0 to 1 {
571+
affine.for %arg2 = 0 to 26 {
572+
affine.for %arg3 = 0 to 23 {
573+
affine.apply affine_map<()[s0, s1] -> (s0 * 23 + s1)>()[%arg1, %arg3]
574+
// expected-error@above {{dimensional operand cannot be used as a symbol}}
575+
}
576+
}
577+
}
578+
return
579+
}

‎mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Copy file name to clipboardExpand all lines: mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+18-18Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
496496

497497
// -----
498498

499-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
501501
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502502
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503503
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -518,16 +518,16 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
518518
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519519
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520520
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521-
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522-
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
523523
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524-
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])[%[[ARG2]]]
525525
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526526

527527
// -----
528528

529-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
531531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
532532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
533533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -549,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
549549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
550550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
551551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
552-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
554554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
555555

556556
// -----
557557

558-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
560560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
561561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
562562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -578,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
578578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
579579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
580580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
581-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
583583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
584584

585585
// -----
586586

587-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
589589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
590590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
591591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -608,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
608608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
609609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
610610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
611-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
613613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
614614

615615
// -----
@@ -678,7 +678,7 @@ func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
678678
// -----
679679

680680
// CHECK-LABEL: func @fold_store_keep_nontemporal(
681-
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
681+
// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32>
682682
func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
683683
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
684684
memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.