@@ -541,11 +541,11 @@ void EnzymeGradientUtilsSubTransferHelper(
541541}
542542
543543LLVMValueRef EnzymeCreateForwardDiff (
544- EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType ,
545- CDIFFE_TYPE *constant_args, size_t constant_args_size ,
546- EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode ,
547- uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg ,
548- CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
544+ EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip ,
545+ LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args ,
546+ size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
547+ CDerivativeMode mode, uint8_t freeMemory, unsigned width,
548+ LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
549549 size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) {
550550 SmallVector<DIFFE_TYPE, 4 > nconstant_args ((DIFFE_TYPE *)constant_args,
551551 (DIFFE_TYPE *)constant_args +
@@ -556,16 +556,18 @@ LLVMValueRef EnzymeCreateForwardDiff(
556556 overwritten_args.push_back (_overwritten_args[i]);
557557 }
558558 return wrap (eunwrap (Logic).CreateForwardDiff (
559+ RequestContext (cast_or_null<Instruction>(unwrap (request_req)),
560+ unwrap (request_ip)),
559561 cast<Function>(unwrap (todiff)), (DIFFE_TYPE)retType, nconstant_args,
560562 eunwrap (TA), returnValue, (DerivativeMode)mode, freeMemory, width,
561563 unwrap (additionalArg), eunwrap (typeInfo, cast<Function>(unwrap (todiff))),
562564 overwritten_args, eunwrap (augmented)));
563565}
564566LLVMValueRef EnzymeCreatePrimalAndGradient (
565- EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType ,
566- CDIFFE_TYPE *constant_args, size_t constant_args_size ,
567- EnzymeTypeAnalysisRef TA, uint8_t returnValue , uint8_t dretUsed ,
568- CDerivativeMode mode, unsigned width, uint8_t freeMemory,
567+ EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip ,
568+ LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args ,
569+ size_t constant_args_size, EnzymeTypeAnalysisRef TA , uint8_t returnValue ,
570+ uint8_t dretUsed, CDerivativeMode mode, unsigned width, uint8_t freeMemory,
569571 LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
570572 uint8_t *_overwritten_args, size_t overwritten_args_size,
571573 EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd) {
@@ -578,6 +580,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
578580 overwritten_args.push_back (_overwritten_args[i]);
579581 }
580582 return wrap (eunwrap (Logic).CreatePrimalAndGradient (
583+ RequestContext (cast<Instruction>(unwrap (request_req)),
584+ unwrap (request_ip)),
581585 (ReverseCacheKey){
582586 .todiff = cast<Function>(unwrap (todiff)),
583587 .retType = (DIFFE_TYPE)retType,
@@ -596,10 +600,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
596600 eunwrap (TA), eunwrap (augmented)));
597601}
598602EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal (
599- EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType ,
600- CDIFFE_TYPE *constant_args, size_t constant_args_size ,
601- EnzymeTypeAnalysisRef TA, uint8_t returnUsed , uint8_t shadowReturnUsed ,
602- CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
603+ EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip ,
604+ LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args ,
605+ size_t constant_args_size, EnzymeTypeAnalysisRef TA , uint8_t returnUsed ,
606+ uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
603607 size_t overwritten_args_size, uint8_t forceAnonymousTape, unsigned width,
604608 uint8_t AtomicAdd) {
605609
@@ -612,14 +616,31 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
612616 overwritten_args.push_back (_overwritten_args[i]);
613617 }
614618 return ewrap (eunwrap (Logic).CreateAugmentedPrimal (
619+ RequestContext (cast_or_null<Instruction>(unwrap (request_req)),
620+ unwrap (request_ip)),
615621 cast<Function>(unwrap (todiff)), (DIFFE_TYPE)retType, nconstant_args,
616622 eunwrap (TA), returnUsed, shadowReturnUsed,
617623 eunwrap (typeInfo, cast<Function>(unwrap (todiff))), overwritten_args,
618624 forceAnonymousTape, width, AtomicAdd));
619625}
620626
627+ LLVMValueRef EnzymeCreateBatch (EnzymeLogicRef Logic, LLVMValueRef request_req,
628+ LLVMBuilderRef request_ip, LLVMValueRef tobatch,
629+ unsigned width, CBATCH_TYPE *arg_types,
630+ size_t arg_types_size, CBATCH_TYPE retType) {
631+
632+ return wrap (eunwrap (Logic).CreateBatch (
633+ RequestContext (cast_or_null<Instruction>(unwrap (request_req)),
634+ unwrap (request_ip)),
635+ cast<Function>(unwrap (tobatch)), width,
636+ ArrayRef<BATCH_TYPE>((BATCH_TYPE *)arg_types,
637+ (BATCH_TYPE *)arg_types + arg_types_size),
638+ (BATCH_TYPE)retType));
639+ }
640+
621641LLVMValueRef EnzymeCreateTrace (
622- EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions,
642+ EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
643+ LLVMValueRef totrace, LLVMValueRef *sample_functions,
623644 size_t sample_functions_size, LLVMValueRef *observe_functions,
624645 size_t observe_functions_size, const char *active_random_variables[],
625646 size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff,
@@ -641,6 +662,8 @@ LLVMValueRef EnzymeCreateTrace(
641662 }
642663
643664 return wrap (eunwrap (Logic).CreateTrace (
665+ RequestContext (cast_or_null<Instruction>(unwrap (request_req)),
666+ unwrap (request_ip)),
644667 cast<Function>(unwrap (totrace)), SampleFunctions, ObserveFunctions,
645668 ActiveRandomVariables, (ProbProgMode)mode, (bool )autodiff,
646669 eunwrap (interface)));
0 commit comments