@@ -2474,6 +2474,143 @@ bool AdjointGenerator::handleKnownCallDerivatives(
24742474 return true ;
24752475 }
24762476
2477+ /*
2478+ * int gsl_sf_legendre_array_e(const gsl_sf_legendre_t norm,
2479+ const size_t lmax,
2480+ const double x,
2481+ const double csphase,
2482+ double result_array[]);
2483+ */
2484+ // d L(n, x) / dx = L(n,x) * x * (n-1) + 1
2485+ if (funcName == " gsl_sf_legendre_array_e" ) {
2486+ if (gutils->isConstantValue (call.getArgOperand (4 ))) {
2487+ eraseIfUnused (call);
2488+ return true ;
2489+ }
2490+ if (Mode == DerivativeMode::ReverseModePrimal) {
2491+ eraseIfUnused (call);
2492+ return true ;
2493+ }
2494+ if (Mode == DerivativeMode::ReverseModeCombined ||
2495+ Mode == DerivativeMode::ReverseModeGradient) {
2496+ IRBuilder<> Builder2 (&call);
2497+ getReverseBuilder (Builder2);
2498+ ValueType BundleTypes[5 ] = {ValueType::None, ValueType::None,
2499+ ValueType::None, ValueType::None,
2500+ ValueType::Shadow};
2501+ auto Defs = gutils->getInvertedBundles (&call, BundleTypes, Builder2,
2502+ /* lookup*/ true );
2503+
2504+ Type *types[6 ] = {
2505+ call.getOperand (0 )->getType (), call.getOperand (1 )->getType (),
2506+ call.getOperand (2 )->getType (), call.getOperand (3 )->getType (),
2507+ call.getOperand (4 )->getType (), call.getOperand (4 )->getType (),
2508+ };
2509+ FunctionType *FT = FunctionType::get (call.getType (), types, false );
2510+ auto F = called->getParent ()->getOrInsertFunction (
2511+ " gsl_sf_legendre_deriv_array_e" , FT);
2512+
2513+ llvm::Value *args[6 ] = {
2514+ gutils->lookupM (gutils->getNewFromOriginal (call.getOperand (0 )),
2515+ Builder2),
2516+ gutils->lookupM (gutils->getNewFromOriginal (call.getOperand (1 )),
2517+ Builder2),
2518+ gutils->lookupM (gutils->getNewFromOriginal (call.getOperand (2 )),
2519+ Builder2),
2520+ gutils->lookupM (gutils->getNewFromOriginal (call.getOperand (3 )),
2521+ Builder2),
2522+ nullptr ,
2523+ nullptr };
2524+
2525+ #if LLVM_VERSION_MAJOR >= 13
2526+ Type *stackTys[] = {getInt8PtrTy (Builder2.getContext ())};
2527+ #else
2528+ ArrayRef<Type *> stackTys = {};
2529+ #endif
2530+ auto stack = Builder2.CreateIntrinsic (Intrinsic::stacksave,
2531+ ArrayRef<Type *>(stackTys),
2532+ ArrayRef<Value *>());
2533+ auto tmp = Builder2.CreateAlloca (types[2 ], args[1 ]);
2534+ auto dtmp = Builder2.CreateAlloca (types[2 ], args[1 ]);
2535+ Builder2.CreateLifetimeStart (tmp);
2536+ Builder2.CreateLifetimeStart (dtmp);
2537+
2538+ args[4 ] = Builder2.CreateBitCast (tmp, types[4 ]);
2539+ args[5 ] = Builder2.CreateBitCast (dtmp, types[5 ]);
2540+
2541+ Builder2.CreateCall (F, args, Defs);
2542+ Builder2.CreateLifetimeEnd (tmp);
2543+
2544+ BasicBlock *currentBlock = Builder2.GetInsertBlock ();
2545+
2546+ BasicBlock *loopBlock = gutils->addReverseBlock (
2547+ currentBlock, currentBlock->getName () + " _loop" );
2548+ BasicBlock *endBlock =
2549+ gutils->addReverseBlock (loopBlock, currentBlock->getName () + " _end" ,
2550+ /* fork*/ true , /* push*/ false );
2551+
2552+ Builder2.CreateCondBr (
2553+ Builder2.CreateICmpEQ (args[1 ], Constant::getNullValue (types[1 ])),
2554+ endBlock, loopBlock);
2555+ Builder2.SetInsertPoint (loopBlock);
2556+
2557+ auto idx = Builder2.CreatePHI (types[1 ], 2 );
2558+ idx->addIncoming (ConstantInt::get (types[1 ], 0 , false ), currentBlock);
2559+
2560+ auto acc_idx = Builder2.CreatePHI (types[2 ], 2 );
2561+
2562+ Value *inc = Builder2.CreateAdd (
2563+ idx, ConstantInt::get (types[1 ], 1 , false ), " " , true , true );
2564+ idx->addIncoming (inc, loopBlock);
2565+ acc_idx->addIncoming (Constant::getNullValue (types[2 ]), currentBlock);
2566+
2567+ Value *idxs[] = {idx};
2568+ Value *dtmp_idx = Builder2.CreateInBoundsGEP (types[2 ], dtmp, idxs);
2569+ Value *d_req = Builder2.CreateInBoundsGEP (
2570+ types[2 ],
2571+ Builder2.CreatePointerCast (
2572+ gutils->invertPointerM (call.getOperand (4 ), Builder2),
2573+ PointerType::getUnqual (types[2 ])),
2574+ idxs);
2575+
2576+ auto acc = Builder2.CreateFAdd (
2577+ acc_idx,
2578+ Builder2.CreateFMul (Builder2.CreateLoad (types[2 ], dtmp_idx),
2579+ Builder2.CreateLoad (types[2 ], d_req)));
2580+ Builder2.CreateStore (Constant::getNullValue (types[2 ]), d_req);
2581+
2582+ acc_idx->addIncoming (acc, loopBlock);
2583+
2584+ Builder2.CreateCondBr (Builder2.CreateICmpEQ (inc, args[1 ]), endBlock,
2585+ loopBlock);
2586+
2587+ Builder2.SetInsertPoint (endBlock);
2588+ {
2589+ auto found = gutils->reverseBlockToPrimal .find (endBlock);
2590+ assert (found != gutils->reverseBlockToPrimal .end ());
2591+ SmallVector<BasicBlock *, 4 > &vec =
2592+ gutils->reverseBlocks [found->second ];
2593+ assert (vec.size ());
2594+ vec.push_back (endBlock);
2595+ }
2596+
2597+ auto fin_idx = Builder2.CreatePHI (types[2 ], 2 );
2598+ fin_idx->addIncoming (Constant::getNullValue (types[2 ]), currentBlock);
2599+ fin_idx->addIncoming (acc, loopBlock);
2600+
2601+ Builder2.CreateLifetimeEnd (dtmp);
2602+
2603+ Builder2.CreateIntrinsic (Intrinsic::stackrestore,
2604+ ArrayRef<Type *>(stackTys),
2605+ ArrayRef<Value *>(stack));
2606+
2607+ ((DiffeGradientUtils *)gutils)
2608+ ->addToDiffe (call.getOperand (2 ), fin_idx, Builder2, types[2 ]);
2609+
2610+ return true ;
2611+ }
2612+ }
2613+
24772614 // Functions that only modify pointers and don't allocate memory,
24782615 // needs to be run on shadow in primal
24792616 if (funcName == " _ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_"
0 commit comments