Skip to content

Commit e96ccd2

Browse files
authored
Fixup diff malloc fb (#1872)
* Diffuse malloc fb * fixup * fixup
1 parent 2250522 commit e96ccd2

File tree

6 files changed

+319
-135
lines changed

6 files changed

+319
-135
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 8 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -6288,134 +6288,15 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
62886288
return;
62896289

62906290
bool useConstantFallback =
6291-
gutils->isConstantInstruction(&call) &&
6292-
(gutils->isConstantValue(&call) || !shadowReturnUsed);
6293-
if (useConstantFallback && Mode != DerivativeMode::ForwardMode &&
6294-
Mode != DerivativeMode::ForwardModeError) {
6295-
// if there is an escaping allocation, which is deduced needed in
6296-
// reverse pass, we need to do the recursive procedure to perform the
6297-
// free.
6298-
6299-
// First test if the return is a potential pointer and needed for the
6300-
// reverse pass
6301-
bool escapingNeededAllocation = false;
6302-
6303-
if (!isNoEscapingAllocation(&call)) {
6304-
escapingNeededAllocation = EnzymeGlobalActivity;
6305-
6306-
std::map<UsageKey, bool> CacheResults;
6307-
for (auto pair : gutils->knownRecomputeHeuristic) {
6308-
if (!pair.second || gutils->unnecessaryIntermediates.count(
6309-
cast<Instruction>(pair.first))) {
6310-
CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
6311-
}
6312-
}
6313-
6314-
if (!escapingNeededAllocation &&
6315-
!(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) {
6316-
if (TR.query(&call)[{-1}].isPossiblePointer()) {
6317-
auto found = gutils->knownRecomputeHeuristic.find(&call);
6318-
if (found != gutils->knownRecomputeHeuristic.end()) {
6319-
if (!found->second) {
6320-
CacheResults.erase(UsageKey(&call, QueryType::Primal));
6321-
escapingNeededAllocation =
6322-
DifferentialUseAnalysis::is_value_needed_in_reverse<
6323-
QueryType::Primal>(gutils, &call,
6324-
DerivativeMode::ReverseModeGradient,
6325-
CacheResults, oldUnreachable);
6326-
}
6327-
} else {
6328-
escapingNeededAllocation =
6329-
DifferentialUseAnalysis::is_value_needed_in_reverse<
6330-
QueryType::Primal>(gutils, &call,
6331-
DerivativeMode::ReverseModeGradient,
6332-
CacheResults, oldUnreachable);
6333-
}
6334-
}
6335-
}
6336-
6337-
// Next test if any allocation could be stored into one of the
6338-
// arguments.
6339-
if (!escapingNeededAllocation)
6340-
#if LLVM_VERSION_MAJOR >= 14
6341-
for (unsigned i = 0; i < call.arg_size(); ++i)
6342-
#else
6343-
for (unsigned i = 0; i < call.getNumArgOperands(); ++i)
6344-
#endif
6345-
{
6346-
Value *a = call.getOperand(i);
6347-
6348-
if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
6349-
continue;
6350-
6351-
auto vd = TR.query(a);
6352-
if (!vd[{-1}].isPossiblePointer())
6353-
continue;
6354-
6355-
if (!vd[{-1, -1}].isPossiblePointer())
6356-
continue;
6357-
6358-
if (isReadOnly(&call, i))
6359-
continue;
6360-
6361-
// An allocation could only be needed in the reverse pass if it
6362-
// escapes into an argument. However, is the parameter by which it
6363-
// escapes could capture the pointer, the rest of Enzyme's caching
6364-
// mechanisms cannot assume that the allocation itself is
6365-
// reloadable, since it may have been captured and overwritten
6366-
// elsewhere.
6367-
// TODO: this justification will need revisiting in the future as
6368-
// the caching algorithm becomes increasingly sophisticated.
6369-
if (!isNoCapture(&call, i))
6370-
continue;
6371-
6372-
escapingNeededAllocation = true;
6373-
}
6374-
}
6375-
6376-
// If desired this can become even more aggressive by looking through the
6377-
// called function for any allocations.
6378-
if (auto F = getFunctionFromCall(&call)) {
6379-
SmallVector<Function *, 1> todo = {F};
6380-
SmallPtrSet<Function *, 1> done;
6381-
bool seenAllocation = false;
6382-
while (todo.size() && !seenAllocation) {
6383-
auto cur = todo.pop_back_val();
6384-
if (done.count(cur))
6385-
continue;
6386-
done.insert(cur);
6387-
// assume empty functions allocate.
6388-
if (cur->empty()) {
6389-
// unless they are marked
6390-
if (isNoEscapingAllocation(cur))
6391-
continue;
6392-
seenAllocation = true;
6393-
break;
6394-
}
6395-
for (auto &BB : *cur)
6396-
for (auto &I : BB)
6397-
if (auto CB = dyn_cast<CallBase>(&I)) {
6398-
if (isNoEscapingAllocation(CB))
6399-
continue;
6400-
if (isAllocationCall(CB, gutils->TLI)) {
6401-
seenAllocation = true;
6402-
goto finish;
6403-
}
6404-
if (auto F = getFunctionFromCall(CB)) {
6405-
todo.push_back(F);
6406-
continue;
6407-
}
6408-
// Conservatively assume indirect functions allocate.
6409-
seenAllocation = true;
6410-
goto finish;
6411-
}
6412-
finish:;
6413-
}
6414-
if (!seenAllocation)
6415-
escapingNeededAllocation = false;
6291+
DifferentialUseAnalysis::callShouldNotUseDerivative(gutils, call);
6292+
if (!useConstantFallback) {
6293+
if (gutils->isConstantInstruction(&call) &&
6294+
gutils->isConstantValue(&call)) {
6295+
EmitWarning("ConstnatFallback", call,
6296+
"Call was deduced inactive but still doing differential "
6297+
"rewrite as it may escape an allocation",
6298+
call);
64166299
}
6417-
if (escapingNeededAllocation)
6418-
useConstantFallback = false;
64196300
}
64206301
if (useConstantFallback) {
64216302
if (!gutils->isConstantValue(&call)) {

enzyme/Enzyme/DifferentialUseAnalysis.cpp

Lines changed: 154 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <set>
3131

3232
#include "DifferentialUseAnalysis.h"
33+
#include "Utils.h"
3334

3435
#include "llvm/IR/BasicBlock.h"
3536
#include "llvm/IR/Instruction.h"
@@ -721,8 +722,13 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse(
721722
return false;
722723
}
723724

724-
bool neededFB = !gutils->isConstantInstruction(user) ||
725-
!gutils->isConstantValue(const_cast<Instruction *>(user));
725+
bool neededFB = false;
726+
if (auto CB = dyn_cast<CallBase>(const_cast<Instruction *>(user))) {
727+
neededFB = !callShouldNotUseDerivative(gutils, *CB);
728+
} else {
729+
neededFB = !gutils->isConstantInstruction(user) ||
730+
!gutils->isConstantValue(const_cast<Instruction *>(user));
731+
}
726732
if (neededFB) {
727733
if (EnzymePrintDiffUse)
728734
llvm::errs() << " Need direct primal of " << *val
@@ -960,3 +966,149 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI,
960966
}
961967
return;
962968
}
969+
970+
bool DifferentialUseAnalysis::callShouldNotUseDerivative(
971+
const GradientUtils *gutils, CallBase &call) {
972+
bool shadowReturnUsed = false;
973+
auto smode = gutils->mode;
974+
if (smode == DerivativeMode::ReverseModeGradient)
975+
smode = DerivativeMode::ReverseModePrimal;
976+
(void)gutils->getReturnDiffeType(&call, nullptr, &shadowReturnUsed, smode);
977+
978+
bool useConstantFallback =
979+
gutils->isConstantInstruction(&call) &&
980+
(gutils->isConstantValue(&call) || !shadowReturnUsed);
981+
if (useConstantFallback && gutils->mode != DerivativeMode::ForwardMode &&
982+
gutils->mode != DerivativeMode::ForwardModeError) {
983+
// if there is an escaping allocation, which is deduced needed in
984+
// reverse pass, we need to do the recursive procedure to perform the
985+
// free.
986+
987+
// First test if the return is a potential pointer and needed for the
988+
// reverse pass
989+
bool escapingNeededAllocation = false;
990+
991+
if (!isNoEscapingAllocation(&call)) {
992+
escapingNeededAllocation = EnzymeGlobalActivity;
993+
994+
std::map<UsageKey, bool> CacheResults;
995+
for (auto pair : gutils->knownRecomputeHeuristic) {
996+
if (!pair.second || gutils->unnecessaryIntermediates.count(
997+
cast<Instruction>(pair.first))) {
998+
CacheResults[UsageKey(pair.first, QueryType::Primal)] = false;
999+
}
1000+
}
1001+
1002+
if (!escapingNeededAllocation &&
1003+
!(EnzymeJuliaAddrLoad && isSpecialPtr(call.getType()))) {
1004+
if (gutils->TR.anyPointer(&call)) {
1005+
auto found = gutils->knownRecomputeHeuristic.find(&call);
1006+
if (found != gutils->knownRecomputeHeuristic.end()) {
1007+
if (!found->second) {
1008+
CacheResults.erase(UsageKey(&call, QueryType::Primal));
1009+
escapingNeededAllocation =
1010+
DifferentialUseAnalysis::is_value_needed_in_reverse<
1011+
QueryType::Primal>(gutils, &call,
1012+
DerivativeMode::ReverseModeGradient,
1013+
CacheResults, gutils->notForAnalysis);
1014+
}
1015+
} else {
1016+
escapingNeededAllocation =
1017+
DifferentialUseAnalysis::is_value_needed_in_reverse<
1018+
QueryType::Primal>(gutils, &call,
1019+
DerivativeMode::ReverseModeGradient,
1020+
CacheResults, gutils->notForAnalysis);
1021+
}
1022+
}
1023+
}
1024+
1025+
// Next test if any allocation could be stored into one of the
1026+
// arguments.
1027+
if (!escapingNeededAllocation)
1028+
#if LLVM_VERSION_MAJOR >= 14
1029+
for (unsigned i = 0; i < call.arg_size(); ++i)
1030+
#else
1031+
for (unsigned i = 0; i < call.getNumArgOperands(); ++i)
1032+
#endif
1033+
{
1034+
Value *a = call.getOperand(i);
1035+
1036+
if (EnzymeJuliaAddrLoad && isSpecialPtr(a->getType()))
1037+
continue;
1038+
1039+
if (!gutils->TR.anyPointer(a))
1040+
continue;
1041+
1042+
auto vd = gutils->TR.query(a);
1043+
1044+
if (!vd[{-1, -1}].isPossiblePointer())
1045+
continue;
1046+
1047+
if (isReadOnly(&call, i))
1048+
continue;
1049+
1050+
// An allocation could only be needed in the reverse pass if it
1051+
// escapes into an argument. However, is the parameter by which it
1052+
// escapes could capture the pointer, the rest of Enzyme's caching
1053+
// mechanisms cannot assume that the allocation itself is
1054+
// reloadable, since it may have been captured and overwritten
1055+
// elsewhere.
1056+
// TODO: this justification will need revisiting in the future as
1057+
// the caching algorithm becomes increasingly sophisticated.
1058+
if (!isNoCapture(&call, i))
1059+
continue;
1060+
1061+
escapingNeededAllocation = true;
1062+
}
1063+
}
1064+
1065+
// If desired this can become even more aggressive by looking through the
1066+
// called function for any allocations.
1067+
if (auto F = getFunctionFromCall(&call)) {
1068+
SmallVector<Function *, 1> todo = {F};
1069+
SmallPtrSet<Function *, 1> done;
1070+
bool seenAllocation = false;
1071+
while (todo.size() && !seenAllocation) {
1072+
auto cur = todo.pop_back_val();
1073+
if (done.count(cur))
1074+
continue;
1075+
done.insert(cur);
1076+
// assume empty functions allocate.
1077+
if (cur->empty()) {
1078+
// unless they are marked
1079+
if (isNoEscapingAllocation(cur))
1080+
continue;
1081+
seenAllocation = true;
1082+
break;
1083+
}
1084+
auto UR = getGuaranteedUnreachable(cur);
1085+
for (auto &BB : *cur) {
1086+
if (UR.count(&BB))
1087+
continue;
1088+
for (auto &I : BB)
1089+
if (auto CB = dyn_cast<CallBase>(&I)) {
1090+
if (isNoEscapingAllocation(CB))
1091+
continue;
1092+
if (isAllocationCall(CB, gutils->TLI)) {
1093+
seenAllocation = true;
1094+
goto finish;
1095+
}
1096+
if (auto F = getFunctionFromCall(CB)) {
1097+
todo.push_back(F);
1098+
continue;
1099+
}
1100+
// Conservatively assume indirect functions allocate.
1101+
seenAllocation = true;
1102+
goto finish;
1103+
}
1104+
}
1105+
finish:;
1106+
}
1107+
if (!seenAllocation)
1108+
escapingNeededAllocation = false;
1109+
}
1110+
if (escapingNeededAllocation)
1111+
useConstantFallback = false;
1112+
}
1113+
return useConstantFallback;
1114+
}

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ forEachDifferentialUser(llvm::function_ref<void(llvm::Value *)> f,
538538
}
539539
}
540540
}
541+
542+
//! Return whether or not this is a constant and should use reverse pass
543+
bool callShouldNotUseDerivative(const GradientUtils *gutils,
544+
llvm::CallBase &orig);
545+
541546
}; // namespace DifferentialUseAnalysis
542547

543548
#endif

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5922,7 +5922,7 @@ bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
59225922
if (dt != BaseType::Anything && dt != BaseType::Unknown)
59235923
return dt.isFloat();
59245924

5925-
if (val->getType()->isTokenTy())
5925+
if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
59265926
return false;
59275927
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
59285928
SmallSet<size_t, 8> offs;
@@ -5958,7 +5958,7 @@ bool TypeResults::anyPointer(Value *val) const {
59585958
auto dt = q[{-1}];
59595959
if (dt != BaseType::Anything && dt != BaseType::Unknown)
59605960
return dt == BaseType::Pointer;
5961-
if (val->getType()->isTokenTy())
5961+
if (val->getType()->isTokenTy() || val->getType()->isVoidTy())
59625962
return false;
59635963

59645964
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();

0 commit comments

Comments
 (0)