@@ -101,8 +101,7 @@ class Enzyme : public ModulePass {
101101 }
102102
103103 // / Return whether successful
104- template <typename T>
105- bool HandleAutoDiff (T *CI, TargetLibraryInfo &TLI, bool PostOpt,
104+ bool HandleAutoDiff (CallInst *CI, TargetLibraryInfo &TLI, bool PostOpt,
106105 bool fwdMode) {
107106
108107 Value *fn = CI->getArgOperand (0 );
@@ -575,9 +574,65 @@ class Enzyme : public ModulePass {
575574
576575 bool Changed = false ;
577576
577+ for (BasicBlock &BB : F)
578+ if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator ())) {
579+
580+ Function *Fn = II->getCalledFunction ();
581+
582+ #if LLVM_VERSION_MAJOR >= 11
583+ if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand ()))
584+ #else
585+ if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledValue ()))
586+ #endif
587+ {
588+ if (castinst->isCast ())
589+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 )))
590+ Fn = fn;
591+ }
592+ if (!Fn)
593+ continue ;
594+
595+ if (!(Fn->getName () == " __enzyme_float" ||
596+ Fn->getName () == " __enzyme_double" ||
597+ Fn->getName () == " __enzyme_integer" ||
598+ Fn->getName () == " __enzyme_pointer" ||
599+ Fn->getName ().contains (" __enzyme_call_inactive" ) ||
600+ Fn->getName ().contains (" __enzyme_autodiff" ) ||
601+ Fn->getName ().contains (" __enzyme_fwddiff" )))
602+ continue ;
603+
604+ SmallVector<Value *, 16 > CallArgs (II->arg_begin (), II->arg_end ());
605+ SmallVector<OperandBundleDef, 1 > OpBundles;
606+ II->getOperandBundlesAsDefs (OpBundles);
607+ // Insert a normal call instruction...
608+ #if LLVM_VERSION_MAJOR >= 8
609+ CallInst *NewCall =
610+ CallInst::Create (II->getFunctionType (), II->getCalledOperand (),
611+ CallArgs, OpBundles, " " , II);
612+ #else
613+ CallInst *NewCall =
614+ CallInst::Create (II->getFunctionType (), II->getCalledValue (),
615+ CallArgs, OpBundles, " " , II);
616+ #endif
617+ NewCall->takeName (II);
618+ NewCall->setCallingConv (II->getCallingConv ());
619+ NewCall->setAttributes (II->getAttributes ());
620+ NewCall->setDebugLoc (II->getDebugLoc ());
621+ II->replaceAllUsesWith (NewCall);
622+
623+ // Insert an unconditional branch to the normal destination.
624+ BranchInst::Create (II->getNormalDest (), II);
625+
626+ // Remove any PHI node entries from the exception destination.
627+ II->getUnwindDest ()->removePredecessor (&BB);
628+
629+ // Remove the invoke instruction now.
630+ BB.getInstList ().erase (II);
631+ Changed = true ;
632+ }
633+
578634 std::set<CallInst *> toLowerAuto;
579635 std::set<CallInst *> toLowerFwd;
580- std::set<InvokeInst *> toLowerI;
581636 std::set<CallInst *> InactiveCalls;
582637 retry:;
583638 for (BasicBlock &BB : F) {
@@ -752,15 +807,9 @@ class Enzyme : public ModulePass {
752807 }
753808 }
754809
755- bool autoDiff = Fn && (Fn->getName () == " __enzyme_autodiff" ||
756- Fn->getName () == " enzyme_autodiff_" ||
757- Fn->getName ().startswith (" __enzyme_autodiff" ) ||
758- Fn->getName ().contains (" __enzyme_autodiff" ));
810+ bool autoDiff = Fn && Fn->getName ().contains (" __enzyme_autodiff" );
759811
760- bool fwdDiff = Fn && (Fn->getName () == " __enzyme_fwddiff" ||
761- Fn->getName () == " enzyme_fwddiff_" ||
762- Fn->getName ().startswith (" __enzyme_fwddiff" ) ||
763- Fn->getName ().contains (" __enzyme_fwddiff" ));
812+ bool fwdDiff = Fn && Fn->getName ().contains (" __enzyme_fwddiff" );
764813
765814 if (autoDiff || fwdDiff) {
766815 if (autoDiff) {
@@ -845,13 +894,6 @@ class Enzyme : public ModulePass {
845894 break ;
846895 }
847896
848- for (auto CI : toLowerI) {
849- successful &= HandleAutoDiff (CI, TLI, PostOpt, /* fwdMode*/ false );
850- Changed = true ;
851- if (!successful)
852- break ;
853- }
854-
855897 if (Changed) {
856898 // TODO consider enabling when attributor does not delete
857899 // dead internal functions, which invalidates Enzyme's cache
0 commit comments