@@ -539,20 +539,24 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
539539Function *getOrInsertDifferentialFloatMemcpy (Module &M, Type *elementType,
540540 unsigned dstalign,
541541 unsigned srcalign,
542- unsigned dstaddr,
543- unsigned srcaddr ) {
542+ unsigned dstaddr, unsigned srcaddr,
543+ unsigned bitwidth ) {
544544 assert (elementType->isFloatingPointTy ());
545- std::string name = " __enzyme_memcpyadd_" + tofltstr (elementType) + " da" +
546- std::to_string (dstalign) + " sa" + std::to_string (srcalign);
545+ std::string name = " __enzyme_memcpy" ;
546+ if (bitwidth != 64 )
547+ name += std::to_string (bitwidth);
548+ name += " add_" + tofltstr (elementType) + " da" + std::to_string (dstalign) +
549+ " sa" + std::to_string (srcalign);
547550 if (dstaddr)
548551 name += " dadd" + std::to_string (dstaddr);
549552 if (srcaddr)
550553 name += " sadd" + std::to_string (srcaddr);
551- FunctionType *FT = FunctionType::get (Type::getVoidTy (M.getContext ()),
552- {PointerType::get (elementType, dstaddr),
553- PointerType::get (elementType, srcaddr),
554- Type::getInt64Ty (M.getContext ())},
555- false );
554+ FunctionType *FT =
555+ FunctionType::get (Type::getVoidTy (M.getContext ()),
556+ {PointerType::get (elementType, dstaddr),
557+ PointerType::get (elementType, srcaddr),
558+ IntegerType::get (M.getContext (), bitwidth)},
559+ false );
556560
557561#if LLVM_VERSION_MAJOR >= 9
558562 Function *F = cast<Function>(M.getOrInsertFunction (name, FT).getCallee ());
@@ -742,15 +746,14 @@ Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, Type *IT,
742746}
743747
744748// TODO implement differential memmove
745- Function *getOrInsertDifferentialFloatMemmove (Module &M, Type *T,
746- unsigned dstalign,
747- unsigned srcalign,
748- unsigned dstaddr,
749- unsigned srcaddr) {
749+ Function *
750+ getOrInsertDifferentialFloatMemmove (Module &M, Type *T, unsigned dstalign,
751+ unsigned srcalign, unsigned dstaddr,
752+ unsigned srcaddr, unsigned bitwidth) {
750753 llvm::errs () << " warning: didn't implement memmove, using memcpy as fallback "
751754 " which can result in errors\n " ;
752755 return getOrInsertDifferentialFloatMemcpy (M, T, dstalign, srcalign, dstaddr,
753- srcaddr);
756+ srcaddr, bitwidth );
754757}
755758
756759Function *getOrInsertCheckedFree (Module &M, CallInst *call, Type *Ty,
0 commit comments