Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
Loading
from

Conversation

mshahneo
Copy link
Contributor

…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.


Full diff: https://github.com/llvm/llvm-project/pull/86750.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+31-5)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+39-1)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81b9f55cac80f7..0ec3ad700fe807 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
   }
 };
 
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern
+    : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -922,6 +933,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+    memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  Type indexType = typeConverter.getIndexType();
+  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+                                                      adaptor.getSource());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -929,10 +954,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
-               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
-               ReinterpretCastPattern, CastPattern>(typeConverter,
-                                                    patterns.getContext());
+  patterns
+      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+           CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+          typeConverter, patterns.getContext());
 }
 } // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 10c03a270005f1..bc2af8b6edadcc 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
@@ -414,6 +415,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 
 }
 
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+
 // -----
 
 // Check nontemporal attribute

@mshahneo
Copy link
Contributor Author

Just a friendly ping, @kuhar, @antiagainst :)

@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
}
};

/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
class ExtractAlignedPointerAsIndexOpPattern
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ExtractAlignedPointerAsIndexOpPattern
class ExtractAlignedPointerAsIndexOpPattern final

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

Comment on lines 426 to 436
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you group CHECK and CHECK64 together? It's a bit hard to read when interleaved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 444 to 454
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvertPtrToU requires the PhysicalStorageBufferAddresses capability, no?

@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \
// RUN: | FileCheck --check-prefix=CHECK64 %s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
mshahneo added 2 commits May 6, 2025 16:50
…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
Index conversion is done based on 'use-64bit-index' option.
@mshahneo mshahneo force-pushed the extract-aligned-pointer-as-index-to-spirv branch from ca371a8 to f120ca2 Compare May 6, 2025 17:17
@mshahneo
Copy link
Contributor Author

mshahneo commented May 6, 2025

I am so sorry, Jakub (@kuhar). For some reason, I completely forgot about this. Updated the PR to address your comments.
Again, I am so so sorry.

@mshahneo
Copy link
Contributor Author

Hi @kuhar , just a friendly ping :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.