-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[InstCombine] Narrow trunc(lshr) in more cases #139645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Change-Id: I8ec210209522573a97773201e08dfea8d6b9d78d
We can narrow `trunc(lshr(i32)) to i8` to `trunc(lshr(i16)) to i8` even when the bits that we are shifting in are not zero, in the cases where the MSBs of the shifted value don't actually matter and actually end up being truncated away. This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type. Motivation: libyuv, functions like ARGBToUV444Row_C(). Change-Id: I681a247eac20a4fcf68e54d4a5009f594030a387 Proof: https://alive2.llvm.org/ce/z/9Ao2aJ
@llvm/pr-subscribers-llvm-transforms Author: Usman Nadeem (UsmanNadeem) ChangesWe can narrow This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type. Proof: https://alive2.llvm.org/ce/z/9Ao2aJ Full diff: https://github.com/llvm/llvm-project/pull/139645.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d6c99366e6f00..b47a82a542bfb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -51,6 +51,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
+ if (Opc == Instruction::LShr || Opc == Instruction::AShr)
+ Res->setIsExact(I->isExact());
break;
}
case Instruction::Trunc:
@@ -319,13 +321,21 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// zero - use AmtKnownBits.getMaxValue().
uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
uint32_t BitWidth = Ty->getScalarSizeInBits();
- KnownBits AmtKnownBits =
- llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+ KnownBits AmtKnownBits = IC.computeKnownBits(I->getOperand(1), 0, CxtI);
+ APInt MaxShiftAmt = AmtKnownBits.getMaxValue();
APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
- if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
- IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ if (MaxShiftAmt.ult(BitWidth)) {
+ if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI))
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ // If the only user is a trunc then we can narrow the shift if any new
+ // MSBs are not going to be used.
+ if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
+ auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
+ if (MaxShiftAmt.ule(BitWidth - DemandedBits))
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ }
}
break;
}
diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 0f957e22ad17b..8485a01e3180a 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -5,6 +5,7 @@
; RUN: opt < %s -passes=instcombine -S -data-layout="E-p:64:64:64-p1:32:32:32-p2:64:64:64-p3:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128-n8:16:32:64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ALL,BE
; RUN: opt < %s -passes=instcombine -S -data-layout="e-p:64:64:64-p1:32:32:32-p2:64:64:64-p3:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128-n8:16:32:64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ALL,LE
+declare void @use_i8(i8)
declare void @use_i32(i32)
declare void @use_v2i32(<2 x i32>)
@@ -2041,6 +2042,101 @@ define <2 x i8> @trunc_lshr_zext_uses1(<2 x i8> %A) {
ret <2 x i8> %D
}
+define i8 @trunc_lshr_ext_halfWidth(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth(
+; ALL-NEXT: [[ADD:%.*]] = add i16 [[A:%.*]], [[B:%.*]]
+; ALL-NEXT: [[SHR:%.*]] = lshr i16 [[ADD]], [[SHIFTAMT:%.*]]
+; ALL-NEXT: [[TRUNC:%.*]] = trunc i16 [[SHR]] to i8
+; ALL-NEXT: ret i8 [[TRUNC]]
+;
+ %zext_a = zext i16 %a to i32
+ %zext_b = zext i16 %b to i32
+ %zext_shiftAmt = zext i16 %shiftAmt to i32
+ %add = add nuw nsw i32 %zext_a, %zext_b
+ %shr = lshr i32 %add, %zext_shiftAmt
+ %trunc = trunc i32 %shr to i8
+ ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_rhsRange_neg(i16 %a, i16 %b, i16 %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_rhsRange_neg(
+; ALL-NEXT: [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT: [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT: [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT: [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT: ret i8 [[TRUNC]]
+;
+ %zext_a = zext i16 %a to i32
+ %zext_b = zext i16 %b to i32
+ %zext_shiftAmt = zext i16 %shiftAmt to i32
+ %add = add nuw nsw i32 %zext_a, %zext_b
+ %shr = lshr i32 %add, %zext_shiftAmt
+ %trunc = trunc i32 %shr to i8
+ ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_twouse_neg1(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_twouse_neg1(
+; ALL-NEXT: [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT: [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT: [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT: call void @use_i32(i32 [[ADD]])
+; ALL-NEXT: [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT: ret i8 [[TRUNC]]
+;
+ %zext_a = zext i16 %a to i32
+ %zext_b = zext i16 %b to i32
+ %zext_shiftAmt = zext i16 %shiftAmt to i32
+ %add = add nuw nsw i32 %zext_a, %zext_b
+ call void @use_i32(i32 %add)
+ %shr = lshr i32 %add, %zext_shiftAmt
+ %trunc = trunc i32 %shr to i8
+ ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_twouse_neg2(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_twouse_neg2(
+; ALL-NEXT: [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT: [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT: [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT: [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT: call void @use_i32(i32 [[SHR]])
+; ALL-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT: ret i8 [[TRUNC]]
+;
+ %zext_a = zext i16 %a to i32
+ %zext_b = zext i16 %b to i32
+ %zext_shiftAmt = zext i16 %shiftAmt to i32
+ %add = add nuw nsw i32 %zext_a, %zext_b
+ %shr = lshr i32 %add, %zext_shiftAmt
+ call void @use_i32(i32 %shr)
+ %trunc = trunc i32 %shr to i8
+ ret i8 %trunc
+}
+
+; The narrowing transform only happens for integer types.
+define <2 x i8> @trunc_lshr_ext_halfWidth_vector_neg(<2 x i16> %a, <2 x i16> %b) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_vector_neg(
+; ALL-NEXT: [[ZEXT_A:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
+; ALL-NEXT: [[ZEXT_B:%.*]] = zext <2 x i16> [[B:%.*]] to <2 x i32>
+; ALL-NEXT: [[ADD:%.*]] = add nuw nsw <2 x i32> [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[ADD]], splat (i32 6)
+; ALL-NEXT: [[TRUNC:%.*]] = trunc <2 x i32> [[SHR]] to <2 x i8>
+; ALL-NEXT: ret <2 x i8> [[TRUNC]]
+;
+ %zext_a = zext <2 x i16> %a to <2 x i32>
+ %zext_b = zext <2 x i16> %b to <2 x i32>
+ %add = add nuw nsw <2 x i32> %zext_a, %zext_b
+ %shr = lshr <2 x i32> %add, <i32 6, i32 6>
+ %trunc = trunc <2 x i32> %shr to <2 x i8>
+ ret <2 x i8> %trunc
+}
+
; The following four tests sext + lshr + trunc patterns.
; PR33078
|
@@ -51,6 +51,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, | ||
Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned); | ||
Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); | ||
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS); | ||
if (Opc == Instruction::LShr || Opc == Instruction::AShr) | ||
Res->setIsExact(I->isExact()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change looks unrelated. Is it necessary to avoid regression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, some other tests regress without it.
// MSBs are not going to be used. | ||
if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) { | ||
auto DemandedBits = Trunc->getType()->getScalarSizeInBits(); | ||
if (MaxShiftAmt.ule(BitWidth - DemandedBits)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BitWidth - DemandedBits
may wrap. Both DemandedBits
and BitWidth
are only guaranteed to be less than OrigBitWidth
. But we don't know which one is larger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrote it to if ((MaxShiftAmt + DemandedBits).ule(BitWidth))
Change-Id: Ifee1fa5fe63a88ab40621c252591a0c620225d37
Change-Id: I056c38da9089b57146b33d94d88383d81a3f15de
Change-Id: I33495fe6248a9f8e7220c8a26c3beb81bb30c645
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
If we want to further generalize this optimization, we may need to add a parameter like DemandedLowBits
.
We can narrow
trunc(lshr(i32)) to i8
totrunc(lshr(i16)) to i8
even when the bits that we are shifting in are not zero, in the cases where the MSBs of the shifted value don't actually matter and actually end up being truncated away.This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type.
Motivation: libyuv, functions like ARGBToUV444Row_C().
Proof: https://alive2.llvm.org/ce/z/9Ao2aJ