@@ -608,51 +608,75 @@ class AdjointGenerator
608608 I.getOpcode () == CastInst::CastOps::PtrToInt)
609609 return ;
610610
611- if (Mode == DerivativeMode::ReverseModePrimal)
611+ switch (Mode) {
612+ case DerivativeMode::ReverseModePrimal: {
612613 return ;
614+ }
615+ case DerivativeMode::ReverseModeGradient:
616+ case DerivativeMode::ReverseModeCombined: {
617+ Value *orig_op0 = I.getOperand (0 );
618+ Value *op0 = gutils->getNewFromOriginal (orig_op0);
613619
614- Value *orig_op0 = I. getOperand ( 0 );
615- Value *op0 = gutils-> getNewFromOriginal (orig_op0 );
620+ IRBuilder<> Builder2 (I. getParent () );
621+ getReverseBuilder (Builder2 );
616622
617- IRBuilder<> Builder2 (I. getParent ());
618- getReverseBuilder ( Builder2);
623+ if (!gutils-> isConstantValue (orig_op0)) {
624+ Value *dif = diffe (&I, Builder2);
619625
620- if (!gutils->isConstantValue (orig_op0)) {
621- Value *dif = diffe (&I, Builder2);
622-
623- size_t size = 1 ;
624- if (orig_op0->getType ()->isSized ())
625- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
626- orig_op0->getType ()) +
627- 7 ) /
628- 8 ;
629- Type *FT = TR.addingType (size, orig_op0);
630- if (!FT) {
631- llvm::errs () << " " << *gutils->oldFunc << " \n " ;
632- TR.dump ();
633- llvm::errs () << " " << *orig_op0 << " \n " ;
626+ size_t size = 1 ;
627+ if (orig_op0->getType ()->isSized ())
628+ size =
629+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
630+ orig_op0->getType ()) +
631+ 7 ) /
632+ 8 ;
633+ Type *FT = TR.addingType (size, orig_op0);
634+ if (!FT) {
635+ llvm::errs () << " " << *gutils->oldFunc << " \n " ;
636+ TR.dump ();
637+ llvm::errs () << " " << *orig_op0 << " \n " ;
638+ }
639+ assert (FT);
640+ if (I.getOpcode () == CastInst::CastOps::FPTrunc ||
641+ I.getOpcode () == CastInst::CastOps::FPExt) {
642+ addToDiffe (orig_op0, Builder2.CreateFPCast (dif, op0->getType ()),
643+ Builder2, FT);
644+ } else if (I.getOpcode () == CastInst::CastOps::BitCast) {
645+ addToDiffe (orig_op0, Builder2.CreateBitCast (dif, op0->getType ()),
646+ Builder2, FT);
647+ } else if (I.getOpcode () == CastInst::CastOps::Trunc) {
648+ // TODO CHECK THIS
649+ auto trunced = Builder2.CreateZExt (dif, op0->getType ());
650+ addToDiffe (orig_op0, trunced, Builder2, FT);
651+ } else {
652+ TR.dump ();
653+ llvm::errs () << *I.getParent ()->getParent () << " \n "
654+ << *I.getParent () << " \n " ;
655+ llvm::errs () << " cannot handle above cast " << I << " \n " ;
656+ report_fatal_error (" unknown instruction" );
657+ }
634658 }
635- assert (FT);
636- if (I.getOpcode () == CastInst::CastOps::FPTrunc ||
637- I.getOpcode () == CastInst::CastOps::FPExt) {
638- addToDiffe (orig_op0, Builder2.CreateFPCast (dif, op0->getType ()),
639- Builder2, FT);
640- } else if (I.getOpcode () == CastInst::CastOps::BitCast) {
641- addToDiffe (orig_op0, Builder2.CreateBitCast (dif, op0->getType ()),
642- Builder2, FT);
643- } else if (I.getOpcode () == CastInst::CastOps::Trunc) {
644- // TODO CHECK THIS
645- auto trunced = Builder2.CreateZExt (dif, op0->getType ());
646- addToDiffe (orig_op0, trunced, Builder2, FT);
659+ setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
660+
661+ break ;
662+ }
663+ case DerivativeMode::ForwardMode: {
664+ Value *orig_op0 = I.getOperand (0 );
665+
666+ IRBuilder<> Builder2 (&I);
667+ getForwardBuilder (Builder2);
668+
669+ if (!gutils->isConstantValue (orig_op0)) {
670+ Value *dif = diffe (orig_op0, Builder2);
671+ setDiffe (&I, Builder2.CreateCast (I.getOpcode (), dif, I.getType ()),
672+ Builder2);
647673 } else {
648- TR.dump ();
649- llvm::errs () << *I.getParent ()->getParent () << " \n "
650- << *I.getParent () << " \n " ;
651- llvm::errs () << " cannot handle above cast " << I << " \n " ;
652- report_fatal_error (" unknown instruction" );
674+ setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
653675 }
676+
677+ break ;
678+ }
654679 }
655- setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
656680 }
657681
658682 void visitSelectInst (llvm::SelectInst &SI) {
0 commit comments