diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp index 83fade45d1892..4bb61261526f0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp @@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op, /// illegal ResNo in that case. bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) { // See if the target wants to custom lower this node. - if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom) - return false; + unsigned Opcode = N->getOpcode(); + bool IsPRMLAOpcode = + Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA; + + if (IsPRMLAOpcode) { + if (TLI.getPartialReduceMLAAction(N->getValueType(0), + N->getOperand(1).getValueType()) != + TargetLowering::Custom) + return false; + } else { + if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom) + return false; + } SmallVector Results; if (LegalizeResult) @@ -946,7 +957,6 @@ bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) { return true; } - /// Widen the node's results with custom code provided by the target and return /// "true", or do nothing and return "false". bool DAGTypeLegalizer::CustomWidenLowerNode(SDNode *N, EVT VT) { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 6e7f13c20db68..e98307fba88dd 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1869,7 +1869,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal); setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); + // 8to64 setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom); + + // USDOT + if (Subtarget->hasMatMulInt8()) + setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom); } // Handle operations that are only available in non-streaming SVE mode. @@ -27533,6 +27538,10 @@ void AArch64TargetLowering::ReplaceNodeResults( if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG)) Results.push_back(Res); return; + case ISD::PARTIAL_REDUCE_UMLA: + case ISD::PARTIAL_REDUCE_SMLA: + Results.push_back(LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG)); + return; case ISD::ADD: case ISD::FADD: ReplaceAddWithADDP(N, Results, DAG, Subtarget); @@ -29481,21 +29490,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, return Scatter; } -/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing -/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can -/// however still make use of the dot product instruction by instead -/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64. SDValue AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const { - SDLoc DL(Op); + if (SDValue UsdotNode = LowerPARTIAL_REDUCE_MLAToUSDOT(Op, DAG)) + return UsdotNode; - SDValue Acc = Op.getOperand(0); SDValue LHS = Op.getOperand(1); - SDValue RHS = Op.getOperand(2); EVT ResultVT = Op.getValueType(); - assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8); + /// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type + /// pairing of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We + /// can however still make use of the dot product instruction by instead + /// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64. + if (ResultVT != MVT::nxv2i64 || LHS.getValueType() != MVT::nxv16i8) + return SDValue(); + SDLoc DL(Op); + SDValue Acc = Op.getOperand(0); + SDValue RHS = Op.getOperand(2); SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS); @@ -29515,6 +29527,80 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended); } +// partial.reduce.umla(acc, mul(zext(mulOpLHS), sext(mulOpRHS)), splat(1)) +// -> USDOT(acc, mulOpLHS, mulOpRHS) +// partial.reduce.smla(acc, mul(sext(mulOpLHS), zext(mulOpRHS)), splat(1)) +// -> USDOT(acc, mulOpRHS, mulOpLHS) +SDValue +AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op, + SelectionDAG &DAG) const { + bool Scalable = Op.getValueType().isScalableVector(); + auto &Subtarget = DAG.getSubtarget(); + if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable()) + return SDValue(); + if (!Scalable && (!Subtarget.isNeonAvailable() || !Subtarget.hasDotProd())) + return SDValue(); + if (!Subtarget.hasMatMulInt8()) + return SDValue(); + SDLoc DL(Op); + + if (Op.getOperand(1).getOpcode() != ISD::MUL) + return SDValue(); + + SDValue Acc = Op.getOperand(0); + SDValue Mul = Op.getOperand(1); + + APInt ConstantOne; + if (!ISD::isConstantSplatVector(Op.getOperand(2).getNode(), ConstantOne) || + !ConstantOne.isOne()) + return SDValue(); + + SDValue ExtMulOpLHS = Mul.getOperand(0); + SDValue ExtMulOpRHS = Mul.getOperand(1); + unsigned ExtMulOpLHSOpcode = ExtMulOpLHS.getOpcode(); + unsigned ExtMulOpRHSOpcode = ExtMulOpRHS.getOpcode(); + if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) || + !ISD::isExtOpcode(ExtMulOpRHSOpcode)) + return SDValue(); + + SDValue MulOpLHS = ExtMulOpLHS.getOperand(0); + SDValue MulOpRHS = ExtMulOpRHS.getOperand(0); + EVT MulOpLHSVT = MulOpLHS.getValueType(); + if (MulOpLHSVT != MulOpRHS.getValueType()) + return SDValue(); + + bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND; + bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND; + if (LHSIsSigned == RHSIsSigned) + return SDValue(); + + EVT AccVT = Acc.getValueType(); + // There is no nxv2i64 version of usdot + if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64) + return SDValue(); + + // USDOT expects the signed operand to be last + if (!RHSIsSigned) + std::swap(MulOpLHS, MulOpRHS); + + unsigned Opcode = AArch64ISD::USDOT; + // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot + // product followed by a zero / sign extension + // Don't want this to be split because there is no nxv2i64 version of usdot + if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) || + (AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) { + EVT AccVTI32 = AccVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32; + + SDValue DotI32 = + DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32), + MulOpLHS, MulOpRHS); + SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT); + return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended); + } + + return DAG.getNode(Opcode, DL, AccVT, Acc, MulOpLHS, MulOpRHS); +} + SDValue AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 9d8d1c22258be..6f0fb03bae0ea 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1182,6 +1182,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const; SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 5bc9a101b1e44..0c5ec2908a16d 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -3,7 +3,7 @@ ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM ; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2 -; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME +; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME define @udot( %acc, %a, %b) { ; CHECK-LABEL: udot: @@ -106,23 +106,7 @@ define @usdot( %acc, %a, ; ; CHECK-NEWLOWERING-LABEL: usdot: ; CHECK-NEWLOWERING: // %bb.0: // %entry -; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b -; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b -; CHECK-NEWLOWERING-NEXT: ptrue p0.s -; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b -; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b -; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h -; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h -; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s -; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z1.h -; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z2.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h -; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b ; CHECK-NEWLOWERING-NEXT: ret entry: %a.wide = zext %a to @@ -161,23 +145,7 @@ define @sudot( %acc, %a, ; ; CHECK-NEWLOWERING-LABEL: sudot: ; CHECK-NEWLOWERING: // %bb.0: // %entry -; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b -; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b -; CHECK-NEWLOWERING-NEXT: ptrue p0.s -; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b -; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b -; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h -; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h -; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s -; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h -; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h -; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b ; CHECK-NEWLOWERING-NEXT: ret entry: %a.wide = sext %a to