File tree Expand file tree Collapse file tree 1 file changed +10
-6
lines changed
Expand file tree Collapse file tree 1 file changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -689,14 +689,18 @@ class GradientUtils : public CacheUtility {
689689 if (shadowHandlers.find (orig->getCalledFunction ()->getName ().str ()) !=
690690 shadowHandlers.end ()) {
691691 bb.SetInsertPoint (placeholder);
692- Value *anti = shadowHandlers[orig->getCalledFunction ()->getName ().str ()](
693- bb, orig, args);
692+ Value *anti = placeholder;
694693
695- invertedPointers.erase (found);
696- bb.SetInsertPoint (placeholder);
694+ if (mode != DerivativeMode::ReverseModeGradient) {
695+ anti = shadowHandlers[orig->getCalledFunction ()->getName ().str ()](
696+ bb, orig, args);
697+
698+ invertedPointers.erase (found);
699+ bb.SetInsertPoint (placeholder);
697700
698- replaceAWithB (placeholder, anti);
699- erase (placeholder);
701+ replaceAWithB (placeholder, anti);
702+ erase (placeholder);
703+ }
700704
701705 if (auto inst = dyn_cast<Instruction>(anti))
702706 bb.SetInsertPoint (inst);
You can’t perform that action at this time.
0 commit comments