@@ -7193,15 +7193,19 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
7193
7193
}
7194
7194
7195
7195
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7196
- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7197
- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7198
- auto *BaseLd = cast<LoadSDNode>(Elt);
7199
- if (!BaseLd->isSimple())
7200
- return false;
7196
+ static bool findEltLoadSrc(SDValue Elt, MemSDNode *&Ld, int64_t &ByteOffset) {
7197
+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7201
7198
Ld = BaseLd;
7202
7199
ByteOffset = 0;
7203
7200
return true;
7204
- }
7201
+ } else if (auto *BaseLd = dyn_cast<LoadSDNode>(Elt))
7202
+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7203
+ if (!BaseLd->isSimple())
7204
+ return false;
7205
+ Ld = BaseLd;
7206
+ ByteOffset = 0;
7207
+ return true;
7208
+ }
7205
7209
7206
7210
switch (Elt.getOpcode()) {
7207
7211
case ISD::BITCAST:
@@ -7254,7 +7258,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7254
7258
APInt ZeroMask = APInt::getZero(NumElems);
7255
7259
APInt UndefMask = APInt::getZero(NumElems);
7256
7260
7257
- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7261
+ SmallVector<MemSDNode *, 8> Loads(NumElems, nullptr);
7258
7262
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
7259
7263
7260
7264
// For each element in the initializer, see if we've found a load, zero or an
@@ -7304,7 +7308,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7304
7308
EVT EltBaseVT = EltBase.getValueType();
7305
7309
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
7306
7310
"Register/Memory size mismatch");
7307
- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7311
+ MemSDNode *LDBase = Loads[FirstLoadedElt];
7308
7312
assert(LDBase && "Did not find base load for merging consecutive loads");
7309
7313
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
7310
7314
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7318,15 +7322,18 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7318
7322
7319
7323
// Check to see if the element's load is consecutive to the base load
7320
7324
// or offset from a previous (already checked) load.
7321
- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7322
- LoadSDNode *Ld = Loads[EltIdx];
7325
+ auto CheckConsecutiveLoad = [&](MemSDNode *Base, int EltIdx) {
7326
+ MemSDNode *Ld = Loads[EltIdx];
7323
7327
int64_t ByteOffset = ByteOffsets[EltIdx];
7324
7328
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
7325
7329
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
7326
7330
return (0 <= BaseIdx && BaseIdx < (int)NumElems && LoadMask[BaseIdx] &&
7327
7331
Loads[BaseIdx] == Ld && ByteOffsets[BaseIdx] == 0);
7328
7332
}
7329
- return DAG.areNonVolatileConsecutiveLoads(Ld, Base, BaseSizeInBytes,
7333
+ auto *L = dyn_cast<LoadSDNode>(Ld);
7334
+ auto *B = dyn_cast<LoadSDNode>(Base);
7335
+ return L && B &&
7336
+ DAG.areNonVolatileConsecutiveLoads(L, B, BaseSizeInBytes,
7330
7337
EltIdx - FirstLoadedElt);
7331
7338
};
7332
7339
@@ -7347,7 +7354,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7347
7354
}
7348
7355
}
7349
7356
7350
- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7357
+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, MemSDNode *LDBase) {
7351
7358
auto MMOFlags = LDBase->getMemOperand()->getFlags();
7352
7359
assert(LDBase->isSimple() &&
7353
7360
"Cannot merge volatile or atomic loads.");
@@ -60539,6 +60546,35 @@ static SDValue combineINTRINSIC_VOID(SDNode *N, SelectionDAG &DAG,
60539
60546
return SDValue();
60540
60547
}
60541
60548
60549
+ static SDValue combineVZEXT_LOAD(SDNode *N, SelectionDAG &DAG,
60550
+ TargetLowering::DAGCombinerInfo &DCI) {
60551
+ // Find the TokenFactor to locate the associated AtomicLoad.
60552
+ SDNode *ALD = nullptr;
60553
+ for (auto &TF : N->uses())
60554
+ if (TF.getUser()->getOpcode() == ISD::TokenFactor) {
60555
+ SDValue L = TF.getUser()->getOperand(0);
60556
+ SDValue R = TF.getUser()->getOperand(1);
60557
+ if (L.getNode() == N)
60558
+ ALD = R.getNode();
60559
+ else if (R.getNode() == N)
60560
+ ALD = L.getNode();
60561
+ }
60562
+
60563
+ if (!ALD)
60564
+ return SDValue();
60565
+ if (!isa<AtomicSDNode>(ALD))
60566
+ return SDValue();
60567
+
60568
+ // Replace the VZEXT_LOAD with the AtomicLoad.
60569
+ SDLoc dl(N);
60570
+ SDValue SV =
60571
+ DAG.getNode(ISD::SCALAR_TO_VECTOR, dl,
60572
+ N->getValueType(0).changeTypeToInteger(), SDValue(ALD, 0));
60573
+ SDValue BC = DAG.getNode(ISD::BITCAST, dl, N->getValueType(0), SV);
60574
+ BC = DCI.CombineTo(N, BC, SDValue(ALD, 1));
60575
+ return BC;
60576
+ }
60577
+
60542
60578
SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
60543
60579
DAGCombinerInfo &DCI) const {
60544
60580
SelectionDAG &DAG = DCI.DAG;
@@ -60735,6 +60771,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
60735
60771
case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
60736
60772
case ISD::FP_TO_SINT_SAT:
60737
60773
case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
60774
+ case X86ISD::VZEXT_LOAD: return combineVZEXT_LOAD(N, DAG, DCI);
60738
60775
// clang-format on
60739
60776
}
60740
60777
0 commit comments