Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Fix support for complex types in transform.structured.pad #139841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from

Conversation

christopherbate
Copy link
Contributor

Fixes verification of the pad element attribute when the operand(s) have element type complex<...>.

Fixes verification of the pad element attribute when the operand(s)
have element type `complex<...>`.
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

Fixes verification of the pad element attribute when the operand(s) have element type complex&lt;...&gt;.


Full diff: https://github.com/llvm/llvm-project/pull/139841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-4)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+37)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..ea02886c1b65a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     SmallVector<Attribute> paddingValues;
     for (auto const &it :
          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
-      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
-      if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+      Attribute attr = std::get<0>(it);
+      if (!llvm::isa<TypedAttr, ArrayAttr>(attr)) {
+        emitOpError("expects padding values to be typed attributes or array "
+                    "attributes (for complex numbers)");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
       Type elementType = getElementTypeOrSelf(std::get<1>(it));
@@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
         continue;
       }
       // Otherwise, add the attribute directly.
-      if (attr.getType() != elementType) {
+      if (isa<TypedAttr>(attr) &&
+          cast<TypedAttr>(attr).getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      if (isa<ArrayAttr>(attr) && !isa<ComplexType>(elementType)) {
         auto diag = this->emitOpError("expects a padding value of type ")
                     << elementType << ", got " << attr;
         diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..c838713f368a3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!type = tensor<10x10xcomplex<f32>>
+// CHECK-LABEL: @pad_matmul
+func.func @pad_matmul(%arg0: !type,
+                           %arg1: !type,
+                           %arg2: !type
+                           ) -> !type {  
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: linalg.matmul
+  %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type
+  func.return %0 : !type
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] {
+      padding_values=[
+        [0.1 : f32, 0.2 : f32],
+        [0.3 : f32, 0.4 : f32],
+        [0.5 : f32, 0.6 : f32]
+      ],      
+      padding_dimensions = [0, 1, 2]      
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo comment.

Comment on lines +1926 to +1927
if (isa<TypedAttr>(attr) &&
cast<TypedAttr>(attr).getType() != elementType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isa followed by cast is an explicit anti-pattern in LLVM, please use dyn_cast combined with c++17 feature of declaring a value inside if.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding a completely new test case, could you re-use or modify one of the existing ones? Otherwise, it’s harder to see what makes the new test "special" or "unique" - modifying an existing test often makes it easier to highlight the impact of a change.

That said, I don’t see any truly “basic” cases for padding linalg.matmul, so it’s reasonable if you need to create your own. Alternatively, you might be able to re-use one of the pad_to_multiple_of tests. I'm not very familiar with transform.structured.pad, but do you actually rely on pad_to_multiple_of to verify this change?

Separately, please make sure your test is self-documenting - for example, consider renaming @pad_matmul to something more descriptive. You can find relevant testing guidelines here:

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.