@@ -8938,6 +8938,7 @@ class AdjointGenerator
89388938 anti = gutils->cacheForReverse (
89398939 bb, anti, getIndex (&call, CacheType::Shadow));
89408940 } else {
8941+ bool zeroed = false ;
89418942 auto rule = [&]() {
89428943#if LLVM_VERSION_MAJOR >= 11
89438944 Value *anti = bb.CreateCall (call.getFunctionType (),
@@ -9006,6 +9007,19 @@ class AdjointGenerator
90069007#endif
90079008 }
90089009 }
9010+ if (Mode == DerivativeMode::ReverseModeCombined ||
9011+ (Mode == DerivativeMode::ReverseModePrimal &&
9012+ forwardsShadow) ||
9013+ (Mode == DerivativeMode::ReverseModeGradient &&
9014+ backwardsShadow) ||
9015+ (Mode == DerivativeMode::ForwardModeSplit &&
9016+ backwardsShadow)) {
9017+ if (!inLoop) {
9018+ zeroKnownAllocation (bb, anti, args, funcName, gutils->TLI ,
9019+ &call);
9020+ zeroed = true ;
9021+ }
9022+ }
90099023 }
90109024 return anti;
90119025 };
@@ -9024,6 +9038,7 @@ class AdjointGenerator
90249038 else {
90259039 if (auto MD = hasMetadata (&call, " enzyme_fromstack" )) {
90269040 isAlloca = true ;
9041+ bb.SetInsertPoint (cast<Instruction>(anti));
90279042 Value *Size;
90289043 if (funcName == " malloc" )
90299044 Size = args[0 ];
@@ -9058,55 +9073,60 @@ class AdjointGenerator
90589073#if LLVM_VERSION_MAJOR >= 15
90599074 }
90609075#endif
9061- Value *replacement = bb.CreateAlloca (elTy, Size, name);
9062- if (name.size () == 0 )
9063- replacement->takeName (anti);
9064- else
9065- anti->setName (" " );
9066- auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>(
9067- MD->getOperand (0 ))
9068- ->getValue ())
9069- ->getLimitedValue ();
9070- if (Alignment) {
9076+ auto rule = [&](Value *anti) {
9077+ Value *replacement = bb.CreateAlloca (elTy, Size, name);
9078+ if (name.size () == 0 )
9079+ replacement->takeName (anti);
9080+ else
9081+ anti->setName (" " );
9082+ auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>(
9083+ MD->getOperand (0 ))
9084+ ->getValue ())
9085+ ->getLimitedValue ();
9086+ if (Alignment) {
90719087#if LLVM_VERSION_MAJOR >= 10
9072- cast<AllocaInst>(replacement)
9073- ->setAlignment (Align (Alignment));
9088+ cast<AllocaInst>(replacement)
9089+ ->setAlignment (Align (Alignment));
90749090#else
9075- cast<AllocaInst>(replacement)->setAlignment (Alignment);
9091+ cast<AllocaInst>(replacement)->setAlignment (Alignment);
90769092#endif
9077- }
9093+ }
90789094#if LLVM_VERSION_MAJOR >= 15
9079- if (call.getContext ().supportsTypedPointers ()) {
9095+ if (call.getContext ().supportsTypedPointers ()) {
90809096#endif
9081- if (anti->getType ()->getPointerElementType () != elTy)
9082- replacement = bb.CreatePointerCast (
9083- replacement,
9084- PointerType::getUnqual (
9085- anti->getType ()->getPointerElementType ()));
9097+ if (anti->getType ()->getPointerElementType () != elTy)
9098+ replacement = bb.CreatePointerCast (
9099+ replacement,
9100+ PointerType::getUnqual (
9101+ anti->getType ()->getPointerElementType ()));
90869102#if LLVM_VERSION_MAJOR >= 15
9087- }
9103+ }
90889104#endif
90899105
9090- if (int AS = cast<PointerType>(anti->getType ())
9091- ->getAddressSpace ()) {
9092- llvm::PointerType *PT;
9106+ if (int AS = cast<PointerType>(anti->getType ())
9107+ ->getAddressSpace ()) {
9108+ llvm::PointerType *PT;
90939109#if LLVM_VERSION_MAJOR >= 15
9094- if (call.getContext ().supportsTypedPointers ()) {
9110+ if (call.getContext ().supportsTypedPointers ()) {
90959111#endif
9096- PT = PointerType::get (
9097- anti->getType ()->getPointerElementType (), AS);
9112+ PT = PointerType::get (
9113+ anti->getType ()->getPointerElementType (), AS);
90989114#if LLVM_VERSION_MAJOR >= 15
9099- } else {
9100- PT = PointerType::get (anti->getContext (), AS);
9101- }
9115+ } else {
9116+ PT = PointerType::get (anti->getContext (), AS);
9117+ }
91029118#endif
9103- replacement = bb.CreateAddrSpaceCast (replacement, PT);
9104- cast<Instruction>(replacement)
9105- ->setMetadata (
9106- " enzyme_backstack" ,
9107- MDNode::get (replacement->getContext (), {}));
9108- }
9119+ replacement = bb.CreateAddrSpaceCast (replacement, PT);
9120+ cast<Instruction>(replacement)
9121+ ->setMetadata (
9122+ " enzyme_backstack" ,
9123+ MDNode::get (replacement->getContext (), {}));
9124+ }
9125+ return replacement;
9126+ };
91099127
9128+ auto replacement =
9129+ applyChainRule (call.getType (), bb, rule, anti);
91109130 gutils->replaceAWithB (cast<Instruction>(anti), replacement);
91119131 gutils->erase (cast<Instruction>(anti));
91129132 anti = replacement;
@@ -9121,13 +9141,7 @@ class AdjointGenerator
91219141 (Mode == DerivativeMode::ForwardModeSplit &&
91229142 backwardsShadow)) {
91239143 if (!inLoop) {
9124- applyChainRule (
9125- bb,
9126- [&](Value *anti) {
9127- zeroKnownAllocation (bb, anti, args, funcName,
9128- gutils->TLI , &call);
9129- },
9130- anti);
9144+ assert (zeroed);
91319145 }
91329146 }
91339147 }
0 commit comments