Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[InstCombine] Implement processSMulSExtIdiom for Signed Multiplication Overflow Detection #131461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
Loading
from

Conversation

DTeachs
Copy link

@DTeachs DTeachs commented Mar 15, 2025

Alive-2 Link:
https://alive2.llvm.org/ce/z/aX4UoK

Summary

This pull request introduces the processSMulSExtIdiom function to the LLVM InstCombine pass, which recognizes and processes idioms involving tests for signed multiplication overflow. The function replaces the multiplication operation with a call to the llvm.smul.with.overflow intrinsic when specific patterns are detected in comparison instructions.

Features

  • New Functionality: The processSMulSExtIdiom function is designed to handle cases where a signed multiplication operation is followed by an addition and a comparison, allowing for more efficient overflow detection.
  • Pattern Recognition: The implementation identifies specific patterns in the form of I = cmp u add (mul(sext A, sext B), V, W) and replaces them with the appropriate intrinsic call.
  • Type Handling: The function calculates the necessary types and widths for the multiplication operation, ensuring that the transformation is valid and does not break existing code.

Added Tests

  • Extensive tests have been added to verify the correctness of the new functionality, including various scenarios for signed multiplication overflow checks.

Implementation Notes

The processSMulSExtIdiom function was derived from the existing processUMulZExtIdiom function, with modifications made to accommodate the specifics of signed multiplication overflow detection.

@llvmbot
Copy link
Member

llvmbot commented Mar 15, 2025

@llvm/pr-subscribers-llvm-transforms

Author: None (DTeachs)

Changes

Summary

This pull request introduces the processSMulSExtIdiom function to the LLVM InstCombine pass, which recognizes and processes idioms involving tests for signed multiplication overflow. The function replaces the multiplication operation with a call to the llvm.smul.with.overflow intrinsic when specific patterns are detected in comparison instructions.

Features

  • New Functionality: The processSMulSExtIdiom function is designed to handle cases where a signed multiplication operation is followed by an addition and a comparison, allowing for more efficient overflow detection.
  • Pattern Recognition: The implementation identifies specific patterns in the form of I = cmp u add (mul(sext A, sext B), V, W) and replaces them with the appropriate intrinsic call.
  • Type Handling: The function calculates the necessary types and widths for the multiplication operation, ensuring that the transformation is valid and does not break existing code.

Added Tests

  • Extensive tests have been added to verify the correctness of the new functionality, including various scenarios for signed multiplication overflow checks.

Implementation Notes

The processSMulSExtIdiom function was derived from the existing processUMulZExtIdiom function, with modifications made to accommodate the specifics of signed multiplication overflow detection.


Patch is 23.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131461.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+179)
  • (modified) llvm/test/Transforms/InstCombine/overflow-mul.ll (+390)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 76020d2b1dbf4..cfbfd6c969a8c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6526,6 +6526,172 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
   return ExtractValueInst::Create(Call, 1);
 }
 
