Skip to content

Commit 2250522

Browse files
authored
Add gsl_sf_legendre_array_e (#1869)
* Add gsl_sf_legendre_array_e * add TT * fixup
1 parent 5458b5f commit 2250522

File tree

3 files changed

+205
-0
lines changed

3 files changed

+205
-0
lines changed

enzyme/Enzyme/CallDerivatives.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,6 +2474,143 @@ bool AdjointGenerator::handleKnownCallDerivatives(
24742474
return true;
24752475
}
24762476

2477+
/*
2478+
* int gsl_sf_legendre_array_e(const gsl_sf_legendre_t norm,
2479+
const size_t lmax,
2480+
const double x,
2481+
const double csphase,
2482+
double result_array[]);
2483+
*/
2484+
// d L(n, x) / dx = L(n,x) * x * (n-1) + 1
2485+
if (funcName == "gsl_sf_legendre_array_e") {
2486+
if (gutils->isConstantValue(call.getArgOperand(4))) {
2487+
eraseIfUnused(call);
2488+
return true;
2489+
}
2490+
if (Mode == DerivativeMode::ReverseModePrimal) {
2491+
eraseIfUnused(call);
2492+
return true;
2493+
}
2494+
if (Mode == DerivativeMode::ReverseModeCombined ||
2495+
Mode == DerivativeMode::ReverseModeGradient) {
2496+
IRBuilder<> Builder2(&call);
2497+
getReverseBuilder(Builder2);
2498+
ValueType BundleTypes[5] = {ValueType::None, ValueType::None,
2499+
ValueType::None, ValueType::None,
2500+
ValueType::Shadow};
2501+
auto Defs = gutils->getInvertedBundles(&call, BundleTypes, Builder2,
2502+
/*lookup*/ true);
2503+
2504+
Type *types[6] = {
2505+
call.getOperand(0)->getType(), call.getOperand(1)->getType(),
2506+
call.getOperand(2)->getType(), call.getOperand(3)->getType(),
2507+
call.getOperand(4)->getType(), call.getOperand(4)->getType(),
2508+
};
2509+
FunctionType *FT = FunctionType::get(call.getType(), types, false);
2510+
auto F = called->getParent()->getOrInsertFunction(
2511+
"gsl_sf_legendre_deriv_array_e", FT);
2512+
2513+
llvm::Value *args[6] = {
2514+
gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(0)),
2515+
Builder2),
2516+
gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(1)),
2517+
Builder2),
2518+
gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(2)),
2519+
Builder2),
2520+
gutils->lookupM(gutils->getNewFromOriginal(call.getOperand(3)),
2521+
Builder2),
2522+
nullptr,
2523+
nullptr};
2524+
2525+
#if LLVM_VERSION_MAJOR >= 13
2526+
Type *stackTys[] = {getInt8PtrTy(Builder2.getContext())};
2527+
#else
2528+
ArrayRef<Type *> stackTys = {};
2529+
#endif
2530+
auto stack = Builder2.CreateIntrinsic(Intrinsic::stacksave,
2531+
ArrayRef<Type *>(stackTys),
2532+
ArrayRef<Value *>());
2533+
auto tmp = Builder2.CreateAlloca(types[2], args[1]);
2534+
auto dtmp = Builder2.CreateAlloca(types[2], args[1]);
2535+
Builder2.CreateLifetimeStart(tmp);
2536+
Builder2.CreateLifetimeStart(dtmp);
2537+
2538+
args[4] = Builder2.CreateBitCast(tmp, types[4]);
2539+
args[5] = Builder2.CreateBitCast(dtmp, types[5]);
2540+
2541+
Builder2.CreateCall(F, args, Defs);
2542+
Builder2.CreateLifetimeEnd(tmp);
2543+
2544+
BasicBlock *currentBlock = Builder2.GetInsertBlock();
2545+
2546+
BasicBlock *loopBlock = gutils->addReverseBlock(
2547+
currentBlock, currentBlock->getName() + "_loop");
2548+
BasicBlock *endBlock =
2549+
gutils->addReverseBlock(loopBlock, currentBlock->getName() + "_end",
2550+
/*fork*/ true, /*push*/ false);
2551+
2552+
Builder2.CreateCondBr(
2553+
Builder2.CreateICmpEQ(args[1], Constant::getNullValue(types[1])),
2554+
endBlock, loopBlock);
2555+
Builder2.SetInsertPoint(loopBlock);
2556+
2557+
auto idx = Builder2.CreatePHI(types[1], 2);
2558+
idx->addIncoming(ConstantInt::get(types[1], 0, false), currentBlock);
2559+
2560+
auto acc_idx = Builder2.CreatePHI(types[2], 2);
2561+
2562+
Value *inc = Builder2.CreateAdd(
2563+
idx, ConstantInt::get(types[1], 1, false), "", true, true);
2564+
idx->addIncoming(inc, loopBlock);
2565+
acc_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2566+
2567+
Value *idxs[] = {idx};
2568+
Value *dtmp_idx = Builder2.CreateInBoundsGEP(types[2], dtmp, idxs);
2569+
Value *d_req = Builder2.CreateInBoundsGEP(
2570+
types[2],
2571+
Builder2.CreatePointerCast(
2572+
gutils->invertPointerM(call.getOperand(4), Builder2),
2573+
PointerType::getUnqual(types[2])),
2574+
idxs);
2575+
2576+
auto acc = Builder2.CreateFAdd(
2577+
acc_idx,
2578+
Builder2.CreateFMul(Builder2.CreateLoad(types[2], dtmp_idx),
2579+
Builder2.CreateLoad(types[2], d_req)));
2580+
Builder2.CreateStore(Constant::getNullValue(types[2]), d_req);
2581+
2582+
acc_idx->addIncoming(acc, loopBlock);
2583+
2584+
Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, args[1]), endBlock,
2585+
loopBlock);
2586+
2587+
Builder2.SetInsertPoint(endBlock);
2588+
{
2589+
auto found = gutils->reverseBlockToPrimal.find(endBlock);
2590+
assert(found != gutils->reverseBlockToPrimal.end());
2591+
SmallVector<BasicBlock *, 4> &vec =
2592+
gutils->reverseBlocks[found->second];
2593+
assert(vec.size());
2594+
vec.push_back(endBlock);
2595+
}
2596+
2597+
auto fin_idx = Builder2.CreatePHI(types[2], 2);
2598+
fin_idx->addIncoming(Constant::getNullValue(types[2]), currentBlock);
2599+
fin_idx->addIncoming(acc, loopBlock);
2600+
2601+
Builder2.CreateLifetimeEnd(dtmp);
2602+
2603+
Builder2.CreateIntrinsic(Intrinsic::stackrestore,
2604+
ArrayRef<Type *>(stackTys),
2605+
ArrayRef<Value *>(stack));
2606+
2607+
((DiffeGradientUtils *)gutils)
2608+
->addToDiffe(call.getOperand(2), fin_idx, Builder2, types[2]);
2609+
2610+
return true;
2611+
}
2612+
}
2613+
24772614
// Functions that only modify pointers and don't allocate memory,
24782615
// needs to be run on shadow in primal
24792616
if (funcName == "_ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_"

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5333,6 +5333,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
53335333
TypeTree(BaseType::Integer).Only(-1, &call), &call);
53345334
return;
53355335
}
5336+
if (funcName == "gsl_sf_legendre_array_e") {
5337+
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
5338+
return;
5339+
}
53365340

