Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold umax(nuw_mul(x, C0), x + 1) into (x == 0 ? 1 : nuw_mul(x, C0)) #123468

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Ruhung
Copy link
Contributor

@Ruhung Ruhung commented Jan 18, 2025

This PR introduces the following transformations:

  • If C0 is not 0:
    umax(nuw_shl(x, C0), x + 1) -> x == 0 ? 1 : nuw_shl(x, C0)
  • If C0 is not 0 or 1:
    umax(nuw_mul(x, C0), x + 1) -> x == 0 ? 1 : nuw_mul(x, C0)

Fixes : #122388
Alive2 proof : https://alive2.llvm.org/ce/z/rkp_8U

@llvmbot
Copy link
Member

llvmbot commented Jan 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: None (Ruhung)

Changes

This PR introduces the following transformations:

  • If C0 is not 0:
    umax(nuw_shl(x, C0), x + 1) -> x == 0 ? 1 : nuw_shl(x, C0)
  • If C0 is not 0 or 1:
    umax(nuw_mul(x, C0), x + 1) -> x == 0 ? 1 : nuw_mul(x, C0)

Fixes : #122388
Alive2 proof : https://alive2.llvm.org/ce/z/rkp_8U


Full diff: https://github.com/llvm/llvm-project/pull/123468.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+31)
  • (added) llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll (+300)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c55c40c88bc845..0db7d8818fbd0b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1847,6 +1847,37 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
       }
     }
+    // If C0 is not 0:
+    //   umax(nuw_shl(x, C0), x + 1) -> x == 0 ? 1 : nuw_shl(x, C0)
+    // If C0 is not 0 or 1:
+    //   umax(nuw_mul(x, C0), x + 1) -> x == 0 ? 1 : nuw_mul(x, C0)
+    ConstantInt *C0;
+    bool isShl = false;
+    BinaryOperator *Op = nullptr;
+    auto matchShiftOrMul = [&](Value *I) {
+      if (match(I, m_OneUse(m_NUWShl(m_Value(X), m_ConstantInt(C0))))) {
+        isShl = true;
+        return true;
+      } else if (match(I, m_OneUse(m_NUWMul(m_Value(X), m_ConstantInt(C0)))) &&
+                 C0 && !C0->isOne()) {
+        isShl = false;
+        return true;
+      }
+      return false;
+    };
+    if (((matchShiftOrMul(I0) &&
+          match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) ||
+         (matchShiftOrMul(I1) &&
+          match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) &&
+        C0 && !C0->isZero()) {
+      Op = isShl ? BinaryOperator::CreateNUWShl(X, C0)
+                 : BinaryOperator::CreateNUWMul(X, C0);
+      Builder.Insert(Op);
+      Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(X->getType(), 0));
+      Value *NewSelect =
+          Builder.CreateSelect(Cmp, ConstantInt::get(X->getType(), 1), Op);
+      return replaceInstUsesWith(*II, NewSelect);
+    }
     // If both operands of unsigned min/max are sign-extended, it is still ok
     // to narrow the operation.
     [[fallthrough]];
