|
30 | 30 | #include <set> |
31 | 31 |
|
32 | 32 | #include "DifferentialUseAnalysis.h" |
| 33 | +#include "Utils.h" |
33 | 34 |
|
34 | 35 | #include "llvm/IR/BasicBlock.h" |
35 | 36 | #include "llvm/IR/Instruction.h" |
@@ -721,8 +722,13 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( |
721 | 722 | return false; |
722 | 723 | } |
723 | 724 |
|
724 | | - bool neededFB = !gutils->isConstantInstruction(user) || |
725 | | - !gutils->isConstantValue(const_cast<Instruction *>(user)); |
| 725 | + bool neededFB = false; |
| 726 | + if (auto CB = dyn_cast<CallBase>(const_cast<Instruction *>(user))) { |
| 727 | + neededFB = !callShouldNotUseDerivative(gutils, *CB); |
| 728 | + } else { |
| 729 | + neededFB = !gutils->isConstantInstruction(user) || |
| 730 | + !gutils->isConstantValue(const_cast<Instruction *>(user)); |
| 731 | + } |
726 | 732 | if (neededFB) { |
727 | 733 | if (EnzymePrintDiffUse) |
728 | 734 | llvm::errs() << " Need direct primal of " << *val |
@@ -960,3 +966,149 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, |
960 | 966 | } |
961 | 967 | return; |
962 | 968 | } |
| 969 | + |
| 970 | +bool DifferentialUseAnalysis::callShouldNotUseDerivative( |
| 971 | + const GradientUtils *gutils, CallBase &call) { |
| 972 | + bool shadowReturnUsed = false; |
| 973 | + auto smode = gutils->mode; |
| 974 | + if (smode == DerivativeMode::ReverseModeGradient) |
| 975 | + smode = DerivativeMode::ReverseModePrimal; |
| 976 | + (void)gutils->getReturnDiffeType(&call, nullptr, &shadowReturnUsed, smode); |
| 977 | + |
| 978 | + bool useConstantFallback = |
| 979 | + gutils->isConstantInstruction(&call) && |
| 980 | + (gutils->isConstantValue(&call) || !shadowReturnUsed); |
| 981 | + if (useConstantFallback && gutils->mode != DerivativeMode::ForwardMode && |
| 982 | + gutils->mode != DerivativeMode::ForwardModeError) { |
| 983 | + // if there is an escaping allocation, which is deduced needed in |
| 984 | + // reverse pass, we need to do the recursive procedure to perform the |
| 985 | + // free. |
| 986 | + |
| 987 | + // First test if the return is a potential pointer and needed for the |
| 988 | + // reverse pass |
| 989 | + bool escapingNeededAllocation = false; |
| 990 | + |
| 991 | + if (!isNoEscapingAllocation(&call)) { |
| 992 | + escapingNeededAllocation = EnzymeGlobalActivity; |
| 993 | + |
| 994 | + std::map<UsageKey, bool> CacheResults; |
| 995 | + for (auto pair : gutils->knownRecomputeHeuristic) { |
| 996 | + if (!pair.second || gutils->unnecessaryIntermediates.count( |
| 997 | + cast<Instruction>(pair.first))) { |
| 998 | + CacheResults[UsageKey(pair.first, QueryType::Primal)] = false; |
| 999 | + } |
| 1000 | + } |
| 1001 | + |
| 1002 | + if (!escapingNeededAllocation && |
| 1003 | + !(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) { |
| 1004 | + if (gutils->TR.anyPointer(&call)) { |
| 1005 | + auto found = gutils->knownRecomputeHeuristic.find(&call); |
| 1006 | + if (found != gutils->knownRecomputeHeuristic.end()) { |
| 1007 | + if (!found->second) { |
| 1008 | + CacheResults.erase(UsageKey(&call, QueryType::Primal)); |
| 1009 | + escapingNeededAllocation = |
| 1010 | + DifferentialUseAnalysis::is_value_needed_in_reverse< |
| 1011 | + QueryType::Primal>(gutils, &call, |
| 1012 | + DerivativeMode::ReverseModeGradient, |
| 1013 | + CacheResults, gutils->notForAnalysis); |
| 1014 | + } |
| 1015 | + } else { |
| 1016 | + escapingNeededAllocation = |
| 1017 | + DifferentialUseAnalysis::is_value_needed_in_reverse< |
| 1018 | + QueryType::Primal>(gutils, &call, |
| 1019 | + DerivativeMode::ReverseModeGradient, |
| 1020 | + CacheResults, gutils->notForAnalysis); |
| 1021 | + } |
| 1022 | + } |
| 1023 | + } |
| 1024 | + |
| 1025 | + // Next test if any allocation could be stored into one of the |
| 1026 | + // arguments. |
| 1027 | + if (!escapingNeededAllocation) |
| 1028 | +#if LLVM_VERSION_MAJOR >= 14 |
| 1029 | + for (unsigned i = 0; i < call.arg_size(); ++i) |
| 1030 | +#else |
| 1031 | + for (unsigned i = 0; i < call.getNumArgOperands(); ++i) |
| 1032 | +#endif |
| 1033 | + { |
| 1034 | + Value *a = call.getOperand(i); |
| 1035 | + |
| 1036 | + if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType())) |
| 1037 | + continue; |
| 1038 | + |
| 1039 | + if (!gutils->TR.anyPointer(a)) |
| 1040 | + continue; |
| 1041 | + |
| 1042 | + auto vd = gutils->TR.query(a); |
| 1043 | + |
| 1044 | + if (!vd[{-1, -1}].isPossiblePointer()) |
| 1045 | + continue; |
| 1046 | + |
| 1047 | + if (isReadOnly(&call, i)) |
| 1048 | + continue; |
| 1049 | + |
| 1050 | + // An allocation could only be needed in the reverse pass if it |
| 1051 | + // escapes into an argument. However, is the parameter by which it |
| 1052 | + // escapes could capture the pointer, the rest of Enzyme's caching |
| 1053 | + // mechanisms cannot assume that the allocation itself is |
| 1054 | + // reloadable, since it may have been captured and overwritten |
| 1055 | + // elsewhere. |
| 1056 | + // TODO: this justification will need revisiting in the future as |
| 1057 | + // the caching algorithm becomes increasingly sophisticated. |
| 1058 | + if (!isNoCapture(&call, i)) |
| 1059 | + continue; |
| 1060 | + |
| 1061 | + escapingNeededAllocation = true; |
| 1062 | + } |
| 1063 | + } |
| 1064 | + |
| 1065 | + // If desired this can become even more aggressive by looking through the |
| 1066 | + // called function for any allocations. |
| 1067 | + if (auto F = getFunctionFromCall(&call)) { |
| 1068 | + SmallVector<Function *, 1> todo = {F}; |
| 1069 | + SmallPtrSet<Function *, 1> done; |
| 1070 | + bool seenAllocation = false; |
| 1071 | + while (todo.size() && !seenAllocation) { |
| 1072 | + auto cur = todo.pop_back_val(); |
| 1073 | + if (done.count(cur)) |
| 1074 | + continue; |
| 1075 | + done.insert(cur); |
| 1076 | + // assume empty functions allocate. |
| 1077 | + if (cur->empty()) { |
| 1078 | + // unless they are marked |
| 1079 | + if (isNoEscapingAllocation(cur)) |
| 1080 | + continue; |
| 1081 | + seenAllocation = true; |
| 1082 | + break; |
| 1083 | + } |
| 1084 | + auto UR = getGuaranteedUnreachable(cur); |
| 1085 | + for (auto &BB : *cur) { |
| 1086 | + if (UR.count(&BB)) |
| 1087 | + continue; |
| 1088 | + for (auto &I : BB) |
| 1089 | + if (auto CB = dyn_cast<CallBase>(&I)) { |
| 1090 | + if (isNoEscapingAllocation(CB)) |
| 1091 | + continue; |
| 1092 | + if (isAllocationCall(CB, gutils->TLI)) { |
| 1093 | + seenAllocation = true; |
| 1094 | + goto finish; |
| 1095 | + } |
| 1096 | + if (auto F = getFunctionFromCall(CB)) { |
| 1097 | + todo.push_back(F); |
| 1098 | + continue; |
| 1099 | + } |
| 1100 | + // Conservatively assume indirect functions allocate. |
| 1101 | + seenAllocation = true; |
| 1102 | + goto finish; |
| 1103 | + } |
| 1104 | + } |
| 1105 | + finish:; |
| 1106 | + } |
| 1107 | + if (!seenAllocation) |
| 1108 | + escapingNeededAllocation = false; |
| 1109 | + } |
| 1110 | + if (escapingNeededAllocation) |
| 1111 | + useConstantFallback = false; |
| 1112 | + } |
| 1113 | + return useConstantFallback; |
| 1114 | +} |
0 commit comments