Skip to content

Commit eb09439

Browse files
authored
Nice error message for undifferentiable functions (#1451)
* Nice error message for undifferentiable functions * Don't clone if empty
1 parent 21dcb51 commit eb09439

File tree

17 files changed

+534
-228
lines changed

17 files changed

+534
-228
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3896,8 +3896,9 @@ class AdjointGenerator
38963896
Mode == DerivativeMode::ReverseModeCombined) {
38973897
if (called) {
38983898
subdata = &gutils->Logic.CreateAugmentedPrimal(
3899-
cast<Function>(called), subretType, argsInverted,
3900-
TR.analyzer.interprocedural, /*return is used*/ false,
3899+
RequestContext(&call, &BuilderZ), cast<Function>(called),
3900+
subretType, argsInverted, TR.analyzer.interprocedural,
3901+
/*return is used*/ false,
39013902
/*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false,
39023903
gutils->getWidth(),
39033904
/*AtomicAdd*/ true,
@@ -4096,6 +4097,7 @@ class AdjointGenerator
40964097
}
40974098

40984099
newcalled = gutils->Logic.CreatePrimalAndGradient(
4100+
RequestContext(&call, &Builder2),
40994101
(ReverseCacheKey){.todiff = cast<Function>(called),
41004102
.retType = subretType,
41014103
.constant_args = argsInverted,
@@ -6851,8 +6853,9 @@ class AdjointGenerator
68516853

68526854
if (called) {
68536855
newcalled = gutils->Logic.CreateForwardDiff(
6854-
cast<Function>(called), subretType, argsInverted,
6855-
TR.analyzer.interprocedural, /*returnValue*/ subretused, Mode,
6856+
RequestContext(&call, &BuilderZ), cast<Function>(called),
6857+
subretType, argsInverted, TR.analyzer.interprocedural,
6858+
/*returnValue*/ subretused, Mode,
68566859
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(),
68576860
tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args,
68586861
/*augmented*/ subdata);
@@ -7254,10 +7257,10 @@ class AdjointGenerator
72547257
if (Mode == DerivativeMode::ReverseModePrimal ||
72557258
Mode == DerivativeMode::ReverseModeCombined) {
72567259
subdata = &gutils->Logic.CreateAugmentedPrimal(
7257-
cast<Function>(called), subretType, argsInverted,
7258-
TR.analyzer.interprocedural, /*return is used*/ subretused,
7259-
shadowReturnUsed, nextTypeInfo, overwritten_args, false,
7260-
gutils->getWidth(), gutils->AtomicAdd);
7260+
RequestContext(&call, &BuilderZ), cast<Function>(called),
7261+
subretType, argsInverted, TR.analyzer.interprocedural,
7262+
/*return is used*/ subretused, shadowReturnUsed, nextTypeInfo,
7263+
overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd);
72617264
if (Mode == DerivativeMode::ReverseModePrimal) {
72627265
assert(augmentedReturn);
72637266
auto subaugmentations =
@@ -7639,6 +7642,7 @@ class AdjointGenerator
76397642
}
76407643

76417644
newcalled = gutils->Logic.CreatePrimalAndGradient(
7645+
RequestContext(&call, &Builder2),
76427646
(ReverseCacheKey){.todiff = cast<Function>(called),
76437647
.retType = subretType,
76447648
.constant_args = argsInverted,
@@ -10066,7 +10070,8 @@ class AdjointGenerator
1006610070
auto callval = call.getCalledOperand();
1006710071
if (!isa<Constant>(callval))
1006810072
callval = gutils->getNewFromOriginal(callval);
10069-
newCall->setCalledOperand(gutils->Logic.CreateNoFree(callval));
10073+
newCall->setCalledOperand(gutils->Logic.CreateNoFree(
10074+
RequestContext(&call, &BuilderZ), callval));
1007010075
}
1007110076
if (gutils->knownRecomputeHeuristic.find(&call) !=
1007210077
gutils->knownRecomputeHeuristic.end()) {

enzyme/Enzyme/CApi.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -541,11 +541,11 @@ void EnzymeGradientUtilsSubTransferHelper(
541541
}
542542

543543
LLVMValueRef EnzymeCreateForwardDiff(
544-
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
545-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
546-
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
547-
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
548-
CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
544+
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
545+
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
546+
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
547+
CDerivativeMode mode, uint8_t freeMemory, unsigned width,
548+
LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
549549
size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) {
550550
SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
551551
(DIFFE_TYPE *)constant_args +
@@ -556,16 +556,18 @@ LLVMValueRef EnzymeCreateForwardDiff(
556556
overwritten_args.push_back(_overwritten_args[i]);
557557
}
558558
return wrap(eunwrap(Logic).CreateForwardDiff(
559+
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
560+
unwrap(request_ip)),
559561
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
560562
eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width,
561563
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
562564
overwritten_args, eunwrap(augmented)));
563565
}
564566
LLVMValueRef EnzymeCreatePrimalAndGradient(
565-
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
566-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
567-
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
568-
CDerivativeMode mode, unsigned width, uint8_t freeMemory,
567+
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
568+
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
569+
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
570+
uint8_t dretUsed, CDerivativeMode mode, unsigned width, uint8_t freeMemory,
569571
LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
570572
uint8_t *_overwritten_args, size_t overwritten_args_size,
571573
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd) {
@@ -578,6 +580,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
578580
overwritten_args.push_back(_overwritten_args[i]);
579581
}
580582
return wrap(eunwrap(Logic).CreatePrimalAndGradient(
583+
RequestContext(cast<Instruction>(unwrap(request_req)),
584+
unwrap(request_ip)),
581585
(ReverseCacheKey){
582586
.todiff = cast<Function>(unwrap(todiff)),
583587
.retType = (DIFFE_TYPE)retType,
@@ -596,10 +600,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
596600
eunwrap(TA), eunwrap(augmented)));
597601
}
598602
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
599-
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
600-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
601-
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
602-
CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
603+
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
604+
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
605+
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed,
606+
uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
603607
size_t overwritten_args_size, uint8_t forceAnonymousTape, unsigned width,
604608
uint8_t AtomicAdd) {
605609

@@ -612,14 +616,31 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
612616
overwritten_args.push_back(_overwritten_args[i]);
613617
}
614618
return ewrap(eunwrap(Logic).CreateAugmentedPrimal(
619+
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
620+
unwrap(request_ip)),
615621
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
616622
eunwrap(TA), returnUsed, shadowReturnUsed,
617623
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), overwritten_args,
618624
forceAnonymousTape, width, AtomicAdd));
619625
}
620626

627+
LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req,
628+
LLVMBuilderRef request_ip, LLVMValueRef tobatch,
629+
unsigned width, CBATCH_TYPE *arg_types,
630+
size_t arg_types_size, CBATCH_TYPE retType) {
631+
632+
return wrap(eunwrap(Logic).CreateBatch(
633+
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
634+
unwrap(request_ip)),
635+
cast<Function>(unwrap(tobatch)), width,
636+
ArrayRef<BATCH_TYPE>((BATCH_TYPE *)arg_types,
637+
(BATCH_TYPE *)arg_types + arg_types_size),
638+
(BATCH_TYPE)retType));
639+
}
640+
621641
LLVMValueRef EnzymeCreateTrace(
622-
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions,
642+
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
643+
LLVMValueRef totrace, LLVMValueRef *sample_functions,
623644
size_t sample_functions_size, LLVMValueRef *observe_functions,
624645
size_t observe_functions_size, const char *active_random_variables[],
625646
size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
@@ -641,6 +662,8 @@ LLVMValueRef EnzymeCreateTrace(
641662
}
642663

643664
return wrap(eunwrap(Logic).CreateTrace(
665+
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
666+
unwrap(request_ip)),
644667
cast<Function>(unwrap(totrace)), SampleFunctions, ObserveFunctions,
645668
ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff,
646669
eunwrap(interface)));

