@@ -584,16 +584,16 @@ CallInst *DiffeGradientUtils::freeCache(BasicBlock *forwardPreheader,
584584
585585#if LLVM_VERSION_MAJOR >= 10
586586void DiffeGradientUtils::addToInvertedPtrDiffe (Instruction *orig,
587- Type *addingType, unsigned start ,
588- unsigned size, Value *origptr ,
589- Value *dif,
587+ Value *origVal, Type *addingType ,
588+ unsigned start, unsigned size ,
589+ Value *origptr, Value * dif,
590590 IRBuilder<> &BuilderM,
591591 MaybeAlign align, Value *mask)
592592#else
593593void DiffeGradientUtils::addToInvertedPtrDiffe (Instruction *orig,
594- Type *addingType, unsigned start ,
595- unsigned size, Value *origptr ,
596- Value *dif,
594+ Value *origVal, Type *addingType ,
595+ unsigned start, unsigned size ,
596+ Value *origptr, Value * dif,
597597 IRBuilder<> &BuilderM,
598598 unsigned align, Value *mask)
599599#endif
@@ -892,8 +892,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
892892
893893 SmallVector<Metadata *, 1 > scopeMD = {
894894 getDerivativeAliasScope (origptr, idx)};
895- if (orig )
896- if (auto MD = orig ->getMetadata (LLVMContext::MD_alias_scope)) {
895+ if (auto origValI = dyn_cast_or_null<Instruction>(origVal) )
896+ if (auto MD = origValI ->getMetadata (LLVMContext::MD_alias_scope)) {
897897 auto MDN = cast<MDNode>(MD);
898898 for (auto &o : MDN->operands ())
899899 scopeMD.push_back (o);
@@ -907,8 +907,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
907907 if (j != (ssize_t )idx)
908908 MDs.push_back (getDerivativeAliasScope (origptr, j));
909909 }
910- if (orig )
911- if (auto MD = orig ->getMetadata (LLVMContext::MD_noalias)) {
910+ if (auto origValI = dyn_cast_or_null<Instruction>(origVal) )
911+ if (auto MD = origValI ->getMetadata (LLVMContext::MD_noalias)) {
912912 auto MDN = cast<MDNode>(MD);
913913 for (auto &o : MDN->operands ())
914914 MDs.push_back (o);
@@ -918,17 +918,19 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
918918 LI->setMetadata (LLVMContext::MD_noalias, noscope);
919919 st->setMetadata (LLVMContext::MD_noalias, noscope);
920920
921- if (orig && start == 0 &&
922- size == (DL.getTypeSizeInBits (orig ->getType ()) + 7 ) / 8 ) {
923- LI-> copyMetadata (*orig, MD_ToCopy );
924- LI->setDebugLoc ( getNewFromOriginal (orig-> getDebugLoc ()) );
921+ if (origVal && isa<Instruction>(origVal) && start == 0 &&
922+ size == (DL.getTypeSizeInBits (origVal ->getType ()) + 7 ) / 8 ) {
923+ auto origValI = cast<Instruction>(origVal );
924+ LI->copyMetadata (*origValI, MD_ToCopy );
925925 unsigned int StoreData[] = {LLVMContext::MD_tbaa,
926926 LLVMContext::MD_tbaa_struct};
927927 for (auto MD : StoreData)
928- st->setMetadata (MD, orig->getMetadata (MD));
929- st->setDebugLoc (getNewFromOriginal (orig->getDebugLoc ()));
928+ st->setMetadata (MD, origValI->getMetadata (MD));
930929 }
931930
931+ LI->setDebugLoc (getNewFromOriginal (orig->getDebugLoc ()));
932+ st->setDebugLoc (getNewFromOriginal (orig->getDebugLoc ()));
933+
932934 if (align) {
933935#if LLVM_VERSION_MAJOR >= 10
934936 auto alignv = align ? align.getValue ().value () : 0 ;
@@ -987,14 +989,14 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
987989
988990#if LLVM_VERSION_MAJOR >= 10
989991void DiffeGradientUtils::addToInvertedPtrDiffe (
990- llvm::Instruction *orig, TypeTree vd, unsigned LoadSize ,
991- llvm::Value *origptr, llvm::Value *prediff, llvm::IRBuilder<> &Builder2 ,
992- MaybeAlign alignment, llvm::Value *premask)
992+ llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd ,
993+ unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
994+ llvm::IRBuilder<> &Builder2, MaybeAlign alignment, llvm::Value *premask)
993995#else
994996void DiffeGradientUtils::addToInvertedPtrDiffe (
995- llvm::Instruction *orig, TypeTree vd, unsigned LoadSize ,
996- llvm::Value *origptr, llvm::Value *prediff, llvm::IRBuilder<> &Builder2 ,
997- unsigned alignment, llvm::Value *premask)
997+ llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd ,
998+ unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
999+ llvm::IRBuilder<> &Builder2, unsigned alignment, llvm::Value *premask)
9981000#endif
9991001{
10001002
@@ -1026,12 +1028,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
10261028
10271029 if (Type *isfloat = dt.isFloat ()) {
10281030
1029- if (orig ) {
1031+ if (origVal ) {
10301032 if (start == 0 && nextStart == LoadSize) {
1031- setDiffe (orig, Constant::getNullValue (getShadowType (orig->getType ())),
1033+ setDiffe (origVal,
1034+ Constant::getNullValue (getShadowType (origVal->getType ())),
10321035 Builder2);
10331036 } else {
1034- Value *tostore = getDifferential (orig );
1037+ Value *tostore = getDifferential (origVal );
10351038
10361039 auto i8 = Type::getInt8Ty (tostore->getContext ());
10371040 if (start != 0 ) {
@@ -1074,8 +1077,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
10741077 // Masked partial type is unhanled.
10751078 if (premask)
10761079 assert (start == 0 && nextStart == LoadSize);
1077- addToInvertedPtrDiffe (orig, isfloat, start, nextStart - start, origptr ,
1078- prediff, Builder2, alignment, premask);
1080+ addToInvertedPtrDiffe (orig, origVal, isfloat, start, nextStart - start,
1081+ origptr, prediff, Builder2, alignment, premask);
10791082 }
10801083 }
10811084
0 commit comments