+/// Recognize and process idiom involving test for multiplication
+/// overflow.
+///
+/// The caller has matched a pattern of the form:
+///   I = cmp u add (mul(sext A, sext B), V, W
+/// The function checks if this is a test for overflow and if so replaces
+/// multiplication with call to 'mul.with.overflow' intrinsic.
+///
+/// \param I Compare instruction.
+/// \param MulVal Result of 'mult' instruction.  It is one of the arguments of
+///               the compare instruction.  Must be of integer type.
+/// \param OtherVal The other argument of compare instruction.
+/// \returns Instruction which must replace the compare instruction, NULL if no
+///          replacement required.
+static Instruction *processSMulSExtIdiom(ICmpInst &I, Value *MulVal,
+                                         const APInt *AddVal,
+                                         const APInt *OtherVal,
+                                         InstCombinerImpl &IC) {
+  // Don't bother doing this transformation for pointers, don't do it for
+  // vectors.
+  if (!isa<IntegerType>(MulVal->getType()))
+    return nullptr;
+
+  auto *MulInstr = dyn_cast<Instruction>(MulVal);
+  if (!MulInstr)
+    return nullptr;
+  assert(MulInstr->getOpcode() == Instruction::Mul);
+
+  auto *LHS = cast<SExtInst>(MulInstr->getOperand(0)),
+       *RHS = cast<SExtInst>(MulInstr->getOperand(1));
+  assert(LHS->getOpcode() == Instruction::SExt);
+  assert(RHS->getOpcode() == Instruction::SExt);
+  Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
+
+  // Calculate type and width of the result produced by mul.with.overflow.
+  Type *TyA = A->getType(), *TyB = B->getType();
+  unsigned WidthA = TyA->getPrimitiveSizeInBits(),
+           WidthB = TyB->getPrimitiveSizeInBits();
+  unsigned MulWidth;
+  Type *MulType;
+  if (WidthB > WidthA) {
+    MulWidth = WidthB;
+    MulType = TyB;
+  } else {
+    MulWidth = WidthA;
+    MulType = TyA;
+  }
+
+  // In order to replace the original mul with a narrower mul.with.overflow,
+  // all uses must ignore upper bits of the product.  The number of used low
+  // bits must be not greater than the width of mul.with.overflow.
+  if (MulVal->hasNUsesOrMore(2))
+    for (User *U : MulVal->users()) {
+      if (U == &I)
+        continue;
+      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+        // Check if truncation ignores bits above MulWidth.
+        unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
+        if (TruncWidth > MulWidth)
+          return nullptr;
+      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+        // Check if AND ignores bits above MulWidth.
+        if (BO->getOpcode() != Instruction::And)
+          return nullptr;
+        if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+          const APInt &CVal = CI->getValue();
+          if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
+            return nullptr;
+        } else {
+          // In this case we could have the operand of the binary operation
+          // being defined in another block, and performing the replacement
+          // could break the dominance relation.
+          return nullptr;
+        }
+      } else {
+        // Other uses prohibit this transformation.
+        return nullptr;
+      }
+    }
+
+  // Recognize patterns
+  bool IsInverse = false;
+  switch (I.getPredicate()) {
+  case ICmpInst::ICMP_ULT: {
+    // Recognize pattern:
+    //   mulval = mul(sext A, sext B)
+    //   addval = add (mulval, min)
+    //   cmp ult addval, -min * 2 + 1
+    APInt MinVal = APInt::getSignedMinValue(MulWidth);
+    MinVal = MinVal.sext(OtherVal->getBitWidth());
+    APInt MinMinVal = APInt::getSignedMinValue(MulWidth + 1);
+    MinMinVal = MinMinVal.sext(OtherVal->getBitWidth());
+    if (MinVal.eq(*AddVal) && MinMinVal.eq(*OtherVal))
+      break; // Recognized
+
+    // Recognize pattern:
+    //   mulval = mul(sext A, sext B)
+    //   addval = add (mulval, signedMax)
+    //   cmp ult addval, unsignedMax
+    APInt MaxVal = APInt::getSignedMaxValue(MulWidth);
+    MaxVal = MaxVal.zext(OtherVal->getBitWidth()) + 1;
+    APInt MaxMaxVal = APInt::getMaxValue(MulWidth);
+    MaxMaxVal = MaxMaxVal.zext(OtherVal->getBitWidth()) + 1;
+    if (MaxVal.eq(*AddVal) && MaxMaxVal.eq(*OtherVal)) {
+      IsInverse = true;
+      break; // Recognized
+    }
+    return nullptr;
+  }
+
+  default:
+    return nullptr;
+  }
+
+  InstCombiner::BuilderTy &Builder = IC.Builder;
+  Builder.SetInsertPoint(MulInstr);
+
+  // Replace: mul(sext A, sext B) --> mul.with.overflow(A, B)
+  Value *MulA = A, *MulB = B;
+  if (WidthA < MulWidth)
+    MulA = Builder.CreateSExt(A, MulType);
+  if (WidthB < MulWidth)
+    MulB = Builder.CreateSExt(B, MulType);
+  Function *F = Intrinsic::getOrInsertDeclaration(
+      I.getModule(), Intrinsic::smul_with_overflow, MulType);
+  CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "smul");
+  IC.addToWorklist(MulInstr);
+
+  // If there are uses of mul result other than the comparison, we know that
+  // they are truncation or binary AND. Change them to use result of
+  // mul.with.overflow and adjust properly mask/size.
+  if (MulVal->hasNUsesOrMore(2)) {
+    Value *Mul = Builder.CreateExtractValue(Call, 0, "smul.value");
+    for (User *U : make_early_inc_range(MulVal->users())) {
+      if (U == &I)
+        continue;
+      if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
+        if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
+          IC.replaceInstUsesWith(*TI, Mul);
+        else
+          TI->setOperand(0, Mul);
+      } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
+        assert(BO->getOpcode() == Instruction::And);
+        // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
+        ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
+        APInt ShortMask = CI->getValue().trunc(MulWidth);
+        Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
+        Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
+        IC.replaceInstUsesWith(*BO, Zext);
+      } else {
+        llvm_unreachable("Unexpected Binary operation");
+      }
+      IC.addToWorklist(cast<Instruction>(U));
+    }
+  }
+
+  // The original icmp gets replaced with the overflow value, maybe inverted
+  // depending on predicate.
+  if (IsInverse) {
+    Value *Res = Builder.CreateExtractValue(Call, 1);
+    return BinaryOperator::CreateNot(Res);
+  }
+
+  return ExtractValueInst::Create(Call, 1);
+}
+
 /// When performing a comparison against a constant, it is possible that not all
 /// the bits in the LHS are demanded. This helper method computes the mask that
 /// IS demanded.
