-
Notifications
You must be signed in to change notification settings - Fork 14.4k
Fix support for complex types in transform.structured.pad
#139841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix support for complex types in transform.structured.pad
#139841
Conversation
Fixes verification of the pad element attribute when the operand(s) have element type `complex<...>`.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Christopher Bate (christopherbate) ChangesFixes verification of the pad element attribute when the operand(s) have element type Full diff: https://github.com/llvm/llvm-project/pull/139841.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index fbe7593420102..ea02886c1b65a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Attribute> paddingValues;
for (auto const &it :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
- if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ Attribute attr = std::get<0>(it);
+ if (!llvm::isa<TypedAttr, ArrayAttr>(attr)) {
+ emitOpError("expects padding values to be typed attributes or array "
+ "attributes (for complex numbers)");
return DiagnosedSilenceableFailure::definiteFailure();
}
Type elementType = getElementTypeOrSelf(std::get<1>(it));
@@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
continue;
}
// Otherwise, add the attribute directly.
- if (attr.getType() != elementType) {
+ if (isa<TypedAttr>(attr) &&
+ cast<TypedAttr>(attr).getType() != elementType) {
+ auto diag = this->emitOpError("expects a padding value of type ")
+ << elementType << ", got " << attr;
+ diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ if (isa<ArrayAttr>(attr) && !isa<ComplexType>(elementType)) {
auto diag = this->emitOpError("expects a padding value of type ")
<< elementType << ", got " << attr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..c838713f368a3 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+!type = tensor<10x10xcomplex<f32>>
+// CHECK-LABEL: @pad_matmul
+func.func @pad_matmul(%arg0: !type,
+ %arg1: !type,
+ %arg2: !type
+ ) -> !type {
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: complex.constant [{{.*}}] : complex<f32>
+ // CHECK: tensor.pad
+ // CHECK: tensor.yield
+ // CHECK: linalg.matmul
+ %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type
+ func.return %0 : !type
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] {
+ padding_values=[
+ [0.1 : f32, 0.2 : f32],
+ [0.3 : f32, 0.4 : f32],
+ [0.5 : f32, 0.6 : f32]
+ ],
+ padding_dimensions = [0, 1, 2]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo comment.
if (isa<TypedAttr>(attr) && | ||
cast<TypedAttr>(attr).getType() != elementType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isa
followed by cast
is an explicit anti-pattern in LLVM, please use dyn_cast
combined with c++17 feature of declaring a value inside if
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than adding a completely new test case, could you re-use or modify one of the existing ones? Otherwise, it’s harder to see what makes the new test "special" or "unique" - modifying an existing test often makes it easier to highlight the impact of a change.
That said, I don’t see any truly “basic” cases for padding linalg.matmul, so it’s reasonable if you need to create your own. Alternatively, you might be able to re-use one of the pad_to_multiple_of
tests. I'm not very familiar with transform.structured.pad
, but do you actually rely on pad_to_multiple_of
to verify this change?
Separately, please make sure your test is self-documenting - for example, consider renaming @pad_matmul
to something more descriptive. You can find relevant testing guidelines here:
Thanks!
Fixes verification of the pad element attribute when the operand(s) have element type
complex<...>
.