diff --git a/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll b/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll
new file mode 100644
index 00000000000000..86ce5dc06ce031
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll
@@ -0,0 +1,300 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+; When C0 is neither 0 nor 1:
+;   umax(nuw_mul(x, C0), x + 1) is optimized to:
+;   x == 0 ? 1 : nuw_mul(x, C0)
+; When C0 is not 0:
+;   umax(nuw_shl(x, C0), x + 1) is optimized to:
+;   x == 0 ? 1 : nuw_shl(x, C0)
+
+; Positive Test Cases for `shl`
+
+define i64 @test_shl_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_shl_by_5(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_5(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 5
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 5
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+; Commuted Test Cases for `shl`
+
+define i64 @test_shl_umax_commuted(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_umax_commuted(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[SHL]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %x1, i64 %shl)
+  ret i64 %max
+}
+
+; Negative Test Cases for `shl`
+
+define i64 @test_shl_by_zero(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_by_zero(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 0
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_shl_add_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_add_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 2
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 2
+  %shl = shl nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_shl_without_nuw(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_without_nuw(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[X]], 2
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+; Multi-use Test Cases for `shl`
+declare void @use(i64)
+
+define i64 @test_shl_multi_use_add(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_add(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    call void @use(i64 [[X1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 3
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  call void @use(i64 %x1)
+  %shl = shl nuw i64 %x, 3
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_shl_multi_use_shl(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_shl(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i64 [[X]], 2
+; CHECK-NEXT:    call void @use(i64 [[SHL]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 2
+  call void @use(i64 %shl)
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_shl_multi_use_max(i64 %x) {
+; CHECK-LABEL: define i64 @test_shl_multi_use_max(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT:    call void @use(i64 [[MAX]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %shl = shl nuw i64 %x, 3
+  %max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
+  call void @use(i64 %max)
+  ret i64 %max
+}
+
+; Positive Test Cases for `mul`
+
+define i64 @test_mul_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_by_5(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_5(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i64 [[X]], 5
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[MUL]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 5
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+; Commuted Test Cases for `mul`
+
+define i64 @test_mul_max_commuted(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_max_commuted(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP2]], i64 1, i64 [[MUL]]
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %x1, i64 %mul)
+  ret i64 %max
+}
+
+; Negative Test Cases for `mul`
+
+define i64 @test_mul_by_zero(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_zero(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    ret i64 [[X1]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 0
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_by_1(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_by_1(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 1
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_add_by_2(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_add_by_2(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 2
+  %mul = mul nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_without_nuw(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_without_nuw(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[MUL:%.*]] = shl i64 [[X]], 1
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+; Multi-use Test Cases for `mul`
+
+define i64 @test_mul_multi_use_add(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_add(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    call void @use(i64 [[X1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  call void @use(i64 %x1)
+  %mul = mul nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_multi_use_mul(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_mul(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[X1:%.*]] = add i64 [[X]], 1
+; CHECK-NEXT:    [[MUL:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    call void @use(i64 [[MUL]])
+; CHECK-NEXT:    [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 2
+  call void @use(i64 %mul)
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  ret i64 %max
+}
+
+define i64 @test_mul_multi_use_max(i64 %x) {
+; CHECK-LABEL: define i64 @test_mul_multi_use_max(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[X]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[X]], 0
+; CHECK-NEXT:    [[MAX:%.*]] = select i1 [[TMP3]], i64 1, i64 [[TMP2]]
+; CHECK-NEXT:    call void @use(i64 [[MAX]])
+; CHECK-NEXT:    ret i64 [[MAX]]
+;
+  %x1 = add i64 %x, 1
+  %mul = mul nuw i64 %x, 2
+  %max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
+  call void @use(i64 %max)
+  ret i64 %max
+}

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) &&
C0 && !C0->isZero()) {
Op = isShl ? BinaryOperator::CreateNUWShl(X, C0)
: BinaryOperator::CreateNUWMul(X, C0);
Copy link
Contributor

@nikic nikic Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you reuse the original mul/shl? Then you also don't need the one-use restriction. This will also preserve the nsw flag, which is valid per https://alive2.llvm.org/ce/z/BfYZmT. (There should be a test for the nuw nsw combination.)

@nikic nikic requested a review from dtcxzyw January 18, 2025 16:57
match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) ||
(matchShiftOrMul(I1) &&
match(I0, m_OneUse(m_Add(m_Specific(X), m_One()))))) &&
C0 && !C0->isZero()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking C0 is isn't nullptr is redundant here and above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the check.

%mul = mul nuw i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
call void @use(i64 %max)
ret i64 %max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mul tests need to have non-power of two multipliers, otherwise they are actually going through the shl pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks!

const APInt *C0;
auto matchShiftOrMul = [&](Value *I) {
return ((match(I, m_OneUse(m_NUWShl(m_Value(X), m_APInt(C0))))) ||
(match(I, m_OneUse(m_NUWMul(m_Value(X), m_APInt(C0)))) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one-use restriction on mul/shl are unnecessary.

(((matchI0 = matchShiftOrMul(I0)) &&
match(I1, m_OneUse(m_Add(m_Specific(X), m_One())))) ||
(matchShiftOrMul(I1) &&
match(I0, m_OneUse(m_Add(m_Specific(X), m_One())))))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner if you structured the fold like moveNotAfterMinMax below (i.e. implement it in a lambda, call with commuted args).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missed fold: umax(x *nuw 2, x + 1) => x == 0 ? 1 : x *nuw 2
4 participants