-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes #140075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes #140075
Conversation
…REDUCE_*MLA ISD nodes Lowering for fixed width vectors added to tablegen. There is also custom lowering to ensure that the USDOT patterns are still lowered for fixed width vectors. It also ensures that the v16i8 -> v4i64 partial reduction case is lowered here instead of being split (as there is not a v2i64 dot product instruction).
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesLowering for fixed width vectors added to tablegen. @JamesChesterman is the original author. Patch is 37.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140075.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..f1354bf1147dd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1872,6 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
}
+ if (EnablePartialReduceNodes && Subtarget->hasNEON() &&
+ Subtarget->hasDotProd()) {
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+ }
+
// Handle operations that are only available in non-streaming SVE mode.
if (Subtarget->isSVEAvailable()) {
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
@@ -7743,8 +7752,11 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
case ISD::PARTIAL_REDUCE_SMLA:
- case ISD::PARTIAL_REDUCE_UMLA:
- return LowerPARTIAL_REDUCE_MLA(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UMLA: {
+ if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG))
+ return Result;
+ return expandPartialReduceMLA(Op.getNode(), DAG);
+ }
}
}
@@ -27569,6 +27581,14 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA: {
+ if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
+ Results.push_back(Res);
+ else
+ Results.push_back(expandPartialReduceMLA(N, DAG));
+ return;
+ }
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29518,37 +29538,64 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// of v2i64/v16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
-/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+/// accumulating over two steps: v16i8 -> v4i32 -> v2i64.
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
+ bool Scalable = Op.getValueType().isScalableVector();
+ if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+ if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+ return SDValue();
+
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
- assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
- SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
- DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+ assert((Scalable && ResultVT == MVT::nxv2i64 &&
+ LHS.getValueType() == MVT::nxv16i8) ||
+ (!Scalable && ResultVT == MVT::v2i64 &&
+ LHS.getValueType() == MVT::v16i8));
+
+ EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+ SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
+ DAG.getConstant(0, DL, DotVT), LHS, RHS);
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
- if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+ if (Scalable &&
+ (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
}
- unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
- unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
- auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
- auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
- auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
- return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+ if (Scalable) {
+ unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+ unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+ auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+ auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+ auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+ }
+
+ // Fold v4i32 into v2i64
+ // SDValues
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+ if (IsUnsigned) {
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+ } else {
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+ }
+ auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo);
+ return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi);
}
SDValue
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index b02a907f7439f..5cc6a38d55977 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
}
+let Predicates = [HasNEON, HasDotProd] in {
+ def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
+ (v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
+ def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
+ (v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
+ def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
+ (v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
+ def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
+ (v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
+} // End HasNEON, HasDotProd
+
// ARMv8.6-A BFloat
let Predicates = [HasNEON, HasBF16] in {
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index ab9813aa796e3..47a4796d0f9a1 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,7 +2,8 @@
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM,CHECK-NEWLOWERING-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM,CHECK-NEWLOWERING-NOI8MM
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
@@ -174,10 +175,17 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: usdot:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
+; CHECK-NEWLOWERING-I8MM: // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -209,21 +217,28 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
; CHECK-NOI8MM-NEXT: // %bb.2: // %end
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: usdot_in_loop:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: movi v1.2d, #0000000000000000
-; CHECK-I8MM-NEXT: mov x8, xzr
-; CHECK-I8MM-NEXT: .LBB6_1: // %vector.body
-; CHECK-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECK-I8MM-NEXT: ldr q2, [x0, x8]
-; CHECK-I8MM-NEXT: ldr q3, [x1, x8]
-; CHECK-I8MM-NEXT: mov v0.16b, v1.16b
-; CHECK-I8MM-NEXT: add x8, x8, #16
-; CHECK-I8MM-NEXT: usdot v1.4s, v3.16b, v2.16b
-; CHECK-I8MM-NEXT: cmp x8, #16
-; CHECK-I8MM-NEXT: b.ne .LBB6_1
-; CHECK-I8MM-NEXT: // %bb.2: // %end
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot_in_loop:
+; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v1.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: mov x8, xzr
+; CHECK-NEWLOWERING-I8MM-NEXT: .LBB6_1: // %vector.body
+; CHECK-NEWLOWERING-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-I8MM-NEXT: ldr q2, [x0, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB6_1
+; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
br label %vector.body
@@ -264,10 +279,22 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: usdot_narrow:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
+; CHECK-NEWLOWERING-I8MM: // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -288,10 +315,17 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
; CHECK-NOI8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: sudot:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
+; CHECK-NEWLOWERING-I8MM: // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
%s.wide = sext <16 x i8> %u to <16 x i32>
%u.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %u.wide, %s.wide
@@ -323,21 +357,28 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-NOI8MM-NEXT: // %bb.2: // %end
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: sudot_in_loop:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: movi v1.2d, #0000000000000000
-; CHECK-I8MM-NEXT: mov x8, xzr
-; CHECK-I8MM-NEXT: .LBB9_1: // %vector.body
-; CHECK-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECK-I8MM-NEXT: ldr q2, [x0, x8]
-; CHECK-I8MM-NEXT: ldr q3, [x1, x8]
-; CHECK-I8MM-NEXT: mov v0.16b, v1.16b
-; CHECK-I8MM-NEXT: add x8, x8, #16
-; CHECK-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
-; CHECK-I8MM-NEXT: cmp x8, #16
-; CHECK-I8MM-NEXT: b.ne .LBB9_1
-; CHECK-I8MM-NEXT: // %bb.2: // %end
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot_in_loop:
+; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v1.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: mov x8, xzr
+; CHECK-NEWLOWERING-I8MM-NEXT: .LBB9_1: // %vector.body
+; CHECK-NEWLOWERING-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-I8MM-NEXT: ldr q2, [x0, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB9_1
+; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
br label %vector.body
@@ -378,10 +419,22 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-NOI8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NOI8MM-NEXT: ret
;
-; CHECK-I8MM-LABEL: sudot_narrow:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
-; CHECK-I8MM-NEXT: ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
+; CHECK-NEWLOWERING-I8MM: // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
+; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -390,14 +443,6 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
}
define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
-; CHECK-DOT-LABEL: udot_8to64:
-; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
-; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
-; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
-; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
-; CHECK-DOT-NEXT: ret
-;
; CHECK-NODOT-LABEL: udot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
@@ -415,6 +460,22 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: uaddw2 v0.2d, v0.2d, v4.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-I8MM-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: uaddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-NOI8MM-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-NOI8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NOI8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NOI8MM-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NOI8MM-NEXT: uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NOI8MM-NEXT: uaddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-NOI8MM-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
@@ -425,14 +486,6 @@ entry:
}
define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
-; CHECK-DOT-LABEL: sdot_8to64:
-; CHECK-DOT: // %bb.0: // %entry
-; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
-; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
-; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
-; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
-; CHECK-DOT-NEXT: ret
-;
; CHECK-NODOT-LABEL: sdot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
@@ -450,6 +503,22 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: saddw2 v0.2d, v0.2d, v4.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-I8MM-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-NOI8MM-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-NOI8MM: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NOI8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NOI8MM-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NOI8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NOI8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-NOI8MM-NEXT: ret
...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of suggestions.
/// however still make use of the dot product instruction by instead | ||
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64. | ||
/// accumulating over two steps: v16i8 -> v4i32 -> v2i64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have a comment about the pattern that it produces.
; CHECK-I8MM: // %bb.0: | ||
; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b | ||
; CHECK-I8MM-NEXT: ret | ||
; CHECK-NEWLOWERING-I8MM-LABEL: usdot: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep the I8MM check statements and add these NEWLOWERING ones as new ones?
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG)) | ||
Results.push_back(Res); | ||
else | ||
Results.push_back(expandPartialReduceMLA(N, DAG)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No test failures from removing this expansion (and also seems not to be done for other nodes). IIRC a custom lowering failing just means the node is assumed to be legal.
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG)) | |
Results.push_back(Res); | |
else | |
Results.push_back(expandPartialReduceMLA(N, DAG)); | |
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG)) | |
Results.push_back(Res); |
if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG)) | ||
return Result; | ||
return expandPartialReduceMLA(Op.getNode(), DAG); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto here:
if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG)) | |
return Result; | |
return expandPartialReduceMLA(Op.getNode(), DAG); | |
return LowerPARTIAL_REDUCE_MLA(Op, DAG); |
@@ -29518,37 +29538,64 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, | ||
} | ||
|
||
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing | ||
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can | ||
/// of v2i64/v16i8, we cannot directly lower it to a (u|s)dot. We can |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Maybe but "nx" in brackets e.g. (nx)v2i64
, since this applies to both scalable and fixed vectors.
if (Scalable) { | ||
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO; | ||
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI; | ||
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode); | ||
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode); | ||
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi); | ||
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended); | ||
} | ||
|
||
// Fold v4i32 into v2i64 | ||
// SDValues | ||
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL); | ||
if (IsUnsigned) { | ||
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64); | ||
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64); | ||
} else { | ||
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64); | ||
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64); | ||
} | ||
auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo); | ||
return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this would work for both Neon and SVE. Any reason this is not just?:
if (Scalable) { | |
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO; | |
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI; | |
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode); | |
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode); | |
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi); | |
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended); | |
} | |
// Fold v4i32 into v2i64 | |
// SDValues | |
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL); | |
if (IsUnsigned) { | |
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64); | |
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64); | |
} else { | |
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64); | |
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64); | |
} | |
auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo); | |
return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi); | |
// Fold (nx)v4i32 into (nx)v2i64 | |
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL); | |
if (IsUnsigned) { | |
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT); | |
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT); | |
} else { | |
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT); | |
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT); | |
} | |
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, DotNodeLo, DotNodeHi); | |
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended); |
Lowering for fixed width vectors added to tablegen.
There is also custom lowering to ensure that the USDOT patterns are
still lowered for fixed width vectors. It also ensures that the
v16i8 -> v4i64 partial reduction case is lowered here instead of
being split (as there is not a v2i64 dot product instruction).
@JamesChesterman is the original author.