enzyme/Enzyme/CApi.h

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ typedef enum {
119119
// but don't need the forward
120120
} CDIFFE_TYPE;
121121

122+
typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE;
123+
122124
typedef enum {
123125
DEM_ForwardMode = 0,
124126
DEM_ReverseModePrimal = 1,
@@ -132,40 +134,6 @@ typedef enum {
132134
DEM_Condition = 1,
133135
} CProbProgMode;
134136

135-
LLVMValueRef EnzymeCreateForwardDiff(
136-
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
137-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
138-
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
139-
uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg,
140-
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
141-
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented);
142-
143-
LLVMValueRef EnzymeCreatePrimalAndGradient(
144-
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
145-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
146-
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
147-
CDerivativeMode mode, unsigned width, uint8_t freeMemory,
148-
LLVMTypeRef additionalArg, uint8_t forceAnonymousTape,
149-
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
150-
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented,
151-
uint8_t AtomicAdd);
152-
153-
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
154-
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
155-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
156-
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
157-
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
158-
size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width,
159-
uint8_t AtomicAdd);
160-
161-
LLVMValueRef CreateTrace(
162-
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions,
163-
size_t sample_functions_size, LLVMValueRef *observe_functions,
164-
size_t observe_functions_size, LLVMValueRef *generative_functions,
165-
size_t generative_functions_size, const char *active_random_variables[],
166-
size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
167-
EnzymeTraceInterfaceRef interface);
168-
169137
typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
170138
CTypeTreeRef * /*args*/,
171139
struct IntList * /*knownValues*/,

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ DiffeGradientUtils::DiffeGradientUtils(
6262
: GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
6363
constantvalues_, returnvals_, ActiveReturn, constant_values,
6464
origToNew_, mode, width, omp) {
65+
if (oldFunc_->empty())
66+
return;
6567
assert(reverseBlocks.size() == 0);
6668
if (mode == DerivativeMode::ForwardMode ||
6769
mode == DerivativeMode::ForwardModeSplit) {
@@ -83,7 +85,6 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
8385
TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo,
8486
DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
8587
ReturnType returnValue, Type *additionalArg, bool omp) {
86-
assert(!todiff->empty());
8788
Function *oldFunc = todiff;
8889
assert(mode == DerivativeMode::ReverseModeGradient ||
8990
mode == DerivativeMode::ReverseModeCombined ||
@@ -149,7 +150,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
149150
}
150151

151152
TypeResults TR = TA.analyzeFunction(typeInfo);
152-
assert(TR.getFunction() == oldFunc);
153+
if (!oldFunc->empty())
154+
assert(TR.getFunction() == oldFunc);
153155

154156
auto res = new DiffeGradientUtils(Logic, newFunc, oldFunc, TLI, TA, TR,
155157
invertedPointers, constant_values,

enzyme/Enzyme/Enzyme.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,8 @@ class EnzymeBase {
13771377
? BATCH_TYPE::SCALAR
13781378
: BATCH_TYPE::VECTOR;
13791379

1380-
auto newFunc = Logic.CreateBatch(F, width, arg_types, ret_type);
1380+
auto newFunc = Logic.CreateBatch(RequestContext(CI, &Builder), F, width,
1381+
arg_types, ret_type);
13811382

13821383
if (!newFunc)
13831384
return false;
@@ -1432,6 +1433,7 @@ class EnzymeBase {
14321433
populate_overwritten_args(TA, fn, mode, overwritten_args);
14331434

14341435
IRBuilder Builder(CI);
1436+
RequestContext context(CI, &Builder);
14351437

14361438
// differentiate fn
14371439
Function *newFunc = nullptr;
@@ -1440,15 +1442,15 @@ class EnzymeBase {
14401442
switch (mode) {
14411443
case DerivativeMode::ForwardMode:
14421444
newFunc = Logic.CreateForwardDiff(
1443-
fn, retType, constants, TA,
1445+
context, fn, retType, constants, TA,
14441446
/*should return*/ primalReturn, mode, freeMemory, width,
14451447
/*addedType*/ nullptr, type_args, overwritten_args,
14461448
/*augmented*/ nullptr);
14471449
break;
14481450
case DerivativeMode::ForwardModeSplit: {
14491451
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
14501452
aug = &Logic.CreateAugmentedPrimal(
1451-
fn, retType, constants, TA,
1453+
context, fn, retType, constants, TA,
14521454
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
14531455
overwritten_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd);
14541456
auto &DL = fn->getParent()->getDataLayout();
@@ -1484,14 +1486,15 @@ class EnzymeBase {
14841486
tapeType = PointerType::getInt8PtrTy(fn->getContext());
14851487
}
14861488
newFunc = Logic.CreateForwardDiff(
1487-
fn, retType, constants, TA,
1489+
context, fn, retType, constants, TA,
14881490
/*should return*/ primalReturn, mode, freeMemory, width,
14891491
/*addedType*/ tapeType, type_args, overwritten_args, aug);
14901492
break;
14911493
}
14921494
case DerivativeMode::ReverseModeCombined:
14931495
assert(freeMemory);
14941496
newFunc = Logic.CreatePrimalAndGradient(
1497+
context,
14951498
(ReverseCacheKey){.todiff = fn,
14961499
.retType = retType,
14971500
.constant_args = constants,
@@ -1518,8 +1521,8 @@ class EnzymeBase {
15181521
bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
15191522
retType == DIFFE_TYPE::DUP_NONEED);
15201523
aug = &Logic.CreateAugmentedPrimal(
1521-
fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args,
1522-
overwritten_args, forceAnonymousTape, width,
1524+
context, fn, retType, constants, TA, returnUsed, shadowReturnUsed,
1525+
type_args, overwritten_args, forceAnonymousTape, width,
15231526
/*atomicAdd*/ AtomicAdd);
15241527
auto &DL = fn->getParent()->getDataLayout();
15251528
if (!forceAnonymousTape) {
@@ -1557,6 +1560,7 @@ class EnzymeBase {
15571560
newFunc = aug->fn;
15581561
else
15591562
newFunc = Logic.CreatePrimalAndGradient(
1563+
context,
15601564
(ReverseCacheKey){.todiff = fn,
15611565
.retType = retType,
15621566
.constant_args = constants,
@@ -1856,9 +1860,9 @@ class EnzymeBase {
18561860
constants.push_back(DIFFE_TYPE::CONSTANT);
18571861
}
18581862

1859-
auto newFunc = Logic.CreateTrace(F, sampleFunctions, observeFunctions,
1860-
opt->ActiveRandomVariables, mode, autodiff,
1861-
interface);
1863+
auto newFunc = Logic.CreateTrace(
1864+
RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions,
1865+
opt->ActiveRandomVariables, mode, autodiff, interface);
18621866

18631867
if (!autodiff) {
18641868
auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
@@ -2438,8 +2442,10 @@ class EnzymeBase {
24382442
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
24392443
Arch == Triple::amdgcn;
24402444

2445+
IRBuilder<> Builder(CI);
24412446
auto val = GradientUtils::GetOrCreateShadowConstant(
2442-
Logic, Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
2447+
RequestContext(CI, &Builder), Logic,
2448+
Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F), TA, fn,
24432449
pair.second, /*width*/ 1, AtomicAdd);
24442450
CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
24452451
CI->eraseFromParent();

0 commit comments

Comments
 (0)