Skip to content

Commit 44c2032

Browse files
authored
Atomicrmw undef fix (#1900)
* Atomicrmw undef fix * fixup * earlyinc * fix diffuse custreamsync * Fix cusync
1 parent 374031e commit 44c2032

File tree

4 files changed

+60
-6
lines changed

4 files changed

+60
-6
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
344344
Constant::getNullValue(gutils->getShadowType(inst.getType())),
345345
Builder2);
346346
}
347-
if (!inst.getType()->isVoidTy())
348-
gutils->getNewFromOriginal(&inst)->replaceAllUsesWith(
349-
UndefValue::get(inst.getType()));
347+
#if LLVM_VERSION_MAJOR >= 12
348+
if (!inst.getType()->isVoidTy()) {
349+
for (auto &U :
350+
make_early_inc_range(gutils->getNewFromOriginal(&inst)->uses())) {
351+
U.set(UndefValue::get(inst.getType()));
352+
}
353+
}
354+
#endif
350355
eraseIfUnused(inst, /*erase*/ true, /*check*/ false);
351356
return;
352357
}
@@ -920,8 +925,14 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
920925
setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())),
921926
BuilderZ);
922927
}
923-
gutils->replaceAWithB(gutils->getNewFromOriginal(&I),
924-
UndefValue::get(I.getType()));
928+
#if LLVM_VERSION_MAJOR >= 12
929+
if (!I.getType()->isVoidTy()) {
930+
for (auto &U :
931+
make_early_inc_range(gutils->getNewFromOriginal(&I)->uses())) {
932+
U.set(UndefValue::get(I.getType()));
933+
}
934+
}
935+
#endif
925936
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
926937
return;
927938
}

enzyme/Enzyme/DifferentialUseAnalysis.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,15 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse(
525525

526526
if (!shadow) {
527527

528+
// Need the primal request in reverse.
529+
if (funcName == "cuStreamSynchronize")
530+
if (val == CI->getArgOperand(0)) {
531+
if (EnzymePrintDiffUse)
532+
llvm::errs() << " Need: primal(" << to_string(qtype) << ") of "
533+
<< *val << " in reverse for cuda sync " << *CI << "\n";
534+
return true;
535+
}
536+
528537
// Only need the primal request.
529538
if (funcName == "MPI_Wait" || funcName == "PMPI_Wait")
530539
if (val != CI->getArgOperand(0))

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8978,7 +8978,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) {
89788978
// Check that the replacement doesn't already exist in the mapping
89798979
// thereby resulting in a conflict.
89808980
#ifndef NDEBUG
8981-
{
8981+
if (!isa<UndefValue>(B)) {
89828982
auto found = newToOriginalFn.find(A);
89838983
if (found != newToOriginalFn.end()) {
89848984
auto foundB = newToOriginalFn.find(B);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -ge 12 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi
2+
; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s; fi
3+
4+
declare void @__enzyme_reverse(...)
5+
6+
declare i64 @foo() "enzyme_inactive" "enzyme_nofree" "enzyme_no_escaping_allocation"
7+
8+
declare void @cuStreamSynchronize(i64)
9+
10+
define void @square() {
11+
entry:
12+
%z = call i64 @foo()
13+
call void @cuStreamSynchronize(i64 %z)
14+
ret void
15+
}
16+
17+
define void @dsquare() {
18+
entry:
19+
tail call void (...) @__enzyme_reverse(void ()* nonnull @square, i8* null)
20+
ret void
21+
}
22+
23+
; CHECK: define internal void @diffesquare(i8* %tapeArg)
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i64*
26+
; CHECK-NEXT: %z = load i64, i64* %0
27+
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
28+
; CHECK-NEXT: br label %invertentry
29+
30+
; CHECK: invertentry: ; preds = %entry
31+
; CHECK-NEXT: call void @cuStreamSynchronize(i64 %z)
32+
; CHECK-NEXT: ret void
33+
; CHECK-NEXT: }
34+

0 commit comments

Comments
 (0)