diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index fbe7593420102..ea02886c1b65a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1901,9 +1901,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { - auto attr = dyn_cast(std::get<0>(it)); - if (!attr) { - emitOpError("expects padding values to be typed attributes"); + Attribute attr = std::get<0>(it); + if (!llvm::isa(attr)) { + emitOpError("expects padding values to be typed attributes or array " + "attributes (for complex numbers)"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); @@ -1922,7 +1923,14 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, continue; } // Otherwise, add the attribute directly. - if (attr.getType() != elementType) { + if (isa(attr) && + cast(attr).getType() != elementType) { + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; + diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + if (isa(attr) && !isa(elementType)) { auto diag = this->emitOpError("expects a padding value of type ") << elementType << ", got " << attr; diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index ab2711545405e..c838713f368a3 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -419,3 +419,40 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +!type = tensor<10x10xcomplex> +// CHECK-LABEL: @pad_matmul +func.func @pad_matmul(%arg0: !type, + %arg1: !type, + %arg2: !type + ) -> !type { + // CHECK: complex.constant [{{.*}}] : complex + // CHECK: tensor.pad + // CHECK: tensor.yield + // CHECK: complex.constant [{{.*}}] : complex + // CHECK: tensor.pad + // CHECK: tensor.yield + // CHECK: complex.constant [{{.*}}] : complex + // CHECK: tensor.pad + // CHECK: tensor.yield + // CHECK: linalg.matmul + %0 = linalg.matmul ins(%arg0, %arg1 : !type, !type) outs(%arg2 : !type) -> !type + func.return %0 : !type +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [3, 3, 3] { + padding_values=[ + [0.1 : f32, 0.2 : f32], + [0.3 : f32, 0.4 : f32], + [0.5 : f32, 0.6 : f32] + ], + padding_dimensions = [0, 1, 2] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +}