Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mli][vector] canonicalize vector.from_elements from ascending extracts #139819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
Loading
from

Conversation

newling
Copy link
Contributor

@newling newling commented May 14, 2025

Example:

%0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>

becomes

%2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>

It was decided that we should spill canonicalization tests into new files (see discussion) In view of this I added the new tests to a new file specifically for canonicalization of from_elements. To be consistent in the location of the tests, I moved existing tests extract_scalar_from_from_element, extract_1d_from_from_elements, extract_2d_from_from_elements and from_elements_to_splat from canonicalize.mlir to canonicalze/vector-from-elements.mlir. In addition to moving I changed the LIT variables to all be upper-case for consistency.

@newling newling changed the title [mli][vector] vector.from_elements canonicalizer when elements ascending extracts [mli][vector] canonicalize vector.from_elements from ascending extracts May 14, 2025
@newling newling force-pushed the from_from_elements_to_shape_cast branch from 394b6be to 28fcceb Compare May 14, 2025 19:39
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Example:

%0 = vector.extract %source[0, 0] : i8 from vector&lt;1x2xi8&gt;
%1 = vector.extract %source[0, 1] : i8 from vector&lt;1x2xi8&gt;
%2 = vector.from_elements %0, %1 : vector&lt;2xi8&gt;

becomes

%2 = vector.shape_cast %source : vector&lt;1x2xi8&gt; to vector&lt;2xi8&gt;

Full diff: https://github.com/llvm/llvm-project/pull/139819.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+90)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-69)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+155)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..1080263ed3eb6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
@@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElements,
+                                PatternRewriter &rewriter) const override {
+
+    mlir::OperandRange elements = fromElements.getElements();
+    assert(!elements.empty() && "must be at least 1 element");
+    Value firstElement = elements.front();
+
+    ExtractOp extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+    if (!extractOp) {
+      return rewriter.notifyMatchFailure(
+          fromElements, "first element not from vector.extract");
+    }
+    VectorType sourceType = extractOp.getSourceVectorType();
+    Value source = extractOp.getVector();
+
+    // Check condition (ii).
+    if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+      return rewriter.notifyMatchFailure(fromElements,
+                                         "number of elements differ");
+    }
+
+    for (auto [indexMinusOne, element] :
+         llvm::enumerate(elements.drop_front(1))) {
+
+      extractOp =
+          dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element not from vector.extract");
+      }
+      Value currentSource = extractOp.getVector();
+      // Check condition (i).
+      if (currentSource != source) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element from different vector");
+      }
+
+      ArrayRef<int64_t> position = extractOp.getStaticPosition();
+      assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
+             "scalar extract must have full rank position");
+      int64_t stride{1};
+      int64_t offset{0};
+      for (auto [pos, size] : llvm::zip(llvm::reverse(position),
+                                        llvm::reverse(sourceType.getShape()))) {
+        if (pos == ShapedType::kDynamic) {
+          return rewriter.notifyMatchFailure(
+              fromElements, "elements not in ascending order (dynamic order)");
+        }
+        offset += pos * stride;
+        stride *= size;
+      }
+      // Check condition (iii).
+      if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
+        return rewriter.notifyMatchFailure(
+            fromElements, "elements not in ascending order (static order)");
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
+                                             fromElements.getType(), source);
+    return success();
+  }
+};
+
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
+  results.add<FromElementsToShapCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..6af517d988360 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
 
 // -----
 
-// CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
-  // Extract from 0D.
-  %0 = vector.from_elements %a : vector<f32>
-  %1 = vector.extract %0[] : f32 from vector<f32>
-
-  // Extract from 1D.
-  %2 = vector.from_elements %a : vector<1xf32>
-  %3 = vector.extract %2[0] : f32 from vector<1xf32>
-  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
-  %5 = vector.extract %4[4] : f32 from vector<5xf32>
-
-  // Extract from 2D.
-  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
-  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
-  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
-
-  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
-  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
-  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
-  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<3xf32>, vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
-  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
-  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
-  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
-  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
-  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
-  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
-  %2 = vector.from_elements %a : vector<f32>
-  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
-  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
-}
-
-// -----
-
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
new file mode 100644
index 0000000000000..14bf5d9df4783
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -0,0 +1,155 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements
+
+///===----------------------------------------------===//
+///  Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----
+
+///===----------------------------------------------===//
+///  Tests of `FromElementsToShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+//  CHECK-SAME:       %[[a:.*]]: vector<1x2xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2xi8>
+func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %4 = vector.from_elements %0, %1 : vector<2xi8>
+  return %4 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
+//  CHECK-SAME:       %[[a:.*]]: vector<8xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2x2x2xi8>
+func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
+  %2 = vector.extract %arg0[2] : i8 from vector<8xi8>
+  %3 = vector.extract %arg0[3] : i8 from vector<8xi8>
+  %4 = vector.extract %arg0[4] : i8 from vector<8xi8>
+  %5 = vector.extract %arg0[5] : i8 from vector<8xi8>
+  %6 = vector.extract %arg0[6] : i8 from vector<8xi8>
+  %7 = vector.extract %arg0[7] : i8 from vector<8xi8>
+  %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
+  return %8 : vector<2x2x2xi8>
+}
+
+// -----
+
+// The extracted elements are recombined into a single vector, but in a new order.
+// CHECK-LABEL: func @negative_nonascending_order(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_nonstatic_extract(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_different_sources(
+//   CHECK-NOT: shape_cast
+func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_large(
+//   CHECK-NOT: shape_cast
+func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat, thank you!

The transformation makes sense to me - I've left a couple of comments, but these are minor.

I am unsure about making this a canonicalization and hopefully somebody else could chime in as well. Basically, the IR becomes much simpler (+1), but "data movement" Ops are replaced with an Op for which:

It is currently assumed that this operation does not require moving data"

This makes it a bit tricky for me. In my view, as long as this pattern composes nicely with other transformations (e.g. we don't re-materialise the data movement Ops to "lower" vector.shape_cast), it should be safe/fine to include it as a canonicalization.

Comment on lines +5 to +7
///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
///===----------------------------------------------===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section was copied, right? Could you add a note in the summary so that it's easy to track the history?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, copied. I added a comment to the PR summary, I assume that's where you meant?

Comment on lines 83 to 85
// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>)
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[shape_cast]] : vector<2xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you use caps for LIT variables? That's much more common. And, IMHO, easier to parse 😅

I suspect that you wanted to maintain consistency with the tests for rewriteFromElementsAsSplat? I would just update those as well (fortunately, there arent' that many LIT vars there)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easier, but still hard IMO 😆
Done

Comment on lines 2419 to 2424
ExtractOp extractOp =
dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(
fromElements, "first element not from vector.extract");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we check the first element separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I flip-flopped between having a conditional "if (index == 0) { // do the one off check }" inside the loop and doing it before the loop. But I've gone back to doing it in the loop now

Comment on lines 2428 to 2432
// Check condition (ii).
if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
return rewriter.notifyMatchFailure(fromElements,
"number of elements differ");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I would "rebrand" this as "Condition (i)" (it's the first condition to be checked) and move it all the way to the top - it feels like a fairly high level condition that deserves a special place :)

Comment on lines 2437 to 2448
extractOp =
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
if (!extractOp) {
return rewriter.notifyMatchFailure(fromElements,
"element not from vector.extract");
}
Value currentSource = extractOp.getVector();
// Check condition (i).
if (currentSource != source) {
return rewriter.notifyMatchFailure(fromElements,
"element from different vector");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To me all of this is checking "condition (i)" and everything that's left is checking "condition (ii)". I would just move the comments and basically split the loop body into 2 blocks.

@newling
Copy link
Contributor Author

newling commented May 15, 2025

@banach-space thanks for the thorough review! I'll implement your suggestions soon.

I am unsure about making this a canonicalization and hopefully somebody else could chime in as well. Basically, the IR becomes much simpler (+1), but "data movement" Ops are replaced with an Op for which:

It is currently assumed that this operation does not require moving data"

This makes it a bit tricky for me. In my view, as long as this pattern composes nicely with other transformations (e.g. we don't re-materialise the data movement Ops to "lower" vector.shape_cast), it should be safe/fine to include it as a canonicalization.

I share this concern. Specifically, if the lowering of vector.shape_cast were to create a vector.from_elements, then canonicalization would un-lower shape_cast. It currently doesn't lower like this, but it doesn't seem unreasonable.

I could take the conservative approach of making this a folder along the lines of #135841 (comment), so effectively converting to vector.shape_cast only if we know it will be absorbed immediately into another vector.shape_cast. One concern here is that this pattern is O(num elements in vector) and so not the fastest, and folding generally happens more often than canonicalizing. Another concern is that it goes against the idea suggested in #138777 of canonicalizing transpose/broadcast/extract to shape_cast where possible.

My personal preference would be to not avoid creating shape_cast. And to reconsider its lowering, as you suggested here. Perhaps shape_cast should be treated as the bottom rung on the lowering ladder, an op which gets converted directly to llvm/other (in the case where not all shape_casts have cancelled).

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! I've left few minor suggestion, otherwise LGTM

@@ -2385,9 +2386,105 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
return success();
}

/// Rewrite vector.from_elements as vector.shape_cast, if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ultra nit] "if possible" is implicit and "less is more" :)

Suggested change
/// Rewrite vector.from_elements as vector.shape_cast, if possible.
/// Rewrite vector.from_elements as vector.shape_cast.

/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
///
/// The requirements for this to be valid are
/// i) source and from_elements result have the same number of elements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Otherwise it's not clear what source and from_elements are. Perhaps there's better way to clarify 🤔

Suggested change
/// i) source and from_elements result have the same number of elements,
/// i) vector.extract and vector.from_elements result have the same number of elements,

Comment on lines +2404 to +2407
/// It might be possible to rewrite vector.from_elements as a single
/// vector.extract if (i) is not satisifed, or in some cases as a
/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied,
/// this is left for future consideration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we already have quite a few TODOs/FIXMEs that basically mean "let’s look at this later." But the phrasing "It might be possible…” feels particularly vague here - I’d suggest omitting it unless we can be more specific.

If we do want to leave a note, maybe something like:

“Consider extending to use a single vector.extract when (i) does not hold.”

Also, just a general thought: extending this pattern could quickly become quite complex. If we're seeing bad code that would benefit from such a complicated rewrite, it might be worth checking whether the producer of that code could be improved instead.

@banach-space
Copy link
Contributor

I share this concern. Specifically, if the lowering of vector.shape_cast were to create a vector.from_elements, then canonicalization would un-lower shape_cast. It currently doesn't lower like this, but it doesn't seem unreasonable.

Thanks for sharing your points - I agree with pretty much everything you've said.

Perhaps shape_cast should be treated as the bottom rung on the lowering ladder, an op which gets converted directly to llvm/other (in the case where not all shape_casts have cancelled).

Personally, I think we should aim for a design where vector.shape_cast is effectively a no-op. In that model, lowering it to LLVM would be considered an error. We're not quite there yet, and in the meantime, we need to continue supporting all current users who may rely on the existing behaviour. So I'm not suggesting any drastic changes just yet. :)

That said, I fully acknowledge that we might ultimately discover that "shape_cast as a no-op" isn't viable - and that would be a completely acceptable outcome. But with your recent patches, we're already converging in that direction, so I’m optimistic.

@dcaballe, any thoughts?

@dcaballe
Copy link
Contributor

This canonicalization makes sense to me. Thanks! I added some thoughts to #138777 (comment) that hopefully makes sense to you! I'll take a look at the actual changes later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.