Skip to content

Commit a18e093

Browse files
committed
Additional julia updates
1 parent 0b2624f commit a18e093

File tree

6 files changed

+104
-26
lines changed

6 files changed

+104
-26
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424
#include "CApi.h"
2525
#include "EnzymeLogic.h"
26+
#include "LibraryFuncs.h"
2627
#include "SCEV/TargetLibraryInfo.h"
2728

2829
#include "llvm/ADT/Triple.h"
@@ -184,6 +185,23 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) {
184185
delete TA;
185186
}
186187

188+
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
189+
CustomShadowFree FHandle) {
190+
shadowHandlers[std::string(Name)] =
191+
[=](IRBuilder<> &B, CallInst *CI,
192+
ArrayRef<Value *> Args) -> llvm::Value * {
193+
SmallVector<LLVMValueRef, 3> refs;
194+
for (auto a : Args)
195+
refs.push_back(wrap(a));
196+
return unwrap(AHandle(wrap(&B), wrap(CI), Args.size(), refs.data()));
197+
};
198+
shadowErasers[std::string(Name)] = [=](IRBuilder<> &B, Value *ToFree,
199+
Function *AllocF) -> llvm::CallInst * {
200+
return cast_or_null<CallInst>(
201+
unwrap(FHandle(wrap(&B), wrap(ToFree), wrap(AllocF))));
202+
};
203+
}
204+
187205
LLVMValueRef EnzymeCreatePrimalAndGradient(
188206
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
189207
size_t constant_args_size, EnzymeTypeAnalysisRef TA,

enzyme/Enzyme/CApi.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ LLVMValueRef
145145
EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret);
146146
LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret);
147147

148+
typedef LLVMValueRef (*CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef,
149+
size_t /*numArgs*/, LLVMValueRef *);
150+
typedef LLVMValueRef (*CustomShadowFree)(LLVMBuilderRef, LLVMValueRef,
151+
LLVMValueRef);
152+
153+
void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
154+
CustomShadowFree FHandle);
155+
148156
#ifdef __cplusplus
149157
}
150158
#endif

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@
4545

4646
#include <algorithm>
4747

48+
std::map<std::string, std::function<llvm::Value *(IRBuilder<> &, CallInst *,
49+
ArrayRef<Value *>)>>
50+
shadowHandlers;
51+
std::map<std::string,
52+
std::function<llvm::CallInst *(IRBuilder<> &, Value *, Function *)>>
53+
shadowErasers;
54+
4855
llvm::cl::opt<bool>
4956
EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden,
5057
cl::desc("Use new cache decision algorithm"));

enzyme/Enzyme/GradientUtils.h

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ using namespace llvm;
7373

7474
enum class DerivativeMode { Forward, Reverse, Both };
7575

76+
#include "llvm-c/Core.h"
77+
78+
extern std::map<std::string, std::function<llvm::Value *(
79+
IRBuilder<> &, CallInst *, ArrayRef<Value *>)>>
80+
shadowHandlers;
81+
7682
static inline std::string to_string(DerivativeMode mode) {
7783
switch (mode) {
7884
case DerivativeMode::Forward:
@@ -525,6 +531,22 @@ class GradientUtils : public CacheUtility {
525531
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
526532
args.push_back(getNewFromOriginal(orig->getArgOperand(i)));
527533
}
534+
535+
if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) !=
536+
shadowHandlers.end()) {
537+
Value *anti = shadowHandlers[orig->getCalledFunction()->getName().str()](
538+
bb, orig, args);
539+
invertedPointers[orig] = anti;
540+
// assert(placeholder != anti);
541+
bb.SetInsertPoint(placeholder->getNextNode());
542+
replaceAWithB(placeholder, anti);
543+
erase(placeholder);
544+
545+
anti = cacheForReverse(bb, anti, idx);
546+
invertedPointers[orig] = anti;
547+
return anti;
548+
}
549+
528550
Value *anti =
529551
bb.CreateCall(orig->getCalledFunction(), args, orig->getName() + "'mi");
530552
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
@@ -575,16 +597,22 @@ class GradientUtils : public CacheUtility {
575597
*orig);
576598
}
577599
}
578-
auto dst_arg = bb.CreateBitCast(
579-
anti, Type::getInt8PtrTy(orig->getContext(),
580-
anti->getType()->getPointerAddressSpace()));
600+
601+
Value *dst_arg = anti;
602+
603+
dst_arg = bb.CreateBitCast(
604+
dst_arg,
605+
Type::getInt8PtrTy(orig->getContext(),
606+
anti->getType()->getPointerAddressSpace()));
607+
581608
auto val_arg = ConstantInt::get(Type::getInt8Ty(orig->getContext()), 0);
582609
Value *size;
583610
// todo check if this memset is legal and if a write barrier is needed
584-
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj")
611+
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") {
585612
size = args[1];
586-
else
613+
} else {
587614
size = args[0];
615+
}
588616
auto len_arg =
589617
bb.CreateZExtOrTrunc(size, Type::getInt64Ty(orig->getContext()));
590618
auto volatile_arg = ConstantInt::getFalse(orig->getContext());

