Skip to content

Commit cc89bab

Browse files
committed
Fix custom allocation handler
1 parent bb12649 commit cc89bab

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

enzyme/Enzyme/GradientUtils.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -689,14 +689,18 @@ class GradientUtils : public CacheUtility {
689689
if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) !=
690690
shadowHandlers.end()) {
691691
bb.SetInsertPoint(placeholder);
692-
Value *anti = shadowHandlers[orig->getCalledFunction()->getName().str()](
693-
bb, orig, args);
692+
Value *anti = placeholder;
694693

695-
invertedPointers.erase(found);
696-
bb.SetInsertPoint(placeholder);
694+
if (mode != DerivativeMode::ReverseModeGradient) {
695+
anti = shadowHandlers[orig->getCalledFunction()->getName().str()](
696+
bb, orig, args);
697+
698+
invertedPointers.erase(found);
699+
bb.SetInsertPoint(placeholder);
697700

698-
replaceAWithB(placeholder, anti);
699-
erase(placeholder);
701+
replaceAWithB(placeholder, anti);
702+
erase(placeholder);
703+
}
700704

701705
if (auto inst = dyn_cast<Instruction>(anti))
702706
bb.SetInsertPoint(inst);

0 commit comments

Comments
 (0)