@@ -7651,6 +7817,19 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
         return R;
     }
 
+    // (sext X) * (sext Y)  --> llvm.smul.with.overflow.1  -->
+    // llvm.smul.with.overflow. Detects patterns for signed multiplication
+    // overflow checks where the result is adjusted by a constant and then
+    // compared against another constant.
+    const APInt *C1;
+    if (match(Op0, m_Add(m_NSWMul(m_SExt(m_Value(X)), m_SExt(m_Value(Y))),
+                         m_APInt(C))) &&
+        match(Op1, m_APInt(C1))) {
+      if (Instruction *R = processSMulSExtIdiom(
+              I, cast<Instruction>(Op0)->getOperand(0), C, C1, *this))
+        return R;
+    }
+
     // Signbit test folds
     // Fold (X u>> BitWidth - 1 Pred ZExt(i1))  -->  X s< 0 Pred i1
     // Fold (X s>> BitWidth - 1 Pred SExt(i1))  -->  X s< 0 Pred i1
diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll
index 1d18d9ffd46d2..d1a5813ffba8a 100644
--- a/llvm/test/Transforms/InstCombine/overflow-mul.ll
+++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll
@@ -12,6 +12,7 @@ target datalayout = "i32:8:8"
 ; The mask is no longer in the form 2^n-1  and this prevents the transformation.
 
 declare void @use.i64(i64)
