Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[MLIR] Add apply_patterns.vector.arm_neon.lower_contraction TD Op #140251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from

Conversation

momchil-velikov
Copy link
Collaborator

This patch wraps populateLowerContractionToSMMLAPatternPatterns into a new TD Op apply_patterns.vector.arm_neon.lower_contraction.

It also removes the "test-lower-to-arm-neon" pass.

This patch wraps `populateLowerContractionToSMMLAPatternPatterns` into
a new TD Op `apply_patterns.vector.arm_neon.lower_contraction`.

It also removes the "test-lower-to-arm-neon" pass.
@llvmbot
Copy link
Member

llvmbot commented May 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Momchil Velikov (momchil-velikov)

Changes

This patch wraps populateLowerContractionToSMMLAPatternPatterns into a new TD Op apply_patterns.vector.arm_neon.lower_contraction.

It also removes the "test-lower-to-arm-neon" pass.


Full diff: https://github.com/llvm/llvm-project/pull/140251.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt (+2)
  • (added) mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h (+31)
  • (added) mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td (+26)
  • (added) mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt (+6)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Dialect/ArmNeon/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp (+54)
  • (added) mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt (+18)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+13-1)
  • (removed) mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt (-13)
  • (removed) mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp (-60)
  • (modified) mlir/test/lib/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (-1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 1c679bcd049b8..3de3ec3f3a0e8 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
 set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
 mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
new file mode 100644
index 0000000000000..5bc03535a86c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmNeon Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_neon {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace arm_neon
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
new file mode 100644
index 0000000000000..f863ccaea3765
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmNeonTransformOps.td - Arm Neon transform ops------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMNEON_TRANSFORM_OPS
+#define ARMNEON_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmNeonLowerContractionPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.vector.arm_neon.lower_contraction",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contraction-like operations should be lowered to
+    finer-grained vector primitives using the ArmNeon dialect.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMNEON_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..b8bc72a2bb734
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
+mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..619ac88ad76d3 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  arm_neon::registerTransformDialectExtension(registry);
 
   // Translation extensions need to be registered by calling
   // `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b096c2cbc503f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmNeonLowerContractionPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmNeonVectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          ArmNeonVectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      ArmNeonVectorTransformDialectExtension)
+
+  ArmNeonVectorTransformDialectExtension() {
+    declareGeneratedDialect<arm_neon::ArmNeonDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+
+void mlir::arm_neon::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<ArmNeonVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..69d2143ad4e1f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIRArmNeonVectorTransformOps
+  ArmNeonVectorTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps
+
+  DEPENDS
+  MLIRArmNeonVectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRTransformDialect
+  MLIRArmNeonDialect
+  MLIRArmNeonTransforms
+  )
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 297be91e77283..ccad307e89dfb 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -transform-interpreter  %s | FileCheck %s
 
 // CHECK-LABEL: vector_arm_neon_mixed_types
 // CHECK-SAME:    %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
@@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
   return %res : vector<1x2xi32>
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.arm_neon.lower_contraction
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
deleted file mode 100644
index 460842d238533..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-# Exclude tests from libMLIR.so
-add_mlir_library(MLIRArmNeonTestPasses
-  TestLowerToArmNeon.cpp
-
-  EXCLUDE_FROM_LIBMLIR
-  )
-mlir_target_link_libraries(MLIRArmNeonTestPasses PUBLIC
-  MLIRArmNeonDialect
-  MLIRArmNeonTransforms
-  MLIRIR
-  MLIRPass
-  MLIRTransforms
-  )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
deleted file mode 100644
index 03c80b601a347..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass for testing the lowering to ArmNeon as a
-// generally usable sink pass.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmNeon/Transforms.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define PASS_NAME "test-lower-to-arm-neon"
-
-using namespace mlir;
-using namespace mlir::arm_neon;
-
-namespace {
-struct TestLowerToArmNeon
-    : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
-
-  StringRef getArgument() const final { return PASS_NAME; }
-  StringRef getDescription() const final { return "Tests lower to arm Neon."; }
-  TestLowerToArmNeon() = default;
-  TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arm_neon::ArmNeonDialect>();
-  }
-
-  void runOnOperation() override;
-};
-
-} // namespace
-
-void TestLowerToArmNeon::runOnOperation() {
-  MLIRContext *context = &getContext();
-  RewritePatternSet patterns(context);
-  populateLowerContractionToSMMLAPatternPatterns(patterns);
-  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
-    return signalPassFailure();
-}
-
-namespace mlir {
-namespace test {
-
-void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
-
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index a8fd70e6397a5..5614237d80f02 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_subdirectory(Affine)
 add_subdirectory(Arith)
-add_subdirectory(ArmNeon)
 add_subdirectory(ArmSME)
 add_subdirectory(Bufferization)
 add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 3220dca282eac..5256cf7ae90d7 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestFuncToLLVM
     MLIRAffineTransformsTestPasses
     MLIRArithTestPasses
-    MLIRArmNeonTestPasses
     MLIRArmSMETestPasses
     MLIRBufferizationTestPasses
     MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index cdcf59b2add13..aa9c33dd9150c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass();
 void registerTestLoopFusion();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
-void registerTestLowerToArmNeon();
 void registerTestLowerToArmSME();
 void registerTestLowerToLLVM();
 void registerTestMakeIsolatedFromAbovePass();
@@ -264,7 +263,6 @@ void registerTestPasses() {
   mlir::test::registerTestLoopFusion();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
-  mlir::test::registerTestLowerToArmNeon();
   mlir::test::registerTestLowerToArmSME();
   mlir::test::registerTestLowerToLLVM();
   mlir::test::registerTestMakeIsolatedFromAbovePass();

@KoolJBlack KoolJBlack removed their request for review May 16, 2025 13:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:neon mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.