Skip to content

Commit 02bb48b

Browse files
authored
Add asserting vh (#1238)
* Add asserting vh * fix * fix * Add capi gutils erase/replace
1 parent 949cb65 commit 02bb48b

File tree

9 files changed

+47
-27
lines changed

9 files changed

+47
-27
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5378,7 +5378,7 @@ class AdjointGenerator
53785378
gutils->getNewFromOriginal(call.getDebugLoc()));
53795379
BuilderZ.SetInsertPoint(
53805380
gutils->getNewFromOriginal(&call)->getNextNode());
5381-
gutils->getNewFromOriginal(&call)->eraseFromParent();
5381+
gutils->erase(gutils->getNewFromOriginal(&call));
53825382
} else {
53835383
assert(0 && "unhandled unknown outline");
53845384
}

enzyme/Enzyme/CApi.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,15 @@ void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) {
304304
return (void *)&G->TR.analyzer;
305305
}
306306

307+
void EnzymeGradientUtilsErase(GradientUtils *G, LLVMValueRef I) {
308+
return G->erase(cast<Instruction>(unwrap(I)));
309+
}
310+
311+
void EnzymeGradientUtilsReplaceAWithB(GradientUtils *G, LLVMValueRef A,
312+
LLVMValueRef B) {
313+
return G->replaceAWithB(unwrap(A), unwrap(B));
314+
}
315+
307316
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
308317
CustomShadowFree FHandle) {
309318
shadowHandlers[std::string(Name)] =

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ DiffeGradientUtils::DiffeGradientUtils(
5454
ValueToValueMapTy &invertedPointers_,
5555
const SmallPtrSetImpl<Value *> &constantvalues_,
5656
const SmallPtrSetImpl<Value *> &returnvals_, DIFFE_TYPE ActiveReturn,
57-
ArrayRef<DIFFE_TYPE> constant_values, ValueToValueMapTy &origToNew_,
57+
ArrayRef<DIFFE_TYPE> constant_values,
58+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
5859
DerivativeMode mode, unsigned width, bool omp)
5960
: GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
6061
constantvalues_, returnvals_, ActiveReturn, constant_values,
@@ -90,7 +91,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
9091
SmallPtrSet<Instruction *, 4> constants;
9192
SmallPtrSet<Instruction *, 20> nonconstant;
9293
SmallPtrSet<Value *, 2> returnvals;
93-
ValueToValueMapTy originalToNew;
94+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> originalToNew;
9495

9596
SmallPtrSet<Value *, 4> constant_values;
9697
SmallPtrSet<Value *, 4> nonconstant_values;

enzyme/Enzyme/DiffeGradientUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ class DiffeGradientUtils final : public GradientUtils {
6767
const llvm::SmallPtrSetImpl<llvm::Value *> &constantvalues_,
6868
const llvm::SmallPtrSetImpl<llvm::Value *> &returnvals_,
6969
DIFFE_TYPE ActiveReturn, llvm::ArrayRef<DIFFE_TYPE> constant_values,
70-
llvm::ValueToValueMapTy &origToNew_, DerivativeMode mode, unsigned width,
71-
bool omp);
70+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
71+
DerivativeMode mode, unsigned width, bool omp);
7272

