-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR] Add apply_patterns.vector.arm_neon.lower_contraction TD Op #140251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
momchil-velikov
wants to merge
1
commit into
llvm:main
Choose a base branch
from
momchil-velikov:neon-contract-pattern-transform
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[MLIR] Add apply_patterns.vector.arm_neon.lower_contraction TD Op #140251
momchil-velikov
wants to merge
1
commit into
llvm:main
from
momchil-velikov:neon-contract-pattern-transform
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This patch wraps `populateLowerContractionToSMMLAPatternPatterns` into a new TD Op `apply_patterns.vector.arm_neon.lower_contraction`. It also removes the "test-lower-to-arm-neon" pass.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Momchil Velikov (momchil-velikov) ChangesThis patch wraps It also removes the "test-lower-to-arm-neon" pass. Full diff: https://github.com/llvm/llvm-project/pull/140251.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 1c679bcd049b8..3de3ec3f3a0e8 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
new file mode 100644
index 0000000000000..5bc03535a86c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmNeon Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_neon {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_neon
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
new file mode 100644
index 0000000000000..f863ccaea3765
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmNeonTransformOps.td - Arm Neon transform ops------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMNEON_TRANSFORM_OPS
+#define ARMNEON_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmNeonLowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_neon.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmNeon dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMNEON_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..b8bc72a2bb734
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
+mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..619ac88ad76d3 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b096c2cbc503f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmNeonLowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmNeonVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmNeonVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmNeonVectorTransformDialectExtension)
+
+ ArmNeonVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_neon::ArmNeonDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+
+void mlir::arm_neon::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmNeonVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..69d2143ad4e1f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIRArmNeonVectorTransformOps
+ ArmNeonVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps
+
+ DEPENDS
+ MLIRArmNeonVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmNeonDialect
+ MLIRArmNeonTransforms
+ )
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 297be91e77283..ccad307e89dfb 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -transform-interpreter %s | FileCheck %s
// CHECK-LABEL: vector_arm_neon_mixed_types
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
@@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
return %res : vector<1x2xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_neon.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
deleted file mode 100644
index 460842d238533..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-# Exclude tests from libMLIR.so
-add_mlir_library(MLIRArmNeonTestPasses
- TestLowerToArmNeon.cpp
-
- EXCLUDE_FROM_LIBMLIR
- )
-mlir_target_link_libraries(MLIRArmNeonTestPasses PUBLIC
- MLIRArmNeonDialect
- MLIRArmNeonTransforms
- MLIRIR
- MLIRPass
- MLIRTransforms
- )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
deleted file mode 100644
index 03c80b601a347..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass for testing the lowering to ArmNeon as a
-// generally usable sink pass.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmNeon/Transforms.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define PASS_NAME "test-lower-to-arm-neon"
-
-using namespace mlir;
-using namespace mlir::arm_neon;
-
-namespace {
-struct TestLowerToArmNeon
- : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
-
- StringRef getArgument() const final { return PASS_NAME; }
- StringRef getDescription() const final { return "Tests lower to arm Neon."; }
- TestLowerToArmNeon() = default;
- TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arm_neon::ArmNeonDialect>();
- }
-
- void runOnOperation() override;
-};
-
-} // namespace
-
-void TestLowerToArmNeon::runOnOperation() {
- MLIRContext *context = &getContext();
- RewritePatternSet patterns(context);
- populateLowerContractionToSMMLAPatternPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
- return signalPassFailure();
-}
-
-namespace mlir {
-namespace test {
-
-void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
-
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index a8fd70e6397a5..5614237d80f02 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
-add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)
add_subdirectory(Bufferization)
add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 3220dca282eac..5256cf7ae90d7 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
- MLIRArmNeonTestPasses
MLIRArmSMETestPasses
MLIRBufferizationTestPasses
MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index cdcf59b2add13..aa9c33dd9150c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
-void registerTestLowerToArmNeon();
void registerTestLowerToArmSME();
void registerTestLowerToLLVM();
void registerTestMakeIsolatedFromAbovePass();
@@ -264,7 +263,6 @@ void registerTestPasses() {
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
- mlir::test::registerTestLowerToArmNeon();
mlir::test::registerTestLowerToArmSME();
mlir::test::registerTestLowerToLLVM();
mlir::test::registerTestMakeIsolatedFromAbovePass();
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This patch wraps
populateLowerContractionToSMMLAPatternPatterns
into a new TD Opapply_patterns.vector.arm_neon.lower_contraction
.It also removes the "test-lower-to-arm-neon" pass.