@@ -73,6 +73,12 @@ using namespace llvm;
7373
7474enum class DerivativeMode { Forward, Reverse, Both };
7575
76+ #include " llvm-c/Core.h"
77+
78+ extern std::map<std::string, std::function<llvm::Value *(
79+ IRBuilder<> &, CallInst *, ArrayRef<Value *>)>>
80+ shadowHandlers;
81+
7682static inline std::string to_string (DerivativeMode mode) {
7783 switch (mode) {
7884 case DerivativeMode::Forward:
@@ -525,6 +531,22 @@ class GradientUtils : public CacheUtility {
525531 for (unsigned i = 0 ; i < orig->getNumArgOperands (); ++i) {
526532 args.push_back (getNewFromOriginal (orig->getArgOperand (i)));
527533 }
534+
535+ if (shadowHandlers.find (orig->getCalledFunction ()->getName ().str ()) !=
536+ shadowHandlers.end ()) {
537+ Value *anti = shadowHandlers[orig->getCalledFunction ()->getName ().str ()](
538+ bb, orig, args);
539+ invertedPointers[orig] = anti;
540+ // assert(placeholder != anti);
541+ bb.SetInsertPoint (placeholder->getNextNode ());
542+ replaceAWithB (placeholder, anti);
543+ erase (placeholder);
544+
545+ anti = cacheForReverse (bb, anti, idx);
546+ invertedPointers[orig] = anti;
547+ return anti;
548+ }
549+
528550 Value *anti =
529551 bb.CreateCall (orig->getCalledFunction (), args, orig->getName () + " 'mi" );
530552 cast<CallInst>(anti)->setAttributes (orig->getAttributes ());
@@ -575,16 +597,22 @@ class GradientUtils : public CacheUtility {
575597 *orig);
576598 }
577599 }
578- auto dst_arg = bb.CreateBitCast (
579- anti, Type::getInt8PtrTy (orig->getContext (),
580- anti->getType ()->getPointerAddressSpace ()));
600+
601+ Value *dst_arg = anti;
602+
603+ dst_arg = bb.CreateBitCast (
604+ dst_arg,
605+ Type::getInt8PtrTy (orig->getContext (),
606+ anti->getType ()->getPointerAddressSpace ()));
607+
581608 auto val_arg = ConstantInt::get (Type::getInt8Ty (orig->getContext ()), 0 );
582609 Value *size;
583610 // todo check if this memset is legal and if a write barrier is needed
584- if (orig->getCalledFunction ()->getName () == " julia.gc_alloc_obj" )
611+ if (orig->getCalledFunction ()->getName () == " julia.gc_alloc_obj" ) {
585612 size = args[1 ];
586- else
613+ } else {
587614 size = args[0 ];
615+ }
588616 auto len_arg =
589617 bb.CreateZExtOrTrunc (size, Type::getInt64Ty (orig->getContext ()));
590618 auto volatile_arg = ConstantInt::getFalse (orig->getContext ());
0 commit comments