-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Md Abdullah Shahneous Bari (mshahneo) Changes…er_as_index to SPIR-V Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option. Full diff: https://github.com/llvm/llvm-project/pull/86750.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81b9f55cac80f7..0ec3ad700fe807 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
}
};
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -922,6 +933,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type indexType = typeConverter.getIndexType();
+ rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+ adaptor.getSource());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -929,10 +954,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
- ReinterpretCastPattern, CastPattern>(typeConverter,
- patterns.getContext());
+ patterns
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 10c03a270005f1..bc2af8b6edadcc 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
@@ -414,6 +415,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+
// -----
// Check nontemporal attribute
|
Just a friendly ping, @kuhar, @antiagainst :) |
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> { | ||
} | ||
}; | ||
|
||
/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. | ||
class ExtractAlignedPointerAsIndexOpPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ExtractAlignedPointerAsIndexOpPattern | |
class ExtractAlignedPointerAsIndexOpPattern final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32 | ||
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64 | ||
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index | ||
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you group CHECK and CHECK64 together? It's a bit hard to read when interleaved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32 | ||
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64 | ||
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index | ||
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// ----- | ||
|
||
module attributes { | ||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvertPtrToU
requires the PhysicalStorageBufferAddresses
capability, no?
@@ -1,4 +1,5 @@ | ||
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s | ||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s | ||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s | |
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \ | |
// RUN: | FileCheck --check-prefix=CHECK64 %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…er_as_index to SPIR-V Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.
ca371a8
to
f120ca2
Compare
I am so sorry, Jakub (@kuhar). For some reason, I completely forgot about this. Updated the PR to address your comments. |
Hi @kuhar , just a friendly ping :) |
…er_as_index to SPIR-V
Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.