Skip to content

Commit 159a4c2

Browse files
authored
Add memcpy for 32 (#1133)
1 parent f6c8c5d commit 159a4c2

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5887,7 +5887,8 @@ class AdjointGenerator
58875887

58885888
auto dmemcpy = getOrInsertDifferentialFloatMemcpy(
58895889
*Builder2.GetInsertBlock()->getParent()->getParent(), secretty,
5890-
/*dstalign*/ 1, /*srcalign*/ 1, dstaddr, srcaddr);
5890+
/*dstalign*/ 1, /*srcalign*/ 1, dstaddr, srcaddr,
5891+
cast<IntegerType>(length->getType())->getBitWidth());
58915892

58925893
Builder2.CreateCall(dmemcpy, args, ReverseDefs);
58935894
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8207,7 +8207,8 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
82078207
? getOrInsertDifferentialFloatMemcpy
82088208
: getOrInsertDifferentialFloatMemmove)(
82098209
*MTI->getParent()->getParent()->getParent(), secretty, dstalign,
8210-
srcalign, dstaddr, srcaddr);
8210+
srcalign, dstaddr, srcaddr,
8211+
cast<IntegerType>(length->getType())->getBitWidth());
82118212
Builder2.CreateCall(dmemcpy, args);
82128213
}
82138214
}

enzyme/Enzyme/Utils.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -539,20 +539,24 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
539539
Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
540540
unsigned dstalign,
541541
unsigned srcalign,
542-
unsigned dstaddr,
543-
unsigned srcaddr) {
542+
unsigned dstaddr, unsigned srcaddr,
543+
unsigned bitwidth) {
544544
assert(elementType->isFloatingPointTy());
545-
std::string name = "__enzyme_memcpyadd_" + tofltstr(elementType) + "da" +
546-
std::to_string(dstalign) + "sa" + std::to_string(srcalign);
545+
std::string name = "__enzyme_memcpy";
546+
if (bitwidth != 64)
547+
name += std::to_string(bitwidth);
548+
name += "add_" + tofltstr(elementType) + "da" + std::to_string(dstalign) +
549+
"sa" + std::to_string(srcalign);
547550
if (dstaddr)
548551
name += "dadd" + std::to_string(dstaddr);
549552
if (srcaddr)
550553
name += "sadd" + std::to_string(srcaddr);
551-
FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
552-
{PointerType::get(elementType, dstaddr),
553-
PointerType::get(elementType, srcaddr),
554-
Type::getInt64Ty(M.getContext())},
555-
false);
554+
FunctionType *FT =
555+
FunctionType::get(Type::getVoidTy(M.getContext()),
556+
{PointerType::get(elementType, dstaddr),
557+
PointerType::get(elementType, srcaddr),
558+
IntegerType::get(M.getContext(), bitwidth)},
559+
false);
556560

557561
#if LLVM_VERSION_MAJOR >= 9
558562
Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
@@ -742,15 +746,14 @@ Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, Type *IT,
742746
}
743747

744748
// TODO implement differential memmove
745-
Function *getOrInsertDifferentialFloatMemmove(Module &M, Type *T,
746-
unsigned dstalign,
747-
unsigned srcalign,
748-
unsigned dstaddr,
749-
unsigned srcaddr) {
749+
Function *
750+
getOrInsertDifferentialFloatMemmove(Module &M, Type *T, unsigned dstalign,
751+
unsigned srcalign, unsigned dstaddr,
752+
unsigned srcaddr, unsigned bitwidth) {
750753
llvm::errs() << "warning: didn't implement memmove, using memcpy as fallback "
751754
"which can result in errors\n";
752755
return getOrInsertDifferentialFloatMemcpy(M, T, dstalign, srcalign, dstaddr,
753-
srcaddr);
756+
srcaddr, bitwidth);
754757
}
755758

756759
Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty,

enzyme/Enzyme/Utils.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -603,10 +603,9 @@ static inline bool isCertainPrint(const llvm::StringRef name) {
603603

604604
/// Create function for type that performs the derivative memcpy on floating
605605
/// point memory
606-
llvm::Function *
607-
getOrInsertDifferentialFloatMemcpy(llvm::Module &M, llvm::Type *T,
608-
unsigned dstalign, unsigned srcalign,
609-
unsigned dstaddr, unsigned srcaddr);
606+
llvm::Function *getOrInsertDifferentialFloatMemcpy(
607+
llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign,
608+
unsigned dstaddr, unsigned srcaddr, unsigned bitwidth);
610609

611610
/// Create function for type that performs memcpy with a stride
612611
llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
@@ -615,10 +614,9 @@ llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
615614

616615
/// Create function for type that performs the derivative memmove on floating
617616
/// point memory
618-
llvm::Function *
619-
getOrInsertDifferentialFloatMemmove(llvm::Module &M, llvm::Type *T,
620-
unsigned dstalign, unsigned srcalign,
621-
unsigned dstaddr, unsigned srcaddr);
617+
llvm::Function *getOrInsertDifferentialFloatMemmove(
618+
llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign,
619+
unsigned dstaddr, unsigned srcaddr, unsigned bitwidth);
622620

623621
llvm::Function *getOrInsertCheckedFree(llvm::Module &M, llvm::CallInst *call,
624622
llvm::Type *Type, unsigned width);

0 commit comments

Comments
 (0)