Skip to content

Commit 41f5895

Browse files
authored
Fix incorrect replacement of switch default branch if all other branches are unreachable in reverse split mode (#1381)
1 parent 70ee9d9 commit 41f5895

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,11 @@ void restoreCache(
18931893
while (cases.count(legalNot))
18941894
legalNot++;
18951895
repVal = ConstantInt::getSigned(condition->getType(), legalNot);
1896+
cast<SwitchInst>(gutils->getNewFromOriginal(si))
1897+
->setCondition(repVal);
1898+
// knowing which input was provided for the default dest is not
1899+
// possible at compile time, give up on other use replacement
1900+
continue;
18961901
} else {
18971902
for (auto c : si->cases()) {
18981903
if (c.getCaseSuccessor() == reachables[0]) {
@@ -4127,6 +4132,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
41274132
gutils->erase(newBB->getTerminator());
41284133
IRBuilder<> builder(newBB);
41294134
builder.CreateUnreachable();
4135+
41304136
continue;
41314137
}
41324138

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=0 -enzyme -mem2reg -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=0 -passes="enzyme,function(mem2reg)" -S | FileCheck %s
3+
4+
declare double @__enzyme_autodiff(i8*, ...)
5+
6+
define double @sfunc(double %x, i32 %val) {
7+
entry:
8+
switch i32 %val, label %b1 [
9+
i32 0, label %err
10+
i32 1, label %err
11+
]
12+
13+
err:
14+
unreachable
15+
16+
b1:
17+
%q = sitofp i32 %val to double
18+
%m = fmul double %q, %x
19+
ret double %m
20+
}
21+
22+
define double @outer(double %x, i32 %val) {
23+
%v = call double @sfunc(double %x, i32 %val)
24+
%m = fmul double %v, %v
25+
ret double %m
26+
}
27+
28+
define void @main(double %q, i32 %val) {
29+
entry:
30+
%call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double, i32)* @outer to i8*), double %q, i32 %val)
31+
ret void
32+
}
33+
34+
; CHECK: define internal { double } @diffesfunc(double %x, i32 %val, double %differeturn)
35+
; CHECK-NEXT: entry:
36+
; CHECK-NEXT: switch i32 2, label %b1 [
37+
; CHECK-NEXT: i32 0, label %err
38+
; CHECK-NEXT: i32 1, label %err
39+
; CHECK-NEXT: ]
40+
41+
; CHECK: err: ; preds = %entry, %entry
42+
; CHECK-NEXT: unreachable
43+
44+
; CHECK: b1: ; preds = %entry
45+
; CHECK-NEXT: %q = sitofp i32 %val to double
46+
; CHECK-NEXT: br label %invertb1
47+
48+
; CHECK: invertentry: ; preds = %invertb1
49+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %2, 0
50+
; CHECK-NEXT: ret { double } %0
51+
52+
; CHECK: invertb1: ; preds = %b1
53+
; CHECK-NEXT: %1 = fmul fast double %differeturn, %q
54+
; CHECK-NEXT: %2 = fadd fast double 0.000000e+00, %1
55+
; CHECK-NEXT: br label %invertentry
56+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)