53375341
// CONSIDER(__lgamma_r_finite)
53385342

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
3+
4+
declare dso_local i32 @gsl_sf_legendre_array_e(i32, i32, double, double, double*) local_unnamed_addr #1
5+
6+
7+
; Function Attrs: noinline nounwind readnone uwtable
8+
define dso_local void @tester(i32 %a0, i32 %a1, double %x, double %a3, double* %a4) {
9+
entry:
10+
%c = call i32 @gsl_sf_legendre_array_e(i32 %a0, i32 %a1, double %x, double %a3, double* %a4)
11+
ret void
12+
}
13+
14+
define double @test_derivative(double %x, double %y) {
15+
entry:
16+
%0 = tail call double (...) @__enzyme_autodiff(void (i32, i32, double, double, double*)* @tester, i32 0, i32 10, double %x, metadata !"enzyme_const", double %y, double* null, double* null)
17+
ret double %0
18+
}
19+
20+
; Function Attrs: nounwind readnone speculatable
21+
declare double @llvm.pow.f64(double, double)
22+
23+
; Function Attrs: nounwind
24+
declare double @__enzyme_autodiff(...)
25+
26+
; CHECK: define internal { double } @diffetester(i32 %a0, i32 %a1, double %x, double %a3, double* %a4, double* %"a4'")
27+
; CHECK-NEXT: entry:
28+
; CHECK-NEXT: %c = call i32 @gsl_sf_legendre_array_e(i32 %a0, i32 %a1, double %x, double %a3, double* %a4)
29+
; CHECK-NEXT: %[[l0:.+]] = call i8* @llvm.stacksave
30+
; CHECK-NEXT: %[[i0:.+]] = alloca double, i32 %a1, align 8
31+
; CHECK-NEXT: %[[i1:.+]] = alloca double, i32 %a1, align 8
32+
; CHECK-NEXT: %[[l3:.+]] = bitcast double* %[[i0]] to i8*
33+
; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 -1, i8* %[[l3]])
34+
; CHECK-NEXT: %[[l4:.+]] = bitcast double* %[[i1]] to i8*
35+
; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 -1, i8* %[[l4]])
36+
; CHECK-NEXT: %[[i2:.+]] = call i32 @gsl_sf_legendre_deriv_array_e(i32 %a0, i32 %a1, double %x, double %a3, double* %[[i0]], double* %[[i1]])
37+
; CHECK-NEXT: %[[l6:.+]] = bitcast double* %[[i0]] to i8*
38+
; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 -1, i8* %[[l6]])
39+
; CHECK-NEXT: %[[i3:.+]] = icmp eq i32 %a1, 0
40+
; CHECK-NEXT: br i1 %[[i3]], label %invertentry_end, label %invertentry_loop
41+
42+
; CHECK: invertentry_loop:
43+
; CHECK-NEXT: %[[i4:.+]] = phi i32 [ 0, %entry ], [ %[[i5:.+]], %invertentry_loop ]
44+
; CHECK-NEXT: %[[p5:.+]] = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %[[p12:.+]], %invertentry_loop ]
45+
; CHECK-NEXT: %[[i5]] = add nuw nsw i32 %[[i4]], 1
46+
; CHECK-NEXT: %[[i6:.+]] = getelementptr inbounds double, double* %[[i1]], i32 %[[i4]]
47+
; CHECK-NEXT: %[[i7:.+]] = getelementptr inbounds double, double* %"a4'", i32 %[[i4]]
48+
; CHECK-NEXT: %[[i8:.+]] = load double, double* %[[i7]], align 8
49+
; CHECK-NEXT: %[[i9:.+]] = load double, double* %[[i6]], align 8
50+
; CHECK-NEXT: %[[i10:.+]] = fmul fast double %[[i9]], %[[i8]]
51+
; CHECK-NEXT: %[[p12]] = fadd fast double %[[p5]], %[[i10]]
52+
; CHECK-NEXT: store double 0.000000e+00, double* %[[i7]], align 8
53+
; CHECK-NEXT: %[[c17:.+]] = icmp eq i32 %[[i5]], %a1
54+
; CHECK-NEXT: br i1 %[[c17]], label %invertentry_end, label %invertentry_loop
55+
56+
; CHECK: invertentry_end:
57+
; CHECK-NEXT: %[[res:.+]] = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %[[p12]], %invertentry_loop ]
58+
; CHECK-NEXT: %[[l19:.+]] = bitcast double* %[[i1]] to i8*
59+
; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 -1, i8* %[[l19]])
60+
; CHECK-NEXT: call void @llvm.stackrestore
61+
; CHECK-NEXT: %[[i11:.+]] = insertvalue { double } {{(undef|poison)}}, double %[[res]], 0
62+
; CHECK-NEXT: ret { double } %[[i11]]
63+
; CHECK-NEXT: }
64+

0 commit comments

Comments
 (0)