-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AArch64][SVE2] Lower read-after-write mask to whilerw #114028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Sam Tebbs (SamTebbs33) ChangesThis patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw. Full diff: https://github.com/llvm/llvm-project/pull/114028.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index bf2f0674b5b65e..a2517761afc0c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14189,7 +14189,16 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
return SDValue();
SDValue Diff = Cmp.getOperand(0);
- if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
+ SDValue NonAbsDiff = Diff;
+ bool WriteAfterRead = true;
+ // A read-after-write will have an abs call on the diff
+ if (Diff.getOpcode() == ISD::ABS) {
+ NonAbsDiff = Diff.getOperand(0);
+ WriteAfterRead = false;
+ }
+
+ if (NonAbsDiff.getOpcode() != ISD::SUB ||
+ NonAbsDiff.getValueType() != MVT::i64)
return SDValue();
if (!isNullConstant(LaneMask.getOperand(1)) ||
@@ -14210,8 +14219,13 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
// it's positive, otherwise the difference plus the element size if it's
// negative: pos_diff = diff < 0 ? (diff + 7) : diff
SDValue Select = DiffDiv.getOperand(0);
+ SDValue SelectOp3 = Select.getOperand(3);
+ // Check for an abs in the case of a read-after-write
+ if (!WriteAfterRead && SelectOp3.getOpcode() == ISD::ABS)
+ SelectOp3 = SelectOp3.getOperand(0);
+
// Make sure the difference is being compared by the select
- if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
+ if (Select.getOpcode() != ISD::SELECT_CC || SelectOp3 != NonAbsDiff)
return SDValue();
// Make sure it's checking if the difference is less than 0
if (!isNullConstant(Select.getOperand(1)) ||
@@ -14243,22 +14257,26 @@ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
} else if (LaneMask.getOperand(2) != Diff)
return SDValue();
- SDValue StorePtr = Diff.getOperand(0);
- SDValue ReadPtr = Diff.getOperand(1);
+ SDValue StorePtr = NonAbsDiff.getOperand(0);
+ SDValue ReadPtr = NonAbsDiff.getOperand(1);
unsigned IntrinsicID = 0;
switch (EltSize) {
case 1:
- IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
+ IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_b
+ : Intrinsic::aarch64_sve_whilerw_b;
break;
case 2:
- IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
+ IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_h
+ : Intrinsic::aarch64_sve_whilerw_h;
break;
case 4:
- IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
+ IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_s
+ : Intrinsic::aarch64_sve_whilerw_s;
break;
case 8:
- IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
+ IntrinsicID = WriteAfterRead ? Intrinsic::aarch64_sve_whilewr_d
+ : Intrinsic::aarch64_sve_whilerw_d;
break;
default:
return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/whilewr.ll b/llvm/test/CodeGen/AArch64/whilewr.ll
index 9f1ea850792384..bdea64296455eb 100644
--- a/llvm/test/CodeGen/AArch64/whilewr.ll
+++ b/llvm/test/CodeGen/AArch64/whilewr.ll
@@ -30,6 +30,36 @@ entry:
ret <vscale x 16 x i1> %active.lane.mask.alias
}
+define <vscale x 16 x i1> @whilerw_8(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_8:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: whilerw p0.b, x2, x1
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_8:
+; CHECK-NOSVE2: // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT: subs x8, x2, x1
+; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT: cmp x8, #1
+; CHECK-NOSVE2-NEXT: cset w9, lt
+; CHECK-NOSVE2-NEXT: whilelo p0.b, xzr, x8
+; CHECK-NOSVE2-NEXT: sbfx x8, x9, #0, #1
+; CHECK-NOSVE2-NEXT: whilelo p1.b, xzr, x8
+; CHECK-NOSVE2-NEXT: sel p0.b, p0, p0.b, p1.b
+; CHECK-NOSVE2-NEXT: ret
+entry:
+ %b24 = ptrtoint ptr %b to i64
+ %c25 = ptrtoint ptr %c to i64
+ %sub.diff = sub i64 %c25, %b24
+ %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+ %neg.compare = icmp slt i64 %0, 1
+ %.splatinsert = insertelement <vscale x 16 x i1> poison, i1 %neg.compare, i64 0
+ %.splat = shufflevector <vscale x 16 x i1> %.splatinsert, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
+ %ptr.diff.lane.mask = tail call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %0)
+ %active.lane.mask.alias = or <vscale x 16 x i1> %ptr.diff.lane.mask, %.splat
+ ret <vscale x 16 x i1> %active.lane.mask.alias
+}
+
define <vscale x 16 x i1> @whilewr_commutative(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
; CHECK-LABEL: whilewr_commutative:
; CHECK: // %bb.0: // %entry
@@ -89,6 +119,39 @@ entry:
ret <vscale x 8 x i1> %active.lane.mask.alias
}
+define <vscale x 8 x i1> @whilerw_16(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_16:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: whilerw p0.h, x2, x1
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_16:
+; CHECK-NOSVE2: // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT: subs x8, x2, x1
+; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT: cmp x8, #2
+; CHECK-NOSVE2-NEXT: add x8, x8, x8, lsr #63
+; CHECK-NOSVE2-NEXT: cset w9, lt
+; CHECK-NOSVE2-NEXT: sbfx x9, x9, #0, #1
+; CHECK-NOSVE2-NEXT: asr x8, x8, #1
+; CHECK-NOSVE2-NEXT: whilelo p0.h, xzr, x9
+; CHECK-NOSVE2-NEXT: whilelo p1.h, xzr, x8
+; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT: ret
+entry:
+ %b24 = ptrtoint ptr %b to i64
+ %c25 = ptrtoint ptr %c to i64
+ %sub.diff = sub i64 %c25, %b24
+ %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+ %diff = sdiv i64 %0, 2
+ %neg.compare = icmp slt i64 %0, 2
+ %.splatinsert = insertelement <vscale x 8 x i1> poison, i1 %neg.compare, i64 0
+ %.splat = shufflevector <vscale x 8 x i1> %.splatinsert, <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer
+ %ptr.diff.lane.mask = tail call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 0, i64 %diff)
+ %active.lane.mask.alias = or <vscale x 8 x i1> %ptr.diff.lane.mask, %.splat
+ ret <vscale x 8 x i1> %active.lane.mask.alias
+}
+
define <vscale x 4 x i1> @whilewr_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
; CHECK-LABEL: whilewr_32:
; CHECK: // %bb.0: // %entry
@@ -122,6 +185,41 @@ entry:
ret <vscale x 4 x i1> %active.lane.mask.alias
}
+define <vscale x 4 x i1> @whilerw_32(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_32:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: whilerw p0.s, x2, x1
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_32:
+; CHECK-NOSVE2: // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT: subs x8, x2, x1
+; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT: add x9, x8, #3
+; CHECK-NOSVE2-NEXT: cmp x8, #0
+; CHECK-NOSVE2-NEXT: csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT: cmp x8, #4
+; CHECK-NOSVE2-NEXT: cset w8, lt
+; CHECK-NOSVE2-NEXT: asr x9, x9, #2
+; CHECK-NOSVE2-NEXT: sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT: whilelo p1.s, xzr, x9
+; CHECK-NOSVE2-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT: ret
+entry:
+ %b24 = ptrtoint ptr %b to i64
+ %c25 = ptrtoint ptr %c to i64
+ %sub.diff = sub i64 %c25, %b24
+ %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+ %diff = sdiv i64 %0, 4
+ %neg.compare = icmp slt i64 %0, 4
+ %.splatinsert = insertelement <vscale x 4 x i1> poison, i1 %neg.compare, i64 0
+ %.splat = shufflevector <vscale x 4 x i1> %.splatinsert, <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer
+ %ptr.diff.lane.mask = tail call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 0, i64 %diff)
+ %active.lane.mask.alias = or <vscale x 4 x i1> %ptr.diff.lane.mask, %.splat
+ ret <vscale x 4 x i1> %active.lane.mask.alias
+}
+
define <vscale x 2 x i1> @whilewr_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
; CHECK-LABEL: whilewr_64:
; CHECK: // %bb.0: // %entry
@@ -155,6 +253,41 @@ entry:
ret <vscale x 2 x i1> %active.lane.mask.alias
}
+define <vscale x 2 x i1> @whilerw_64(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
+; CHECK-SVE2-LABEL: whilerw_64:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: whilerw p0.d, x2, x1
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-NOSVE2-LABEL: whilerw_64:
+; CHECK-NOSVE2: // %bb.0: // %entry
+; CHECK-NOSVE2-NEXT: subs x8, x2, x1
+; CHECK-NOSVE2-NEXT: cneg x8, x8, mi
+; CHECK-NOSVE2-NEXT: add x9, x8, #7
+; CHECK-NOSVE2-NEXT: cmp x8, #0
+; CHECK-NOSVE2-NEXT: csel x9, x9, x8, lt
+; CHECK-NOSVE2-NEXT: cmp x8, #8
+; CHECK-NOSVE2-NEXT: cset w8, lt
+; CHECK-NOSVE2-NEXT: asr x9, x9, #3
+; CHECK-NOSVE2-NEXT: sbfx x8, x8, #0, #1
+; CHECK-NOSVE2-NEXT: whilelo p1.d, xzr, x9
+; CHECK-NOSVE2-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NOSVE2-NEXT: mov p0.b, p1/m, p1.b
+; CHECK-NOSVE2-NEXT: ret
+entry:
+ %b24 = ptrtoint ptr %b to i64
+ %c25 = ptrtoint ptr %c to i64
+ %sub.diff = sub i64 %c25, %b24
+ %0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false)
+ %diff = sdiv i64 %0, 8
+ %neg.compare = icmp slt i64 %0, 8
+ %.splatinsert = insertelement <vscale x 2 x i1> poison, i1 %neg.compare, i64 0
+ %.splat = shufflevector <vscale x 2 x i1> %.splatinsert, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
+ %ptr.diff.lane.mask = tail call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 %diff)
+ %active.lane.mask.alias = or <vscale x 2 x i1> %ptr.diff.lane.mask, %.splat
+ ret <vscale x 2 x i1> %active.lane.mask.alias
+}
+
define <vscale x 1 x i1> @no_whilewr_128(ptr noalias %a, ptr %b, ptr %c, i32 %n) {
; CHECK-LABEL: no_whilewr_128:
; CHECK: // %bb.0: // %entry
|
This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.
llvm/test/CodeGen/AArch64/whilewr.ll
Outdated
%c25 = ptrtoint ptr %c to i64 | ||
%sub.diff = sub i64 %c25, %b24 | ||
%0 = tail call i64 @llvm.abs.i64(i64 %sub.diff, i1 false) | ||
%neg.compare = icmp slt i64 %0, 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comparison is incorrect -- you shouldn't get a negative value after calling abs
, and the ref manual indicates it should just be an equality comparison against 0 to combine with the e < diff
comparison. I'm not sure what to do in the case of the diff being the signed minimum, though it seems unlikely to occur in practice.
I think we got the original code for whilewr wrong too -- should be sle
against 0 instead of slt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right - the original whilewr testing was based on some bad testing from me (the tests returning a predicate confused things). It goes a bit further than what you said though, I had looked at these intrinsics a while ago and come to the conclusion that they were difficult to match. Mathematically they do:
diff = zext(a) - zext(b)
elem[i] = splat(diff <= 0) | ALM(0, diff)
The zext is important if the top bits can be 1. So with inputs like 0 and 0xfffffffffffffff4 we might produce the wrong results. That likely won't be the case for pointers, but as the inputs are just integers we might need to take that into account.
I think the operands of the whilewr are also the wrong way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be difficult mathematically to make fit correctly. I believe that if you pass this a very high value and a very low one (say 0xfffffffffffffff8 and 0), the whilewr will zext both results so the "is negative" part will be different than if you had done it all with i64 arithmetic. The result of 0xfffffffffffffff8-0 is negative in i64, zext(0xfffffffffffffff8)-zext(0) isn't.
That might not be very important for pointers as it is only the very top bits + very bottom bits that would be negative and not outside of a vector width. I don't think we can just rule it out though, as we are matching from any i64s.
; | ||
; CHECK-SVE2-LABEL: whilewr_8: | ||
; CHECK-SVE2: // %bb.0: // %entry | ||
; CHECK-SVE2-NEXT: whilewr p0.b, x1, x2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be whilewr p0.b, x2, x1
?
Abandoned in favour of #117007 |
This patch extends the whilewr matching to also match a read-after-write mask and lower it to a whilerw.