-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT #131327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT #131327
Conversation
ca0e587
to
c967e33
Compare
c967e33
to
94bf348
Compare
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesAdd lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions. This happens when there is a MUL instruction as the second operand in the ISD node. Then the extends on the operands of the MUL op need to have a different signedness. @JamesChesterman is the original author Full diff: https://github.com/llvm/llvm-project/pull/131327.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 83fade45d1892..1af60d6896e6d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
/// illegal ResNo in that case.
bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
// See if the target wants to custom lower this node.
- if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
- return false;
+ unsigned Opcode = N->getOpcode();
+ bool IsPRMLAOpcode =
+ Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
+
+ if (IsPRMLAOpcode) {
+ if (TLI.getPartialReduceMLAAction(N->getValueType(0),
+ N->getOperand(1).getValueType()) !=
+ TargetLowering::Custom)
+ return false;
+ } else {
+ if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+ return false;
+ }
SmallVector<SDValue, 8> Results;
if (LegalizeResult)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..66ab66063614c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7756,6 +7756,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -27560,6 +27563,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG));
+ return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29506,6 +29513,80 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpLHS, MulOpRHS)
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpRHS, MulOpLHS)
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ bool Scalable = Op.getValueType().isScalableVector();
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
+ return SDValue();
+ if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd()))
+ return SDValue();
+ if (!Subtarget.hasMatMulInt8())
+ return SDValue();
+ SDLoc DL(Op);
+
+ if (Op.getOperand(1).getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue Acc = Op.getOperand(0);
+ SDValue Mul = Op.getOperand(1);
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ SDValue ExtMulOpLHS = Mul.getOperand(0);
+ SDValue ExtMulOpRHS = Mul.getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS.getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS.getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ if (LHSIsSigned == RHSIsSigned)
+ return SDValue();
+
+ EVT AccVT = Acc.getValueType();
+ // There is no nxv2i64 version of usdot
+ if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
+ return SDValue();
+
+ // USDOT expects the signed operand to be last
+ if (!RHSIsSigned)
+ std::swap(MulOpLHS, MulOpRHS);
+
+ unsigned Opcode = AArch64ISD::USDOT;
+ // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+ // product followed by a zero / sign extension
+ // Don't want this to be split because there is no nxv2i64 version of usdot
+ if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
+ (AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
+ EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ SDValue DotI32 =
+ DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
+ MulOpLHS, MulOpRHS);
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
+ return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
+ }
+
+ return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS);
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..f0c35b191c0a4 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -106,25 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: usdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -165,25 +147,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: sudot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -415,59 +379,12 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
;
; CHECK-NEWLOWERING-LABEL: usdot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -548,59 +465,12 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
;
; CHECK-NEWLOWERING-LABEL: sudot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
|
@llvm/pr-subscribers-llvm-selectiondag Author: Nicholas Guy (NickGuy-Arm) ChangesAdd lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions. This happens when there is a MUL instruction as the second operand in the ISD node. Then the extends on the operands of the MUL op need to have a different signedness. @JamesChesterman is the original author Full diff: https://github.com/llvm/llvm-project/pull/131327.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 83fade45d1892..1af60d6896e6d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
/// illegal ResNo in that case.
bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
// See if the target wants to custom lower this node.
- if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
- return false;
+ unsigned Opcode = N->getOpcode();
+ bool IsPRMLAOpcode =
+ Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
+
+ if (IsPRMLAOpcode) {
+ if (TLI.getPartialReduceMLAAction(N->getValueType(0),
+ N->getOperand(1).getValueType()) !=
+ TargetLowering::Custom)
+ return false;
+ } else {
+ if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+ return false;
+ }
SmallVector<SDValue, 8> Results;
if (LegalizeResult)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..66ab66063614c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7756,6 +7756,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -27560,6 +27563,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG));
+ return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29506,6 +29513,80 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpLHS, MulOpRHS)
+// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
+// to USDOT(Acc, MulOpRHS, MulOpLHS)
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ bool Scalable = Op.getValueType().isScalableVector();
+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+ if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
+ return SDValue();
+ if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd()))
+ return SDValue();
+ if (!Subtarget.hasMatMulInt8())
+ return SDValue();
+ SDLoc DL(Op);
+
+ if (Op.getOperand(1).getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue Acc = Op.getOperand(0);
+ SDValue Mul = Op.getOperand(1);
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ SDValue ExtMulOpLHS = Mul.getOperand(0);
+ SDValue ExtMulOpRHS = Mul.getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS.getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS.getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ if (LHSIsSigned == RHSIsSigned)
+ return SDValue();
+
+ EVT AccVT = Acc.getValueType();
+ // There is no nxv2i64 version of usdot
+ if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
+ return SDValue();
+
+ // USDOT expects the signed operand to be last
+ if (!RHSIsSigned)
+ std::swap(MulOpLHS, MulOpRHS);
+
+ unsigned Opcode = AArch64ISD::USDOT;
+ // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
+ // product followed by a zero / sign extension
+ // Don't want this to be split because there is no nxv2i64 version of usdot
+ if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
+ (AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
+ EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ SDValue DotI32 =
+ DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),
+ MulOpLHS, MulOpRHS);
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
+ return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
+ }
+
+ return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS);
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d9b535b910b80..9d8d1c22258be 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..f0c35b191c0a4 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -106,25 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: usdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -165,25 +147,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: sudot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -415,59 +379,12 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
;
; CHECK-NEWLOWERING-LABEL: usdot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT: usdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -548,59 +465,12 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
;
; CHECK-NEWLOWERING-LABEL: sudot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
-; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
-; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z3.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z26.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z6.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z30.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z31.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z24.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z8.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z25.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
-; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
-; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
-; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
-; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
-; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
-; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
-; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-NEXT: usdot z4.s, z3.b, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
|
unsigned Opcode = N->getOpcode(); | ||
bool IsPRMLAOpcode = | ||
Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA; | ||
|
||
if (IsPRMLAOpcode) { | ||
if (TLI.getPartialReduceMLAAction(N->getValueType(0), | ||
N->getOperand(1).getValueType()) != | ||
TargetLowering::Custom) | ||
return false; | ||
} else { | ||
if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom) | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't LegalizeVectorOps
handle this? getPartialReduceMLAAction()
is already hooked up there and should be able to call into the custom lowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done to bypass type legalization for the usdot_8to64
case? Could we handle that instead by adding a combine that reduces accumulators of <vscale x 4 x i64>
to <vscale x 4 x i32>
followed by a extend (for i8 inputs)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickGuy-Arm I suspect you did this to work around type legalisation? At the point of doing Custom lowering, all the types must be legal. If the extends would be the same, then as @MacDue says it would be handled in LegalizeVectorOps. It's just that the operands (to be sign/zero-extended) have not been folded into the operation yet, because UMLA/SMLA doesn't support mixed extends, hence why the types can't be legalised the normal way.
The way to handle this case is to either:
(1) Implement this mapping to an AArch64ISD node with an AArch64 DAG combine that runs before type legalisation.
or:
(2)Create a separate PARTIAL_REDUCE_USMLA
node, which would go through the regular flow of type legalisation.
The downside of (1) is that we don't get any type-legalisation, so any unsupported types would need to be handled in that particular DAG combine basically requiring it to do type-legalisation. I think (2) can piggy-back on most of the type legalisation added for UMLA/SMLA, with some small changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type coming in via VT
is the type of the operand of the partial_reduce_umla
node, which is an extend, so it effectively hides the actual operand type at this stage. We need to use the pre-extended type to figure out whether USDOT
is valid to emit, and the type legalization step obscures this type by splitting across multiple sets of partial_reduce_umla
and extract_subvector
nodes, meaning we'd have to check significantly more nodes/paths to verify the validity.
I don't think the pre-legalization DAG combine would work for the reasons you pointed out, but in trying to implement the separate node, I encountered the exact same issues as we hit without the above call to getPartialReduceMLAAction
.
I've added an operation action for ISD::PARTIAL_REDUCE_UMLA
with nxv16i32
, which is the post-extended type of nxv16i8
, and we can have the existing validation within LowerPARTIAL_REDUCE_MLAToUSDOT
decide whether it can actually be lowered to USDOT (falling back to unpacks and mla
s if not).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reimplemented this check, as I believe it is the simplest solution to this problem. For USDOT
lowering to function, it needs to happen pre-legalization because it deals with illegal intermediate types (which are then flattened out by replacing the nodes with the USDOT
ISD node).
As the partialReduceMLA LegalizeActions are handled differently from the standard operation actions, we need to check the relevant action to take.
This check is simply the required plumbing to have the legalizer respect when a target says that it has custom lowering for a given partial reduction. If we try to pack the information into the operation actions, we lose the ability to filter based on what the partial reduction is reducing from. And trying to move the check to post-legalization we lose direct access to the pre-extend type, as the nodes required to legalize the type obscure it through multiple extends or AArch64ISD unpack nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the pre-legalization DAG combine would work for the reasons you pointed out
What reasons were you referring to here? I would expect this pre-type-legalization DAG combine to recognise the pattern partial.reduce.add(a, mul(sext(b), zext(c)), splat(1)) -> AArch64ISD::sudot(a, b, c)
. At this point, there shouldn't be any uunpklo/hi instructions yet.
but in trying to implement the separate node, I encountered the exact same issues as we hit without the above call to getPartialReduceMLAAction.
Are you talking about option (2), create a new ISD::PARTIAL_REDUCE_USMLA node? If so, can you elaborate on the issues you encountered? (I'd expect it to function roughly the same as the PARTIAL_REDUCE_UMLA node for example)
SDValue | ||
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is called for the types marked "Custom" in #130935?
Add lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions. This happens when there is a MUL instruction as the second operand in the ISD node. Then the extends on the operands of the MUL op need to have a different signedness.
94bf348
to
7868964
Compare
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom); | ||
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom); | ||
|
||
setOperationAction(ISD::PARTIAL_REDUCE_UMLA, MVT::nxv16i32, Custom); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like you're trying to work around a problem, because it's now using two separate mechanisms to determine whether a PARTIAL_REDUCE_* node needs custom lowering.
What issues are you running into when implementing this as a DAGcombine that runs before type legalisation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue presents itself as the extend type being obscured by numerous (depending on the illegal types) pairs of AArch64ISD::UUNPKLO/HI. While it would technically be possible to traverse these to find the pre-extended type, we would have to traverse the entire tree of a potentially unknown size to verify that it's all valid.
By running this lowering before type legalization, we can make our adjustments and insert the node hiererchy representing USDOT, which itself can then benefit from type legalization.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom); | ||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom); | ||
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom); | ||
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect the cases you've handled above (where all types are legal) to be sufficient. Why did you have to add this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These include the "illegal but internally legalizable" types that were needed for usdot_8to64 cases, these have now been removed from this PR along with the support for 8to64 cases.
Converting this to a draft, as an acceptable solution will likely require more effort than I can afford right now. I'll revisit it in the future. |
Add lowering for PARTIAL_REDUCE_U/SMLA nodes to USDOT instructions. This happens when there is a MUL instruction as the second operand in the ISD node. Then the extends on the operands of the MUL op need to have a different signedness.
@JamesChesterman is the original author