+declare void @use.i32(i32)
 
 ; return mul(zext x, zext y) > MAX
 define i32 @pr4917_1(i32 %x, i32 %y) nounwind {
@@ -343,3 +344,392 @@ define i32 @extra_and_use_mask_too_large(i32 %x, i32 %y) {
   %retval = zext i1 %overflow to i32
   ret i32 %retval
 }
+
+define i32 @smul(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul(
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[CONV3]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %conv3 = zext i1 %2 to i32
+  ret i32 %conv3
+}
+
+define i32 @smul2(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul2(
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i32, i1 } @llvm.smul.with.overflow.i32(i32 [[B:%.*]], i32 [[A:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i32, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], true
+; CHECK-NEXT:    [[CONV3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[CONV3]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %cmp = icmp sle i64 %mul, 2147483647
+  %cmp2 = icmp sge i64 %mul, -2147483648
+  %1 = select i1 %cmp, i1 %cmp2, i1 false
+  %conv3 = zext i1 %1 to i32
+  ret i32 %conv3
+}
+
+define i1 @smul_sext_add_pattern(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_pattern(
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i8, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = xor i1 [[TMP1]], true
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_wrong_constants(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_wrong_constants(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[MUL]], 58
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 42
+  %cmp = icmp slt i32 %add, 100
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_eq_predicate(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_eq_predicate(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[MUL]], 128
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 128
+  %cmp = icmp eq i32 %add, 256
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_different_widths(i4 %a, i16 %b) {
+; CHECK-LABEL: @smul_sext_add_different_widths(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i4 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i16 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], 128
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD]], 256
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i4 %a to i32
+  %b.ext = sext i16 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_no_nsw(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_no_nsw(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], 128
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD]], 256
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul i32 %a.ext, %b.ext  ; No nsw flag
+  %add = add i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  ret i1 %cmp
+}
+
+define <2 x i1> @smul_sext_add_vector(<2 x i8> %a, <2 x i8> %b) {
+; CHECK-LABEL: @smul_sext_add_vector(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext <2 x i8> [[B:%.*]] to <2 x i32>
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw <2 x i32> [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw <2 x i32> [[MUL]], splat (i32 128)
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[ADD]], splat (i32 256)
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %a.ext = sext <2 x i8> %a to <2 x i32>
+  %b.ext = sext <2 x i8> %b to <2 x i32>
+  %mul = mul nsw <2 x i32> %a.ext, %b.ext
+  %add = add <2 x i32> %mul, <i32 128, i32 128>
+  %cmp = icmp ult <2 x i32> %add, <i32 256, i32 256>
+  ret <2 x i1> %cmp
+}
+
+define i1 @smul_sext_add_negative2(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_negative2(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[MUL]], 128
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %cmp = icmp ult i32 %mul, 128
+  %add = add i32 %mul, 128
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_multiple_uses(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_multiple_uses(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], 128
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[ADD]], 256
+; CHECK-NEXT:    call void @use.i32(i32 [[MUL]])
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  call void @use.i32(i32 %mul)
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_extreme_constants(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_extreme_constants(
+; CHECK-NEXT:    ret i1 false
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add i32 %mul, 2147483647  ; INT_MAX
+  %cmp = icmp slt i32 %add, -2147483648  ; INT_MIN
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_nsw(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_nsw(
+; CHECK-NEXT:    [[SMUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = extractvalue { i8, i1 } [[SMUL]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = xor i1 [[TMP1]], true
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add nsw i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  ret i1 %cmp
+}
+
+define i1 @smul_sext_add_nuw_negative(i8 %a, i8 %b) {
+; CHECK-LABEL: @smul_sext_add_nuw_negative(
+; CHECK-NEXT:    [[A_EXT:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[B_EXT:%.*]] = sext i8 [[B:%.*]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[A_EXT]], [[B_EXT]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[MUL]], 128
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %a.ext = sext i8 %a to i32
+  %b.ext = sext i8 %b to i32
+  %mul = mul nsw i32 %a.ext, %b.ext
+  %add = add nuw i32 %mul, 128
+  %cmp = icmp ult i32 %add, 256
+  ret i1 %cmp
+}
+
+define i32 @smul_extra_and_use(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul_extra_and_use(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT:    [[AND:%.*]] = and i64 [[MUL]], 4294967295
+; CHECK-NEXT:    call void @use.i64(i64 [[AND]])
+; CHECK-NEXT:    [[RETVAL:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[RETVAL]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %and = and i64 %mul, 4294967295
+  call void @use.i64(i64 %and)
+  %retval = zext i1 %2 to i32
+  ret i32 %retval
+}
+
+define i32 @smul_extra_trunc_use(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul_extra_trunc_use(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i64 [[MUL]] to i32
+; CHECK-NEXT:    call void @use.i32(i32 [[TRUNC]])
+; CHECK-NEXT:    [[RETVAL:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[RETVAL]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %trunc = trunc i64 %mul to i32
+  call void @use.i32(i32 %trunc)
+  %retval = zext i1 %2 to i32
+  ret i32 %retval
+}
+
+define i32 @smul_extra_and_use_small_mask(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul_extra_and_use_small_mask(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT:    [[AND:%.*]] = and i64 [[MUL]], 268435455
+; CHECK-NEXT:    call void @use.i64(i64 [[AND]])
+; CHECK-NEXT:    [[RETVAL:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[RETVAL]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %and = and i64 %mul, u0xfffffff
+  call void @use.i64(i64 %and)
+  %retval = zext i1 %2 to i32
+  ret i32 %retval
+}
+
+define i32 @smul_multiple_uses(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul_multiple_uses(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], -4294967296
+; CHECK-NEXT:    [[AND:%.*]] = and i64 [[MUL]], 4294967295
+; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i64 [[MUL]] to i32
+; CHECK-NEXT:    call void @use.i64(i64 [[AND]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TRUNC]])
+; CHECK-NEXT:    [[RETVAL:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[RETVAL]]
+;
+  %conv = sext i32 %a to i64
+  %conv1 = sext i32 %b to i64
+  %mul = mul nsw i64 %conv1, %conv
+  %1 = add nsw i64 %mul, -2147483648
+  %2 = icmp ult i64 %1, -4294967296
+  %and = and i64 %mul, 4294967295
+  %trunc = trunc i64 %mul to i32
+  call void @use.i64(i64 %and)
+  call void @use.i32(i32 %trunc)
+  %retval = zext i1 %2 to i32
+  ret i32 %retval
+}
+
+define i32 @smul_extra_and_use_mask_too_large(i32 %a, i32 %b) {
+; CHECK-LABEL: @smul_extra_and_use_mask_too_large(
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[MUL]], -2147483648
+; CHECK-NEXT:    [[TMP2:%....
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@DTeachs DTeachs force-pushed the overflow-mul branch 6 times, most recently from e86eaeb to 6d37890 Compare March 15, 2025 16:30
Tiaunna added 2 commits March 15, 2025 14:00
…n Overflow Detection

Alive2 Proof:
https://alive2.llvm.org/ce/z/aX4UoK

Summary
This pull request introduces the processSMulSExtIdiom function to the LLVM InstCombine pass, which recognizes and processes idioms involving tests for signed multiplication overflow. The function replaces the multiplication operation with a call to the llvm.smul.with.overflow intrinsic when specific patterns are detected in comparison instructions.

Features
New Functionality: The processSMulSExtIdiom function is designed to handle cases where a signed multiplication operation is followed by an addition and a comparison, allowing for more efficient overflow detection.

Pattern Recognition: The implementation identifies specific patterns in the form of I = cmp u add (mul(sext A, sext B), V, W) and replaces them with the appropriate intrinsic call.

Type Handling: The function calculates the necessary types and widths for the multiplication operation, ensuring that the transformation is valid and does not break existing code.

Added Tests:
Extensive tests have been added to verify the correctness of the new functionality, including various scenarios for signed multiplication overflow checks.

Note:
The processSMulSExtIdiom function was derived from the existing processUMulZExtIdiom function, with modifications made to accommodate the specifics of signed multiplication overflow detection.
@DTeachs
Copy link
Author

DTeachs commented Mar 16, 2025

@dtcxzyw Ping?

@topperc
Copy link
Collaborator

topperc commented Mar 17, 2025

To add to the bigger picture here. Historically the backends had problems supporting smul.overflow. SelectionDAG could only handle smul.overflow on types that fit in a legal register size. Or maybe twice the size of a legal register. There were also some cases that would call a library function that only exists in compiler-rt and not libgcc.

umul.overflow was better supported in the backend.

I don't think any of the broken cases were reachable from software using builtins like __builtin_smul_overflow and the middle end would never create smul.overflow.

35fa7b8 from 3.5 years ago started creating smul.overlow in the middle end. That spawned these patches to prevent calls to non-existant library functions: 124bcc1, c1a31ee, 39e5dd1, d0eeb64, e9b3f25, d8b6ae0, c8c176d, 5c91b98. I created b2ca4dc to avoid crashes on large types, but it generates poor code. Earlier this year I created e30a4fc to improve the generated code for larger types.

So the situation in the backend has approved at lot in the last 3.5 years, but it is possible this patch causes someone to find a new backend issue.

@DTeachs
Copy link
Author

DTeachs commented May 14, 2025

@topperc Well, if that is the case, I can bear responsibility for it.

@DTeachs DTeachs requested a review from topperc May 14, 2025 22:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.