enzyme/Enzyme/LibraryFuncs.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@
2828
#include "llvm/IR/IRBuilder.h"
2929
#include "llvm/IR/Instructions.h"
3030

31+
extern std::map<std::string, std::function<llvm::Value *(
32+
llvm::IRBuilder<> &, llvm::CallInst *,
33+
llvm::ArrayRef<llvm::Value *>)>>
34+
shadowHandlers;
35+
extern std::map<std::string,
36+
std::function<llvm::CallInst *(
37+
llvm::IRBuilder<> &, llvm::Value *, llvm::Function *)>>
38+
shadowErasers;
39+
3140
/// Return whether a given function is a known C/C++ memory allocation function
3241
/// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp
3342
static inline bool isAllocationFunction(const llvm::Function &F,
@@ -38,6 +47,9 @@ static inline bool isAllocationFunction(const llvm::Function &F,
3847
return true;
3948
if (F.getName() == "julia.gc_alloc_obj")
4049
return true;
50+
if (shadowHandlers.find(F.getName().str()) != shadowHandlers.end())
51+
return true;
52+
4153
using namespace llvm;
4254
llvm::LibFunc libfunc;
4355
if (!TLI.getLibFunc(F, libfunc))
@@ -193,8 +205,12 @@ freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree,
193205
allocationfn.getName() == "__rust_alloc_zeroed") {
194206
llvm_unreachable("todo - hook in rust allocation fns");
195207
}
196-
if (allocationfn.getName() == "julia.gc_alloc_obj") {
208+
if (allocationfn.getName() == "julia.gc_alloc_obj")
197209
return nullptr;
210+
211+
if (shadowErasers.find(allocationfn.getName().str()) != shadowErasers.end()) {
212+
return shadowErasers[allocationfn.getName().str()](builder, tofree,
213+
&allocationfn);
198214
}
199215

200216
llvm::LibFunc libfunc;

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,26 +2500,7 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
25002500
analyzeFuncTypes<__VA_ARGS__>(::fn, call, *this); \
25012501
return; \
25022502
}
2503-
// All these are always valid => no direction check
2504-
// CONSIDER(malloc)
2505-
// TODO consider handling other allocation functions integer inputs
2506-
if (isAllocationFunction(*ci, interprocedural.TLI)) {
2507-
size_t Idx = 0;
2508-
for (auto &Arg : ci->args()) {
2509-
if (Arg.getType()->isIntegerTy()) {
2510-
updateAnalysis(call.getOperand(Idx),
2511-
TypeTree(BaseType::Integer).Only(-1), &call);
2512-
}
2513-
Idx++;
2514-
}
2515-
assert(ci->getReturnType()->isPointerTy());
2516-
updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call);
2517-
return;
2518-
}
2519-
if (ci->getName().startswith("_ZN3std2io5stdio6_print") ||
2520-
ci->getName().startswith("_ZN4core3fmt")) {
2521-
return;
2522-
}
2503+
25232504
auto customrule = interprocedural.CustomRules.find(ci->getName().str());
25242505
if (customrule != interprocedural.CustomRules.end()) {
25252506
auto returnAnalysis = getAnalysis(&call);
@@ -2544,6 +2525,26 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
25442525
}
25452526
return;
25462527
}
2528+
// All these are always valid => no direction check
2529+
// CONSIDER(malloc)
2530+
// TODO consider handling other allocation functions integer inputs
2531+
if (isAllocationFunction(*ci, interprocedural.TLI)) {
2532+
size_t Idx = 0;
2533+
for (auto &Arg : ci->args()) {
2534+
if (Arg.getType()->isIntegerTy()) {
2535+
updateAnalysis(call.getOperand(Idx),
2536+
TypeTree(BaseType::Integer).Only(-1), &call);
2537+
}
2538+
Idx++;
2539+
}
2540+
assert(ci->getReturnType()->isPointerTy());
2541+
updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call);
2542+
return;
2543+
}
2544+
if (ci->getName().startswith("_ZN3std2io5stdio6_print") ||
2545+
ci->getName().startswith("_ZN4core3fmt")) {
2546+
return;
2547+
}
25472548
/// MPI
25482549
if (ci->getName() == "MPI_Init") {
25492550
TypeTree ptrint;

0 commit comments

Comments
 (0)