@@ -269,10 +269,15 @@ class AdjointGenerator
269269 IRBuilder<> BuilderZ (newi);
270270 Value *newip = nullptr ;
271271
272- bool needShadow = is_value_needed_in_reverse<ValueType::ShadowPtr>(
273- TR, gutils, &I,
274- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
275- oldUnreachable);
272+ // TODO: In the case of fwd mode this should be true if the loaded value
273+ // itself is used as a pointer.
274+ bool needShadow =
275+ Mode == DerivativeMode::ForwardMode
276+ ? false
277+ : is_value_needed_in_reverse<ValueType::ShadowPtr>(
278+ TR, gutils, &I,
279+ /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
280+ oldUnreachable);
276281
277282 switch (Mode) {
278283
@@ -291,8 +296,8 @@ class AdjointGenerator
291296 gutils->invertedPointers [&I] = newip;
292297 break ;
293298 }
294-
295- case DerivativeMode::ReverseModeGradient : {
299+ case DerivativeMode::ReverseModeGradient:
300+ case DerivativeMode::ForwardMode : {
296301 // only make shadow where caching needed
297302 if (can_modref && needShadow) {
298303 newip = gutils->cacheForReverse (BuilderZ, placeholder,
@@ -322,13 +327,18 @@ class AdjointGenerator
322327
323328 Value *inst = newi;
324329
330+ // TODO: In the case of fwd mode this should be true if the loaded value
331+ // itself is used as a pointer.
332+ bool primalNeededInReverse =
333+ Mode == DerivativeMode::ForwardMode
334+ ? false
335+ : is_value_needed_in_reverse<ValueType::Primal>(
336+ TR, gutils, &I,
337+ /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
338+ oldUnreachable);
325339 // ! Store loads that need to be cached for use in reverse pass
326340 if (cache_reads_always ||
327- (!cache_reads_never && can_modref &&
328- is_value_needed_in_reverse<ValueType::Primal>(
329- TR, gutils, &I,
330- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
331- oldUnreachable))) {
341+ (!cache_reads_never && can_modref && primalNeededInReverse)) {
332342 if (!gutils->unnecessaryIntermediates .count (&I)) {
333343 IRBuilder<> BuilderZ (gutils->getNewFromOriginal (&I)->getNextNode ());
334344 // auto tbaa = inst->getMetadata(LLVMContext::MD_tbaa);
@@ -379,15 +389,36 @@ class AdjointGenerator
379389 }
380390
381391 if (isfloat) {
382- IRBuilder<> Builder2 (parent);
383- getReverseBuilder (Builder2);
384- auto prediff = diffe (&I, Builder2);
385- setDiffe (&I, Constant::getNullValue (type), Builder2);
386392
387- if (!gutils->isConstantValue (I.getOperand (0 ))) {
388- ((DiffeGradientUtils *)gutils)
389- ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
390- alignment, OrigOffset);
393+ switch (Mode) {
394+ case DerivativeMode::ForwardMode: {
395+ IRBuilder<> Builder2 (&I);
396+ getForwardBuilder (Builder2);
397+
398+ if (!gutils->isConstantValue (&I)) {
399+ auto diff = Builder2.CreateLoad (
400+ gutils->invertPointerM (I.getOperand (0 ), Builder2));
401+ setDiffe (&I, diff, Builder2);
402+ }
403+ break ;
404+ }
405+ case DerivativeMode::ReverseModeGradient:
406+ case DerivativeMode::ReverseModeCombined: {
407+ IRBuilder<> Builder2 (parent);
408+ getReverseBuilder (Builder2);
409+
410+ auto prediff = diffe (&I, Builder2);
411+ setDiffe (&I, Constant::getNullValue (type), Builder2);
412+
413+ if (!gutils->isConstantValue (I.getOperand (0 ))) {
414+ ((DiffeGradientUtils *)gutils)
415+ ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
416+ alignment, OrigOffset);
417+ }
418+ break ;
419+ }
420+ case DerivativeMode::ReverseModePrimal:
421+ break ;
391422 }
392423 }
393424 }
@@ -494,8 +525,9 @@ class AdjointGenerator
494525
495526 if (FT) {
496527 // ! Only need to update the reverse function
497- if (Mode == DerivativeMode::ReverseModeGradient ||
498- Mode == DerivativeMode::ReverseModeCombined) {
528+ switch (Mode) {
529+ case DerivativeMode::ReverseModeGradient:
530+ case DerivativeMode::ReverseModeCombined: {
499531 IRBuilder<> Builder2 (SI.getParent ());
500532 getReverseBuilder (Builder2);
501533
@@ -512,13 +544,29 @@ class AdjointGenerator
512544 ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
513545 addToDiffe (orig_val, dif1, Builder2, FT);
514546 }
547+ break ;
548+ }
549+ case DerivativeMode::ForwardMode: {
550+ IRBuilder<> Builder2 (&SI);
551+ getForwardBuilder (Builder2);
552+
553+ if (constantval) {
554+ ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
555+ } else {
556+ auto diff = diffe (orig_val, Builder2);
557+
558+ ts = setPtrDiffe (orig_ptr, diff, Builder2);
559+ }
560+ break ;
561+ }
515562 }
516563
517564 // ! Storing an integer or pointer
518565 } else {
519566 // ! Only need to update the forward function
520567 if (Mode == DerivativeMode::ReverseModePrimal ||
521- Mode == DerivativeMode::ReverseModeCombined) {
568+ Mode == DerivativeMode::ReverseModeCombined ||
569+ Mode == DerivativeMode::ForwardMode) {
522570 IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&SI));
523571
524572 Value *valueop = nullptr ;
@@ -935,25 +983,12 @@ class AdjointGenerator
935983 setDiffe (&IVI, Constant::getNullValue (IVI.getType ()), Builder2);
936984 }
937985
938- inline void getReverseBuilder (IRBuilder<> &Builder2, bool original = true ) {
939- BasicBlock *BB = Builder2.GetInsertBlock ();
940- if (original)
941- BB = gutils->getNewFromOriginal (BB);
942- BasicBlock *BB2 = gutils->reverseBlocks [BB].back ();
943- if (!BB2) {
944- llvm::errs () << " oldFunc: " << *gutils->oldFunc << " \n " ;
945- llvm::errs () << " newFunc: " << *gutils->newFunc << " \n " ;
946- llvm::errs () << " could not invert " << *BB;
947- }
948- assert (BB2);
949-
950- if (BB2->getTerminator ())
951- Builder2.SetInsertPoint (BB2->getTerminator ());
952- else
953- Builder2.SetInsertPoint (BB2);
954- Builder2.SetCurrentDebugLocation (
955- gutils->getNewFromOriginal (Builder2.getCurrentDebugLocation ()));
956- Builder2.setFastMathFlags (getFast ());
986+ void getReverseBuilder (IRBuilder<> &Builder2, bool original = true ) {
987+ ((GradientUtils *)gutils)->getReverseBuilder (Builder2, original);
988+ }
989+
990+ void getForwardBuilder (IRBuilder<> &Builder2) {
991+ ((GradientUtils *)gutils)->getForwardBuilder (Builder2);
957992 }
958993
959994 Value *diffe (Value *val, IRBuilder<> &Builder) {
@@ -1398,19 +1433,7 @@ class AdjointGenerator
13981433
13991434 void createBinaryOperatorDual (llvm::BinaryOperator &BO) {
14001435 IRBuilder<> Builder2 (&BO);
1401-
1402- Instruction *nBO = gutils->getNewFromOriginal (&BO);
1403-
1404- assert (nBO);
1405- assert (nBO->getNextNode ());
1406-
1407- if (nBO->getNextNode ()) {
1408- Builder2.SetInsertPoint (nBO->getNextNode ());
1409- }
1410-
1411- Builder2.SetCurrentDebugLocation (
1412- gutils->getNewFromOriginal (Builder2.getCurrentDebugLocation ()));
1413- Builder2.setFastMathFlags (getFast ());
1436+ getForwardBuilder (Builder2);
14141437
14151438 Value *orig_op0 = BO.getOperand (0 );
14161439 Value *orig_op1 = BO.getOperand (1 );
0 commit comments