7373
public:
7474
/// Whether to free memory in reverse pass or split forward.

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,9 +1499,11 @@ bool legalCombinedForwardReverse(
14991499
return;
15001500
}
15011501
// Do not try moving an instruction that modifies memory, if we already
1502-
// moved it
1502+
// moved it. We need the originalToNew check because we may have deleted
1503+
// the instruction, which wont require the failed to move.
15031504
if (!isa<StoreInst>(I) || unnecessaryInstructions.count(I) == 0)
15041505
if (I->mayReadOrWriteMemory() &&
1506+
gutils->originalToNewFn.find(I) != gutils->originalToNewFn.end() &&
15051507
gutils->getNewFromOriginal(I)->getParent() !=
15061508
gutils->getNewFromOriginal(I->getParent())) {
15071509
legal = false;
@@ -2944,7 +2946,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
29442946
ei->replaceAllUsesWith(rep);
29452947
ei->eraseFromParent();
29462948
}
2947-
user->eraseFromParent();
2949+
if (user->getParent()->getParent() == gutils->newFunc)
2950+
gutils->erase(user);
2951+
else
2952+
user->eraseFromParent();
29482953
} else {
29492954
user->setCalledFunction(NewF);
29502955
}
@@ -4185,7 +4190,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
41854190
DerivativeMode::ReverseModeCombined);
41864191
}
41874192
if (newBB->getTerminator())
4188-
newBB->getTerminator()->eraseFromParent();
4193+
gutils->erase(newBB->getTerminator());
41894194
IRBuilder<> builder(newBB);
41904195
builder.CreateUnreachable();
41914196
continue;

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,7 +2105,8 @@ Function *PreProcessCache::CloneFunctionWithReturns(
21052105
ValueToValueMapTy &ptrInputs, ArrayRef<DIFFE_TYPE> constant_args,
21062106
SmallPtrSetImpl<Value *> &constants, SmallPtrSetImpl<Value *> &nonconstant,
21072107
SmallPtrSetImpl<Value *> &returnvals, ReturnType returnValue,
2108-
DIFFE_TYPE returnType, Twine name, ValueToValueMapTy *VMapO,
2108+
DIFFE_TYPE returnType, Twine name,
2109+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
21092110
bool diffeReturnArg, llvm::Type *additionalArg) {
21102111
assert(!F->empty());
21112112
F = preprocessForClone(F, mode);
@@ -2169,7 +2170,9 @@ Function *PreProcessCache::CloneFunctionWithReturns(
21692170
#endif
21702171
CloneOrigin[NewF] = F;
21712172
if (VMapO) {
2172-
VMapO->insert(VMap.begin(), VMap.end());
2173+
for (const auto &data : VMap)
2174+
VMapO->insert(std::pair<const llvm::Value *, AssertingReplacingVH>(
2175+
data.first, (llvm::Value *)data.second));
21732176
VMapO->getMDMap() = VMap.getMDMap();
21742177
}
21752178

enzyme/Enzyme/FunctionUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class PreProcessCache {
8484
llvm::SmallPtrSetImpl<llvm::Value *> &constants,
8585
llvm::SmallPtrSetImpl<llvm::Value *> &nonconstant,
8686
llvm::SmallPtrSetImpl<llvm::Value *> &returnvals, ReturnType returnValue,
87-
DIFFE_TYPE returnType, llvm::Twine name, llvm::ValueToValueMapTy *VMapO,
87+
DIFFE_TYPE returnType, llvm::Twine name,
88+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> *VMapO,
8889
bool diffeReturnArg, llvm::Type *additionalArg = nullptr);
8990

9091
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,15 @@ static bool isPotentialLastLoopValue(llvm::Value *val,
158158
return false;
159159
}
160160

161-
GradientUtils::GradientUtils(EnzymeLogic &Logic, Function *newFunc_,
162-
Function *oldFunc_, TargetLibraryInfo &TLI_,
163-
TypeAnalysis &TA_, TypeResults TR_,
164-
ValueToValueMapTy &invertedPointers_,
165-
const SmallPtrSetImpl<Value *> &constantvalues_,
166-
const SmallPtrSetImpl<Value *> &activevals_,
167-
DIFFE_TYPE ReturnActivity,
168-
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
169-
ValueToValueMapTy &originalToNewFn_,
170-
DerivativeMode mode, unsigned width, bool omp)
161+
GradientUtils::GradientUtils(
162+
EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_,
163+
TargetLibraryInfo &TLI_, TypeAnalysis &TA_, TypeResults TR_,
164+
ValueToValueMapTy &invertedPointers_,
165+
const SmallPtrSetImpl<Value *> &constantvalues_,
166+
const SmallPtrSetImpl<Value *> &activevals_, DIFFE_TYPE ReturnActivity,
167+
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
168+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &originalToNewFn_,
169+
DerivativeMode mode, unsigned width, bool omp)
171170
: CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
172171
invertedPointers(),
173172
OrigDT(Logic.PPC.FAM.getResult<llvm::DominatorTreeAnalysis>(*oldFunc_)),
@@ -592,7 +591,8 @@ BasicBlock *GradientUtils::getOriginalFromNew(const BasicBlock *newinst) const {
592591
assert(newinst->getParent() == newFunc);
593592
auto found = newToOriginalFn.find(newinst);
594593
assert(found != newToOriginalFn.end());
595-
return cast<BasicBlock>(found->second);
594+
Value *res = found->second;
595+
return cast<BasicBlock>(res);
596596
}
597597

598598
Value *GradientUtils::isOriginal(const Value *newinst) const {
@@ -4327,7 +4327,7 @@ GradientUtils *GradientUtils::CreateFromClone(
43274327
SmallPtrSet<Instruction *, 4> constants;
43284328
SmallPtrSet<Instruction *, 20> nonconstant;
43294329
SmallPtrSet<Value *, 2> returnvals;
4330-
ValueToValueMapTy originalToNew;
4330+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> originalToNew;
43314331

43324332
SmallPtrSet<Value *, 4> constant_values;
43334333
SmallPtrSet<Value *, 4> nonconstant_values;

enzyme/Enzyme/GradientUtils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ class GradientUtils : public CacheUtility {
169169
llvm::SmallPtrSet<llvm::Instruction *, 4> TapesToPreventRecomputation;
170170

171171
llvm::ValueMap<llvm::PHINode *, llvm::WeakTrackingVH> fictiousPHIs;
172-
llvm::ValueToValueMapTy originalToNewFn;
173-
llvm::ValueToValueMapTy newToOriginalFn;
172+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> originalToNewFn;
173+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> newToOriginalFn;
174174
llvm::SmallVector<llvm::CallInst *, 4> originalCalls;
175175

176176
llvm::SmallPtrSet<llvm::Instruction *, 4> unnecessaryIntermediates;
@@ -372,8 +372,9 @@ class GradientUtils : public CacheUtility {
372372
const llvm::SmallPtrSetImpl<llvm::Value *> &activevals_,
373373
DIFFE_TYPE ReturnActivity,
374374
llvm::ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
375-
llvm::ValueToValueMapTy &originalToNewFn_, DerivativeMode mode,
376-
unsigned width, bool omp);
375+
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH>
376+
&originalToNewFn_,
377+
DerivativeMode mode, unsigned width, bool omp);
377378

378379
public:
379380
DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const;

0 commit comments

Comments
 (0)