@@ -7093,10 +7093,115 @@ class AdjointGenerator
70937093 return ;
70947094 }
70957095
7096- bool modifyPrimal = shouldAugmentCall (orig, gutils, TR);
7097-
70987096 bool foreignFunction = called == nullptr || called->empty ();
70997097
7098+ FnTypeInfo nextTypeInfo (called);
7099+
7100+ if (called) {
7101+ nextTypeInfo = TR.getCallInfo (*orig, *called);
7102+ }
7103+
7104+ if (Mode == DerivativeMode::ForwardMode) {
7105+ IRBuilder<> Builder2 (&call);
7106+ getForwardBuilder (Builder2);
7107+
7108+ bool retUsed = subretused;
7109+
7110+ SmallVector<Value *, 8 > args;
7111+ std::vector<DIFFE_TYPE> argsInverted;
7112+ std::map<int , Type *> gradByVal;
7113+
7114+ for (unsigned i = 0 ; i < orig->getNumArgOperands (); ++i) {
7115+
7116+ auto argi = gutils->getNewFromOriginal (orig->getArgOperand (i));
7117+
7118+ #if LLVM_VERSION_MAJOR >= 9
7119+ if (orig->isByValArgument (i)) {
7120+ gradByVal[args.size ()] = orig->getParamByValType (i);
7121+ }
7122+ #endif
7123+ args.push_back (argi);
7124+
7125+ if (gutils->isConstantValue (orig->getArgOperand (i)) &&
7126+ !foreignFunction) {
7127+ argsInverted.push_back (DIFFE_TYPE::CONSTANT);
7128+ continue ;
7129+ }
7130+
7131+ auto argType = argi->getType ();
7132+
7133+ if (!argType->isFPOrFPVectorTy () &&
7134+ (TR.query (orig->getArgOperand (i)).Inner0 ().isPossiblePointer () ||
7135+ foreignFunction)) {
7136+ DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
7137+ if (argType->isPointerTy ()) {
7138+ #if LLVM_VERSION_MAJOR >= 12
7139+ auto at = getUnderlyingObject (orig->getArgOperand (i), 100 );
7140+ #else
7141+ auto at = GetUnderlyingObject (
7142+ orig->getArgOperand (i),
7143+ gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
7144+ #endif
7145+ if (auto arg = dyn_cast<Argument>(at)) {
7146+ if (constant_args[arg->getArgNo ()] == DIFFE_TYPE::DUP_NONEED) {
7147+ ty = DIFFE_TYPE::DUP_NONEED;
7148+ }
7149+ }
7150+ }
7151+ args.push_back (
7152+ gutils->invertPointerM (orig->getArgOperand (i), Builder2));
7153+ argsInverted.push_back (ty);
7154+
7155+ // Note sometimes whattype mistakenly says something should be
7156+ // constant [because composed of integer pointers alone]
7157+ assert (whatType (argType, Mode) == DIFFE_TYPE::DUP_ARG ||
7158+ whatType (argType, Mode) == DIFFE_TYPE::CONSTANT);
7159+ } else {
7160+ if (foreignFunction)
7161+ assert (!argType->isIntOrIntVectorTy ());
7162+
7163+ args.push_back (diffe (orig->getArgOperand (i), Builder2));
7164+ argsInverted.push_back (DIFFE_TYPE::DUP_ARG);
7165+ }
7166+ }
7167+
7168+ auto newcalled = gutils->Logic .CreatePrimalAndGradient (
7169+ cast<Function>(called), subretType, argsInverted, gutils->TLI ,
7170+ TR.analyzer .interprocedural , /* returnValue*/ retUsed,
7171+ /* subdretptr*/ false , DerivativeMode::ForwardMode, nullptr ,
7172+ nextTypeInfo, uncacheable_args, nullptr ,
7173+ /* AtomicAdd*/ gutils->AtomicAdd );
7174+
7175+ assert (newcalled);
7176+ FunctionType *FT = cast<FunctionType>(
7177+ cast<PointerType>(newcalled->getType ())->getElementType ());
7178+
7179+ CallInst *diffes = Builder2.CreateCall (FT, newcalled, args);
7180+ diffes->setCallingConv (orig->getCallingConv ());
7181+ diffes->setDebugLoc (gutils->getNewFromOriginal (orig->getDebugLoc ()));
7182+ #if LLVM_VERSION_MAJOR >= 9
7183+ for (auto pair : gradByVal) {
7184+ diffes->addParamAttr (
7185+ pair.first ,
7186+ Attribute::getWithByValType (diffes->getContext (), pair.second ));
7187+ }
7188+ #endif
7189+
7190+ if (!gutils->isConstantValue (&call)) {
7191+ unsigned structidx = retUsed ? 1 : 0 ;
7192+ Value *diffe = Builder2.CreateExtractValue (diffes, {structidx});
7193+ setDiffe (&call, diffe, Builder2);
7194+ }
7195+
7196+ if (!subretused) {
7197+ eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
7198+ }
7199+
7200+ return ;
7201+ }
7202+
7203+ bool modifyPrimal = shouldAugmentCall (orig, gutils, TR);
7204+
71007205 SmallVector<Value *, 8 > args;
71017206 SmallVector<Value *, 8 > pre_args;
71027207 std::vector<DIFFE_TYPE> argsInverted;
@@ -7202,12 +7307,6 @@ class AdjointGenerator
72027307 CallInst *augmentcall = nullptr ;
72037308 Value *cachereplace = nullptr ;
72047309
7205- FnTypeInfo nextTypeInfo (called);
7206-
7207- if (called) {
7208- nextTypeInfo = TR.getCallInfo (*orig, *called);
7209- }
7210-
72117310 // llvm::Optional<std::map<std::pair<Instruction*, std::string>,
72127311 // unsigned>> sub_index_map;
72137312 Optional<int > tapeIdx;
0 commit comments