Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Fix support for complex types in transform.structured.pad #139841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from

Conversation

christopherbate
Copy link
Contributor

Fixes verification of the pad element attribute when the operand(s) have element type complex<...>.

Fixes verification of the pad element attribute when the operand(s)
have element type `complex<...>`.
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

Fixes verification of the pad element attribute when the operand(s) have element type complex&lt;...&gt;.


Full diff: https://github.com/llvm/llvm-project/pull/139841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-4)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad.mlir (+37)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..ea02886c1b65a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     SmallVector<Attribute> paddingValues;
     for (auto const &it :
          llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
-      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
-      if (!attr) {
-        emitOpError("expects padding values to be typed attributes");
+      Attribute attr = std::get<0>(it);
+      if (!llvm::isa<TypedAttr, ArrayAttr>(attr)) {
+        emitOpError("expects padding values to be typed attributes or array "
+                    "attributes (for complex numbers)");
         return DiagnosedSilenceableFailure::definiteFailure();
       }
       Type elementType = getElementTypeOrSelf(std::get<1>(it));
@@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
         continue;
       }
       // Otherwise, add the attribute directly.
-      if (attr.getType() != elementType) {
+      if (isa<TypedAttr>(attr) &&
+          cast<TypedAttr>(attr).getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      if (isa<ArrayAttr>(attr) && !isa<ComplexType>(elementType)) {
         auto diag = this->emitOpError("expects a padding value of type ")
                     << elementType << ", got " << attr;
         diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..c838713f368a3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!type = tensor<10x10xcomplex<f32>>
+// CHECK-LABEL: @pad_matmul
+func.func @pad_matmul(%arg0: !type,
+                           %arg1: !type,
+                           %arg2: !type
+                           ) -> !type {  
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: complex.constant [{{.*}}] : complex<f32>
+  // CHECK: tensor.pad
+  // CHECK: tensor.yield
+  // CHECK: linalg.matmul
+  %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type
+  func.return %0 : !type
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] {
+      padding_values=[
+        [0.1 : f32, 0.2 : f32],
+        [0.3 : f32, 0.4 : f32],
+        [0.5 : f32, 0.6 : f32]
+      ],      
+      padding_dimensions = [0, 1, 2]      
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.