Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[AArch64][SVE2] Lower read-after-write mask to whilerw #114028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

SamTebbs33
Copy link
Collaborator

This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Sam Tebbs (SamTebbs33)

Changes

This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.


Full diff: https://github.com/llvm/llvm-project/pull/114028.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+26-8)
  • (modified) llvm/test/CodeGen/AArch64/whilewr.ll (+133)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bf2f0674b5b65e..a2517761afc0c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14189,7 +14189,16 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
     return SDValue();
 
   SDValue Diff = Cmp.getOperand(0);
-  if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+  SDValue NonAbsDiff = Diff;
+  bool WriteAfterRead = true;
+  // A read-after-write will have an abs call on the diff
+  if (Diff.getOpcode() == ISD::ABS) {
+    NonAbsDiff = Diff.getOperand(0);
+    WriteAfterRead = false;
+  }
+
+  if (NonAbsDiff.getOpcode() != ISD::SUB ||
+      NonAbsDiff.getValueType() != MVT::i64)
     return SDValue();
 
   if (!isNullConstant(LaneMask.getOperand(1)) ||
@@ -14210,8 +14219,13 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
       // it's positive, otherwise the difference plus the element size if it's
       // negative: pos_diff = diff < 0 ? (diff + 7) : diff
       SDValue Select = DiffDiv.getOperand(0);
+      SDValue SelectOp3 = Select.getOperand(3);
+      // Check for an abs in the case of a read-after-write
+      if (!WriteAfterRead && SelectOp3.getOpcode() == ISD::ABS)
+        SelectOp3 = SelectOp3.getOperand(0);
+
       // Make sure the difference is being compared by the select
-      if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
+      if (Select.getOpcode() != ISD::SELECT_CC || SelectOp3 != NonAbsDiff)
         return SDValue();
       // Make sure it's checking if the difference is less than 0
       if (!isNullConstant(Select.getOperand(1)) ||
@@ -14243,22 +14257,26 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
   } else if (LaneMask.getOperand(2) != Diff)
     return SDValue();
 
-  SDValue StorePtr = Diff.getOperand(0);
-  SDValue ReadPtr = Diff.getOperand(1);
+  SDValue StorePtr = NonAbsDiff.getOperand(0);
+  SDValue ReadPtr = NonAbsDiff.getOperand(1);
 
   unsigned IntrinsicID = 0;
   switch (EltSize) {
   case 1:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_b
+                                 : Intrinsic::aarch64_sve_whilerw_b;
     break;
   case 2:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_h
+                                 : Intrinsic::aarch64_sve_whilerw_h;
     break;
   case 4:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_s
+                                 : Intrinsic::aarch64_sve_whilerw_s;
     break;
   case 8:
-    IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
+    IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_d
+                                 : Intrinsic::aarch64_sve_whilerw_d;
     break;
   default:
     return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/whilewr.ll b/llvm/test/CodeGen/AArch64/whilewr.ll
index 9f1ea850792384..bdea64296455eb 100644
--- a/llvm/test/CodeGen/AArch64/whilewr.ll
+++ b/llvm/test/CodeGen/AArch64/whilewr.ll
@@ -30,6 +30,36 @@ entry:
   ret <vscale x 16 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 16 x i1> @whilerw_8(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_8:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.b, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_8:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    cmp x8, #1
+; CHECK-NOSVE2-NEXT:    cset w9, lt
+; CHECK-NOSVE2-NEXT:    whilelo p0.b, xzr, x8
+; CHECK-NOSVE2-NEXT:    sbfx x8, x9, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.b, xzr, x8
+; CHECK-NOSVE2-NEXT:    sel p0.b, p0, p0.b, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %neg.compare = icmp slt i64 %0, 1
+  %.splatinsert = insertelement <vscale x 16 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 16 x i1> %.splatinsert, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %0)
+  %active.lane.mask.alias = or <vscale x 16 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 16 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 16 x i1> @whilewr_commutative(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_commutative:
 ; CHECK:       // %bb.0: // %entry
@@ -89,6 +119,39 @@ entry:
   ret <vscale x 8 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 8 x i1> @whilerw_16(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_16:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.h, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_16:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    cmp x8, #2
+; CHECK-NOSVE2-NEXT:    add x8, x8, x8, lsr #63
+; CHECK-NOSVE2-NEXT:    cset w9, lt
+; CHECK-NOSVE2-NEXT:    sbfx x9, x9, #0, #1
+; CHECK-NOSVE2-NEXT:    asr x8, x8, #1
+; CHECK-NOSVE2-NEXT:    whilelo p0.h, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p1.h, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 2
+  %neg.compare = icmp slt i64 %0, 2
+  %.splatinsert = insertelement <vscale x 8 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 8 x i1> %.splatinsert, <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 8 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 8 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 4 x i1> @whilewr_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_32:
 ; CHECK:       // %bb.0: // %entry
@@ -122,6 +185,41 @@ entry:
   ret <vscale x 4 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 4 x i1> @whilerw_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_32:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.s, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_32:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    add x9, x8, #3
+; CHECK-NOSVE2-NEXT:    cmp x8, #0
+; CHECK-NOSVE2-NEXT:    csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT:    cmp x8, #4
+; CHECK-NOSVE2-NEXT:    cset w8, lt
+; CHECK-NOSVE2-NEXT:    asr x9, x9, #2
+; CHECK-NOSVE2-NEXT:    sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.s, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p0.s, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 4
+  %neg.compare = icmp slt i64 %0, 4
+  %.splatinsert = insertelement <vscale x 4 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 4 x i1> %.splatinsert, <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 4 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 4 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 2 x i1> @whilewr_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: whilewr_64:
 ; CHECK:       // %bb.0: // %entry
@@ -155,6 +253,41 @@ entry:
   ret <vscale x 2 x i1> %active.lane.mask.alias
 }
 
+define <vscale x 2 x i1> @whilerw_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_64:
+; CHECK-SVE2:       // %bb.0: // %entry
+; CHECK-SVE2-NEXT:    whilerw p0.d, x2, x1
+; CHECK-SVE2-NEXT:    ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_64:
+; CHECK-NOSVE2:       // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT:    subs x8, x2, x1
+; CHECK-NOSVE2-NEXT:    cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT:    add x9, x8, #7
+; CHECK-NOSVE2-NEXT:    cmp x8, #0
+; CHECK-NOSVE2-NEXT:    csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT:    cmp x8, #8
+; CHECK-NOSVE2-NEXT:    cset w8, lt
+; CHECK-NOSVE2-NEXT:    asr x9, x9, #3
+; CHECK-NOSVE2-NEXT:    sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT:    whilelo p1.d, xzr, x9
+; CHECK-NOSVE2-NEXT:    whilelo p0.d, xzr, x8
+; CHECK-NOSVE2-NEXT:    mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT:    ret
+entry:
+  %b24 = ptrtoint ptr %b to i64
+  %c25 = ptrtoint ptr %c to i64
+  %sub.diff = sub i64 %c25, %b24
+  %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+  %diff = sdiv i64 %0, 8
+  %neg.compare = icmp slt i64 %0, 8
+  %.splatinsert = insertelement <vscale x 2 x i1> poison, i1 %neg.compare, i64 0
+  %.splat = shufflevector <vscale x 2 x i1> %.splatinsert, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
+  %ptr.diff.lane.mask = tail call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 %diff)
+  %active.lane.mask.alias = or <vscale x 2 x i1> %ptr.diff.lane.mask, %.splat
+  ret <vscale x 2 x i1> %active.lane.mask.alias
+}
+
 define <vscale x 1 x i1> @no_whilewr_128(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
 ; CHECK-LABEL: no_whilewr_128:
 ; CHECK:       // %bb.0: // %entry

This patch extends the whilewr matching to also match a read-after-write
mask and lower it to a whilerw.
%c25 = ptrtoint ptr %c to i64
%sub.diff = sub i64 %c25, %b24
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
%neg.compare = icmp slt i64 %0, 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comparison is incorrect -- you shouldn't get a negative value after calling abs, and the ref manual indicates it should just be an equality comparison against 0 to combine with the e < diff comparison. I'm not sure what to do in the case of the diff being the signed minimum, though it seems unlikely to occur in practice.

I think we got the original code for whilewr wrong too -- should be sle against 0 instead of slt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right - the original whilewr testing was based on some bad testing from me (the tests returning a predicate confused things). It goes a bit further than what you said though, I had looked at these intrinsics a while ago and come to the conclusion that they were difficult to match. Mathematically they do:
diff = zext(a) - zext(b)
elem[i] = splat(diff <= 0) | ALM(0, diff)
The zext is important if the top bits can be 1. So with inputs like 0 and 0xfffffffffffffff4 we might produce the wrong results. That likely won't be the case for pointers, but as the inputs are just integers we might need to take that into account.

I think the operands of the whilewr are also the wrong way around.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be difficult mathematically to make fit correctly. I believe that if you pass this a very high value and a very low one (say 0xfffffffffffffff8 and 0), the whilewr will zext both results so the "is negative" part will be different than if you had done it all with i64 arithmetic. The result of 0xfffffffffffffff8-0 is negative in i64, zext(0xfffffffffffffff8)-zext(0) isn't.

That might not be very important for pointers as it is only the very top bits + very bottom bits that would be negative and not outside of a vector width. I don't think we can just rule it out though, as we are matching from any i64s.

;
; CHECK-SVE2-LABEL: whilewr_8:
; CHECK-SVE2: // %bb.0: // %entry
; CHECK-SVE2-NEXT: whilewr p0.b, x1, x2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be whilewr p0.b, x2, x1?

@SamTebbs33
Copy link
Collaborator Author

Abandoned in favour of #117007

@SamTebbs33 SamTebbs33 closed this Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.