diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h index 75c051712ae43..b790786760742 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h @@ -162,6 +162,8 @@ class LLVM_ABI CallLowering { /// True if this call results in convergent operations. bool IsConvergent = true; + + GlobalValue *DeactivationSymbol = nullptr; }; /// Argument handling is mostly uniform between the four places that diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h index 7a598bb77b356..8062ef07f876d 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h @@ -56,6 +56,7 @@ struct MachineIRBuilderState { MDNode *PCSections = nullptr; /// MMRA Metadata to be set on any instruction we create. MDNode *MMRA = nullptr; + Value *DS = nullptr; /// \name Fields describing the insertion point. /// @{ @@ -369,6 +370,7 @@ class LLVM_ABI MachineIRBuilder { State.II = MI.getIterator(); setPCSections(MI.getPCSections()); setMMRAMetadata(MI.getMMRAMetadata()); + setDeactivationSymbol(MI.getDeactivationSymbol()); } /// @} @@ -405,6 +407,9 @@ class LLVM_ABI MachineIRBuilder { /// Set the PC sections metadata to \p MD for all the next build instructions. void setMMRAMetadata(MDNode *MMRA) { State.MMRA = MMRA; } + Value *getDeactivationSymbol() { return State.DS; } + void setDeactivationSymbol(Value *DS) { State.DS = DS; } + /// Get the current instruction's MMRA metadata. MDNode *getMMRAMetadata() { return State.MMRA; } diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index 465e4a0a9d0d8..64382df7202ed 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -1563,6 +1563,8 @@ enum NodeType { // Outputs: Output Chain CLEAR_CACHE, + DEACTIVATION_SYMBOL, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h index 94d04b82666be..1935195913336 100644 --- a/llvm/include/llvm/CodeGen/MachineInstr.h +++ b/llvm/include/llvm/CodeGen/MachineInstr.h @@ -866,6 +866,15 @@ class MachineInstr return nullptr; } + // FIXME: Move to Info. + Value *DeactivationSymbol = nullptr; + Value *getDeactivationSymbol() const { + return DeactivationSymbol; + } + void setDeactivationSymbol(MachineFunction &MF, Value *DeactivationSymbol) { + this->DeactivationSymbol = DeactivationSymbol; + } + /// Helper to extract a CFI type hash if one has been added. uint32_t getCFIType() const { if (!Info) diff --git a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h index e705d7d99544c..caeb430d6fd1c 100644 --- a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h +++ b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h @@ -70,29 +70,44 @@ enum { } // end namespace RegState /// Set of metadata that should be preserved when using BuildMI(). This provides -/// a more convenient way of preserving DebugLoc, PCSections and MMRA. +/// a more convenient way of preserving certain data from the original +/// instruction. class MIMetadata { public: MIMetadata() = default; - MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) - : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA) {} + MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr, + Value *DeactivationSymbol = nullptr) + : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA), + DeactivationSymbol(DeactivationSymbol) {} MIMetadata(const DILocation *DI, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) : DL(DI), PCSections(PCSections), MMRA(MMRA) {} explicit MIMetadata(const Instruction &From) : DL(From.getDebugLoc()), - PCSections(From.getMetadata(LLVMContext::MD_pcsections)) {} + PCSections(From.getMetadata(LLVMContext::MD_pcsections)), + DeactivationSymbol(getDeactivationSymbol(&From)) {} explicit MIMetadata(const MachineInstr &From) - : DL(From.getDebugLoc()), PCSections(From.getPCSections()) {} + : DL(From.getDebugLoc()), PCSections(From.getPCSections()), + DeactivationSymbol(From.getDeactivationSymbol()) {} const DebugLoc &getDL() const { return DL; } MDNode *getPCSections() const { return PCSections; } MDNode *getMMRAMetadata() const { return MMRA; } + Value *getDeactivationSymbol() const { return DeactivationSymbol; } private: DebugLoc DL; MDNode *PCSections = nullptr; MDNode *MMRA = nullptr; + Value *DeactivationSymbol = nullptr; + + static inline Value *getDeactivationSymbol(const Instruction *I) { + if (auto *CB = dyn_cast(I)) + if (auto Bundle = + CB->getOperandBundle(llvm::LLVMContext::OB_deactivation_symbol)) + return Bundle->Inputs[0].get(); + return nullptr; + } }; class MachineInstrBuilder { @@ -348,6 +363,8 @@ class MachineInstrBuilder { MI->setPCSections(*MF, MIMD.getPCSections()); if (MIMD.getMMRAMetadata()) MI->setMMRAMetadata(*MF, MIMD.getMMRAMetadata()); + if (MIMD.getDeactivationSymbol()) + MI->setDeactivationSymbol(*MF, MIMD.getDeactivationSymbol()); return *this; } diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index b856b4786573b..6151c7b7d9b4e 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -763,6 +763,7 @@ class SelectionDAG { int64_t offset = 0, unsigned TargetFlags = 0) { return getGlobalAddress(GV, DL, VT, offset, true, TargetFlags); } + LLVM_ABI SDValue getDeactivationSymbol(const GlobalValue *GV); LLVM_ABI SDValue getFrameIndex(int FI, EVT VT, bool isTarget = false); SDValue getTargetFrameIndex(int FI, EVT VT) { return getFrameIndex(FI, VT, true); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h index a6a3928230c3d..824cae23a7ff4 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h @@ -150,6 +150,7 @@ class SelectionDAGISel { OPC_RecordChild7, OPC_RecordMemRef, OPC_CaptureGlueInput, + OPC_CaptureDeactivationSymbol, OPC_MoveChild, OPC_MoveChild0, OPC_MoveChild1, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 5d9937f832396..8a96260242587 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1980,6 +1980,23 @@ class GlobalAddressSDNode : public SDNode { } }; +class DeactivationSymbolSDNode : public SDNode { + friend class SelectionDAG; + + const GlobalValue *TheGlobal; + + DeactivationSymbolSDNode(const GlobalValue *GV, SDVTList VTs) + : SDNode(ISD::DEACTIVATION_SYMBOL, 0, DebugLoc(), VTs), + TheGlobal(GV) {} + +public: + const GlobalValue *getGlobal() const { return TheGlobal; } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::DEACTIVATION_SYMBOL; + } +}; + class FrameIndexSDNode : public SDNode { friend class SelectionDAG; diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index fa46d296bf533..70618e98c1bfc 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4710,6 +4710,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { SmallVector InVals; const ConstantInt *CFIType = nullptr; SDValue ConvergenceControlToken; + GlobalValue *DeactivationSymbol = nullptr; std::optional PAI; @@ -4855,6 +4856,11 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { return *this; } + CallLoweringInfo &setDeactivationSymbol(GlobalValue *Sym) { + DeactivationSymbol = Sym; + return *this; + } + ArgListTy &getArgs() { return Args; } diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h index 852a3a4e2f638..c2d2d46da6ddd 100644 --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -97,6 +97,7 @@ class LLVMContext { OB_ptrauth = 7, // "ptrauth" OB_kcfi = 8, // "kcfi" OB_convergencectrl = 9, // "convergencectrl" + OB_deactivation_symbol = 10, // "deactivation-symbol" }; /// getMDKindID - Return a unique non-zero ID for the specified metadata kind. diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td index 4c83f8a580aa0..77589e4664746 100644 --- a/llvm/include/llvm/Target/Target.td +++ b/llvm/include/llvm/Target/Target.td @@ -682,6 +682,7 @@ class Instruction : InstructionEncoding { // If so, make sure to override // TargetInstrInfo::getInsertSubregLikeInputs. bit variadicOpsAreDefs = false; // Are variadic operands definitions? + bit supportsDeactivationSymbol = false; // Does the instruction have side effects that are not captured by any // operands of the instruction or other flags? diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp index 012d8735180be..de7eae6ad1d7b 100644 --- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp @@ -195,6 +195,10 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB, assert(Info.CFIType->getType()->isIntegerTy(32) && "Invalid CFI type"); } + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + Info.DeactivationSymbol = cast(Bundle->Inputs[0]); + } + Info.CB = &CB; Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees); Info.CallConv = CallConv; diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index ef39fc74554c9..3ba9c0d633bdc 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2857,6 +2857,9 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { } } + if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_deactivation_symbol)) + MIB->setDeactivationSymbol(*MF, Bundle->Inputs[0].get()); + return true; } diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp index 121d7e80251c7..436d33e3ad828 100644 --- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp @@ -38,8 +38,10 @@ void MachineIRBuilder::setMF(MachineFunction &MF) { //------------------------------------------------------------------------------ MachineInstrBuilder MachineIRBuilder::buildInstrNoInsert(unsigned Opcode) { - return BuildMI(getMF(), {getDL(), getPCSections(), getMMRAMetadata()}, - getTII().get(Opcode)); + return BuildMI( + getMF(), + {getDL(), getPCSections(), getMMRAMetadata(), getDeactivationSymbol()}, + getTII().get(Opcode)); } MachineInstrBuilder MachineIRBuilder::insertInstr(MachineInstrBuilder MIB) { diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp index da3665b3b6a0b..23b601b08b900 100644 --- a/llvm/lib/CodeGen/MachineInstr.cpp +++ b/llvm/lib/CodeGen/MachineInstr.cpp @@ -120,7 +120,7 @@ MachineInstr::MachineInstr(MachineFunction &MF, const MCInstrDesc &TID, MachineInstr::MachineInstr(MachineFunction &MF, const MachineInstr &MI) : MCID(&MI.getDesc()), NumOperands(0), Flags(0), AsmPrinterFlags(0), Info(MI.Info), DbgLoc(MI.getDebugLoc()), DebugInstrNum(0), - Opcode(MI.getOpcode()) { + Opcode(MI.getOpcode()), DeactivationSymbol(MI.getDeactivationSymbol()) { assert(DbgLoc.hasTrivialDestructor() && "Expected trivial destructor"); CapOperands = OperandCapacity::get(MI.getNumOperands()); @@ -728,6 +728,8 @@ bool MachineInstr::isIdenticalTo(const MachineInstr &Other, // Call instructions with different CFI types are not identical. if (isCall() && getCFIType() != Other.getCFIType()) return false; + if (getDeactivationSymbol() != Other.getDeactivationSymbol()) + return false; return true; } @@ -2029,6 +2031,8 @@ void MachineInstr::print(raw_ostream &OS, ModuleSlotTracker &MST, OS << ','; OS << " cfi-type " << CFIType; } + if (getDeactivationSymbol()) + OS << ", deactivation-symbol " << getDeactivationSymbol()->getName(); if (DebugInstrNum) { if (!FirstOp) diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp index 03d3e8eab35d0..9c49b2492f8b1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp @@ -15,10 +15,12 @@ #include "InstrEmitter.h" #include "SDNodeDbgValue.h" #include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetLowering.h" @@ -61,6 +63,8 @@ static unsigned countOperands(SDNode *Node, unsigned NumExpUses, unsigned N = Node->getNumOperands(); while (N && Node->getOperand(N - 1).getValueType() == MVT::Glue) --N; + if (N && Node->getOperand(N - 1).getOpcode() == ISD::DEACTIVATION_SYMBOL) + --N; // Ignore deactivation symbol if it exists. if (N && Node->getOperand(N - 1).getValueType() == MVT::Other) --N; // Ignore chain if it exists. @@ -1219,15 +1223,23 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned, } } - if (SDNode *GluedNode = Node->getGluedNode()) { - // FIXME: Possibly iterate over multiple glue nodes? - if (GluedNode->getOpcode() == - ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { - Register VReg = getVR(GluedNode->getOperand(0), VRBaseMap); - MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, - /*isImp=*/true); - MIB->addOperand(MO); - } + unsigned Op = Node->getNumOperands(); + if (Op != 0 && Node->getOperand(Op - 1)->getOpcode() == + ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { + Register VReg = getVR(Node->getOperand(Op - 1)->getOperand(0), VRBaseMap); + MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, + /*isImp=*/true); + MIB->addOperand(MO); + Op--; + } + + if (Op != 0 && + Node->getOperand(Op - 1)->getOpcode() == ISD::DEACTIVATION_SYMBOL) { + MI->setDeactivationSymbol( + *MF, const_cast( + cast(Node->getOperand(Op - 1)) + ->getGlobal())); + Op--; } // Run post-isel target hook to adjust this instruction if needed. @@ -1248,7 +1260,8 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned, llvm_unreachable("This target-independent node should have been selected!"); case ISD::EntryToken: case ISD::MERGE_VALUES: - case ISD::TokenFactor: // fall thru + case ISD::TokenFactor: + case ISD::DEACTIVATION_SYMBOL: break; case ISD::CopyToReg: { Register DestReg = cast(Node->getOperand(1))->getReg(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index c1356239ad206..e39f908beb1ab 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1917,6 +1917,21 @@ SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL, return SDValue(N, 0); } +SDValue SelectionDAG::getDeactivationSymbol(const GlobalValue *GV) { + SDVTList VTs = getVTList(MVT::Untyped); + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::DEACTIVATION_SYMBOL, VTs, {}); + ID.AddPointer(GV); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) + return SDValue(E, 0); + + auto *N = newSDNode(GV, VTs); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) { unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex; SDVTList VTs = getVTList(VT); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index ecd1ff87e7fbc..75569083734c1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -45,6 +45,7 @@ #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/SelectionDAGTargetInfo.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/SwiftErrorValueTracking.h" @@ -5360,6 +5361,13 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, // Create the node. SDValue Result; + if (auto Bundle = I.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + auto *Sym = Bundle->Inputs[0].get(); + SDValue SDSym = getValue(Sym); + SDSym = DAG.getDeactivationSymbol(cast(Sym)); + Ops.push_back(SDSym); + } + if (auto Bundle = I.getOperandBundle(LLVMContext::OB_convergencectrl)) { auto *Token = Bundle->Inputs[0].get(); SDValue ConvControlToken = getValue(Token); @@ -8947,6 +8955,11 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, ConvControlToken = getValue(Token); } + GlobalValue *DeactivationSymbol = nullptr; + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + DeactivationSymbol = cast(Bundle->Inputs[0].get()); + } + TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(getCurSDLoc()) .setChain(getRoot()) @@ -8956,7 +8969,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, .setIsPreallocated( CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0) .setCFIType(CFIType) - .setConvergenceControlToken(ConvControlToken); + .setConvergenceControlToken(ConvControlToken) + .setDeactivationSymbol(DeactivationSymbol); // Set the pointer authentication info if we have it. if (PAI) { @@ -9572,7 +9586,8 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) { {LLVMContext::OB_deopt, LLVMContext::OB_funclet, LLVMContext::OB_cfguardtarget, LLVMContext::OB_preallocated, LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_kcfi, - LLVMContext::OB_convergencectrl})) + LLVMContext::OB_convergencectrl, + LLVMContext::OB_deactivation_symbol})) reportFatalUsageError("cannot lower calls with arbitrary operand bundles!"); SDValue Callee = getValue(I.getCalledOperand()); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index 26071ed70c9db..ff01019182e73 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3268,6 +3268,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case ISD::LIFETIME_START: case ISD::LIFETIME_END: case ISD::PSEUDO_PROBE: + case ISD::DEACTIVATION_SYMBOL: NodeToMatch->setNodeId(-1); // Mark selected. return; case ISD::AssertSext: @@ -3346,7 +3347,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // These are the current input chain and glue for use when generating nodes. // Various Emit operations change these. For example, emitting a copytoreg // uses and updates these. - SDValue InputChain, InputGlue; + SDValue InputChain, InputGlue, DeactivationSymbol; // ChainNodesMatched - If a pattern matches nodes that have input/output // chains, the OPC_EmitMergeInputChains operation is emitted which indicates @@ -3499,6 +3500,15 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, InputGlue = N->getOperand(N->getNumOperands()-1); continue; + case OPC_CaptureDeactivationSymbol: + // If the current node has a deactivation symbol, capture it in + // DeactivationSymbol. + if (N->getNumOperands() != 0 && + N->getOperand(N->getNumOperands() - 1).getOpcode() == + ISD::DEACTIVATION_SYMBOL) + DeactivationSymbol = N->getOperand(N->getNumOperands()-1); + continue; + case OPC_MoveChild: { unsigned ChildNo = MatcherTable[MatcherIndex++]; if (ChildNo >= N.getNumOperands()) @@ -4180,6 +4190,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // If this has chain/glue inputs, add them. if (EmitNodeInfo & OPFL_Chain) Ops.push_back(InputChain); + if (DeactivationSymbol.getNode() != nullptr) + Ops.push_back(DeactivationSymbol); if ((EmitNodeInfo & OPFL_GlueInput) && InputGlue.getNode() != nullptr) Ops.push_back(InputGlue); diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp index 57532cd491dd6..647c1661083a4 100644 --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -53,6 +53,8 @@ static StringRef knownBundleName(unsigned BundleTagID) { return "kcfi"; case LLVMContext::OB_convergencectrl: return "convergencectrl"; + case LLVMContext::OB_deactivation_symbol: + return "deactivation-symbol"; default: llvm_unreachable("unknown bundle id"); } @@ -76,7 +78,7 @@ LLVMContext::LLVMContext() : pImpl(new LLVMContextImpl(*this)) { } for (unsigned BundleTagID = LLVMContext::OB_deopt; - BundleTagID <= LLVMContext::OB_convergencectrl; ++BundleTagID) { + BundleTagID <= LLVMContext::OB_deactivation_symbol; ++BundleTagID) { [[maybe_unused]] const auto *Entry = pImpl->getOrInsertBundleTag(knownBundleName(BundleTagID)); assert(Entry->second == BundleTagID && "operand bundle id drifted!"); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index ee0485ad910e7..53ca97cb21d14 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -177,7 +177,12 @@ class AArch64AsmPrinter : public AsmPrinter { const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, - uint64_t PACDisc, Register PACAddrDisc); + uint64_t PACDisc, Register PACAddrDisc, Value *DS); + + // Emit R_AARCH64_PATCHINST, the deactivation symbol relocation. Returns true + // if no instruction should be emitted because the deactivation symbol is + // defined in the current module so this function emitted a NOP instead. + bool emitDeactivationSymbolRelocation(Value *DS); // Emit the sequence to compute the discriminator. // @@ -2078,11 +2083,30 @@ void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) { /*ShouldTrap=*/true, /*OnFailure=*/nullptr); } +bool AArch64AsmPrinter::emitDeactivationSymbolRelocation(Value *DS) { + if (DS) { + if (isa(DS)) { + // Just emit the nop directly. + EmitToStreamer(MCInstBuilder(AArch64::HINT).addImm(0)); + return true; + } + MCSymbol *Dot = OutContext.createTempSymbol(); + OutStreamer->emitLabel(Dot); + const MCExpr *DeactDotExpr = MCSymbolRefExpr::create(Dot, OutContext); + + const MCExpr *DSExpr = MCSymbolRefExpr::create( + OutContext.getOrCreateSymbol(DS->getName()), OutContext); + OutStreamer->emitRelocDirective(*DeactDotExpr, "R_AARCH64_PATCHINST", + DSExpr, SMLoc(), *TM.getMCSubtargetInfo()); + } + return false; +} + void AArch64AsmPrinter::emitPtrauthAuthResign( Register AUTVal, AArch64PACKey::ID AUTKey, uint64_t AUTDisc, const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, uint64_t PACDisc, - Register PACAddrDisc) { + Register PACAddrDisc, Value *DS) { const bool IsAUTPAC = PACKey.has_value(); // We expand AUT/AUTPAC into a sequence of the form @@ -2129,15 +2153,17 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( bool AUTZero = AUTDiscReg == AArch64::XZR; unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero); - // autiza x16 ; if AUTZero - // autia x16, x17 ; if !AUTZero - MCInst AUTInst; - AUTInst.setOpcode(AUTOpc); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - if (!AUTZero) - AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); - EmitToStreamer(*OutStreamer, AUTInst); + if (!emitDeactivationSymbolRelocation(DS)) { + // autiza x16 ; if AUTZero + // autia x16, x17 ; if !AUTZero + MCInst AUTInst; + AUTInst.setOpcode(AUTOpc); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + if (!AUTZero) + AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); + EmitToStreamer(*OutStreamer, AUTInst); + } // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done. if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) @@ -2909,6 +2935,9 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { OutStreamer->emitLabel(LOHLabel); } + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + AArch64TargetStreamer *TS = static_cast(OutStreamer->getTargetStreamer()); // Do any manual lowerings. @@ -3019,17 +3048,18 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { } case AArch64::AUTx16x17: - emitPtrauthAuthResign(AArch64::X16, - (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), - AArch64::X17, std::nullopt, 0, 0); + emitPtrauthAuthResign( + AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, + std::nullopt, 0, 0, MI->getDeactivationSymbol()); return; case AArch64::AUTxMxN: emitPtrauthAuthResign(MI->getOperand(0).getReg(), (AArch64PACKey::ID)MI->getOperand(3).getImm(), MI->getOperand(4).getImm(), &MI->getOperand(5), - MI->getOperand(1).getReg(), std::nullopt, 0, 0); + MI->getOperand(1).getReg(), std::nullopt, 0, 0, + MI->getDeactivationSymbol()); return; case AArch64::AUTPAC: @@ -3037,7 +3067,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), MI->getOperand(5).getReg()); + MI->getOperand(4).getImm(), MI->getOperand(5).getReg(), + MI->getDeactivationSymbol()); return; case AArch64::LOADauthptrstatic: diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index eca7ca566cfc2..7db4367f8a8dc 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -1535,7 +1535,10 @@ void AArch64DAGToDAGISel::SelectPtrauthAuth(SDNode *N) { extractPtrauthBlendDiscriminators(AUTDisc, CurDAG); if (!Subtarget->isX16X17Safer()) { - SDValue Ops[] = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + std::vector Ops = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + // Copy deactivation symbol if present. + if (N->getNumOperands() > 4) + Ops.push_back(N->getOperand(4)); SDNode *AUT = CurDAG->getMachineNode(AArch64::AUTxMxN, DL, MVT::i64, MVT::i64, Ops); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 6dc1b550303bb..72912d1eddf91 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9474,6 +9474,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (InGlue.getNode()) Ops.push_back(InGlue); + if (CLI.DeactivationSymbol) + Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol)); + // If we're doing a tall call, use a TC_RETURN here rather than an // actual call instruction. if (IsTailCall) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index ba7cbccc0bcd6..3cc4380138b1a 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -2402,6 +2402,7 @@ class BImm pattern> let Inst{25-0} = addr; let DecoderMethod = "DecodeUnconditionalBranch"; + let supportsDeactivationSymbol = true; } class BranchImm pattern> @@ -2459,6 +2460,7 @@ class SignAuthOneData opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = Rn; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthZero opcode_prefix, bits<2> opcode, string asm, @@ -2472,6 +2474,7 @@ class SignAuthZero opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = 0b11111; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthTwoOperand opc, string asm, diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 010d0aaa46e7f..c44cbea010ec6 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -1421,6 +1421,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } else if (Info.CFIType) { MIB->setCFIType(MF, Info.CFIType->getZExtValue()); } + MIB->setDeactivationSymbol(MF, Info.DeactivationSymbol); MIB.add(Info.Callee); diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp index d3fba6a92357a..477f25e6c0688 100644 --- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp @@ -957,6 +957,13 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, } } const EmitNodeMatcherCommon *EN = cast(N); + bool SupportsDeactivationSymbol = + EN->getInstruction().TheDef->getValueAsBit( + "supportsDeactivationSymbol"); + if (SupportsDeactivationSymbol) { + OS << "OPC_CaptureDeactivationSymbol,\n"; + OS.indent(FullIndexWidth + Indent); + } bool IsEmitNode = isa(EN); OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo"); bool CompressVTs = EN->getNumVTs() < 3; @@ -1049,8 +1056,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, OS << '\n'; } - return 4 + !CompressVTs + !CompressNodeInfo + NumTypeBytes + - NumOperandBytes + NumCoveredBytes; + return 4 + SupportsDeactivationSymbol + !CompressVTs + !CompressNodeInfo + + NumTypeBytes + NumOperandBytes + NumCoveredBytes; } case Matcher::CompleteMatch: { const CompleteMatchMatcher *CM = cast(N);