diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt index 1c679bcd049b8..3de3ec3f3a0e8 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt @@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon) set(LLVM_TARGET_DEFINITIONS ArmNeon.td) mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRArmNeonConversionsIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h new file mode 100644 index 0000000000000..ad57e6e1f8ff4 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// ArmNeon Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_neon { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_neon +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td new file mode 100644 index 0000000000000..2dd9cf466fd3f --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td @@ -0,0 +1,27 @@ +//===- ArmNeonVectorTransformOps.td - Arm Neon TD ops ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ARM_NEON_VECTOR_TRANSFORM_OPS +#define ARM_NEON_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +def ApplyArmNeonContractionToI8MMPatternsOp + : Op]> { + let description = [{ + Indicates that vector.contract operations should be lowered to + finer-grained vector primitives from the ArmNeon dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARM_NEON_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..b8bc72a2bb734 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td) +mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen) + +add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 37e4904cb48ed..619ac88ad76d3 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -34,6 +34,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); + arm_neon::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp new file mode 100644 index 0000000000000..e81fc6a8b5980 --- /dev/null +++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp @@ -0,0 +1,54 @@ +//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" + +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArmNeonVectorTransformDialectExtension + : public transform::TransformDialectExtension< + ArmNeonVectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ArmNeonVectorTransformDialectExtension) + + ArmNeonVectorTransformDialectExtension() { + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc" + +void mlir::arm_neon::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..69d2143ad4e1f --- /dev/null +++ b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRArmNeonVectorTransformOps + ArmNeonVectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps + + DEPENDS + MLIRArmNeonVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRTransformDialect + MLIRArmNeonDialect + MLIRArmNeonTransforms + ) diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir index 297be91e77283..8cc494df044f5 100644 --- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir +++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter %s | FileCheck %s // CHECK-LABEL: vector_arm_neon_mixed_types // CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32> @@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2 %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32> return %res : vector<1x2xi32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.vector.arm_neon.contraction_to_i8mm + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt deleted file mode 100644 index 460842d238533..0000000000000 --- a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -# Exclude tests from libMLIR.so -add_mlir_library(MLIRArmNeonTestPasses - TestLowerToArmNeon.cpp - - EXCLUDE_FROM_LIBMLIR - ) -mlir_target_link_libraries(MLIRArmNeonTestPasses PUBLIC - MLIRArmNeonDialect - MLIRArmNeonTransforms - MLIRIR - MLIRPass - MLIRTransforms - ) diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp deleted file mode 100644 index 03c80b601a347..0000000000000 --- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp +++ /dev/null @@ -1,60 +0,0 @@ -//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass for testing the lowering to ArmNeon as a -// generally usable sink pass. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmNeon/Transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define PASS_NAME "test-lower-to-arm-neon" - -using namespace mlir; -using namespace mlir::arm_neon; - -namespace { -struct TestLowerToArmNeon - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon) - - StringRef getArgument() const final { return PASS_NAME; } - StringRef getDescription() const final { return "Tests lower to arm Neon."; } - TestLowerToArmNeon() = default; - TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override; -}; - -} // namespace - -void TestLowerToArmNeon::runOnOperation() { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - populateLowerContractionToSMMLAPatternPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); -} - -namespace mlir { -namespace test { - -void registerTestLowerToArmNeon() { PassRegistration(); } - -} // namespace test -} // namespace mlir diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index a8fd70e6397a5..5614237d80f02 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(Affine) add_subdirectory(Arith) -add_subdirectory(ArmNeon) add_subdirectory(ArmSME) add_subdirectory(Bufferization) add_subdirectory(ControlFlow) diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 3220dca282eac..5256cf7ae90d7 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS) MLIRTestFuncToLLVM MLIRAffineTransformsTestPasses MLIRArithTestPasses - MLIRArmNeonTestPasses MLIRArmSMETestPasses MLIRBufferizationTestPasses MLIRControlFlowTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index cdcf59b2add13..aa9c33dd9150c 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass(); void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); -void registerTestLowerToArmNeon(); void registerTestLowerToArmSME(); void registerTestLowerToLLVM(); void registerTestMakeIsolatedFromAbovePass(); @@ -264,7 +263,6 @@ void registerTestPasses() { mlir::test::registerTestLoopFusion(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); - mlir::test::registerTestLowerToArmNeon(); mlir::test::registerTestLowerToArmSME(); mlir::test::registerTestLowerToLLVM(); mlir::test::registerTestMakeIsolatedFromAbovePass();