Skip to content

Commit 1ada437

Browse files
authored
Add collect offset c api function (#1465)
1 parent 08288a0 commit 1ada437

File tree

4 files changed

+89
-1
lines changed

4 files changed

+89
-1
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,24 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC,
11881188
LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) {
11891189
return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType());
11901190
}
1191+
LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r,
1192+
LLVMTypeRef T_r) {
1193+
IRBuilder<> &B = *unwrap(B_r);
1194+
auto T = cast<IntegerType>(unwrap(T_r));
1195+
auto width = T->getBitWidth();
1196+
auto gep = cast<GetElementPtrInst>(unwrap(V_r));
1197+
auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout();
1198+
1199+
MapVector<Value *, APInt> VariableOffsets;
1200+
APInt Offset(width, 0);
1201+
bool success = collectOffset(gep, DL, width, VariableOffsets, Offset);
1202+
assert(success);
1203+
Value *start = ConstantInt::get(T, Offset);
1204+
for (auto &pair : VariableOffsets)
1205+
start = B.CreateAdd(
1206+
start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second)));
1207+
return wrap(start);
1208+
}
11911209
}
11921210

11931211
static size_t num_rooting(llvm::Type *T, llvm::Function *F) {

enzyme/Enzyme/Utils.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "llvm/IR/BasicBlock.h"
3838
#include "llvm/IR/DerivedTypes.h"
3939
#include "llvm/IR/Function.h"
40+
#include "llvm/IR/GetElementPtrTypeIterator.h"
4041
#include "llvm/IR/IRBuilder.h"
4142
#include "llvm/IR/InlineAsm.h"
4243
#include "llvm/IR/Module.h"
@@ -2648,3 +2649,67 @@ CountTrackedPointers::CountTrackedPointers(Type *T) {
26482649
if (count == 0)
26492650
all = false;
26502651
}
2652+
2653+
bool collectOffset(GetElementPtrInst *gep, const DataLayout &DL,
2654+
unsigned BitWidth,
2655+
MapVector<Value *, APInt> &VariableOffsets,
2656+
APInt &ConstantOffset) {
2657+
#if LLVM_VERSION_MAJOR >= 13
2658+
return cast<GEPOperator>(gep)->collectOffset(DL, BitWidth, VariableOffsets,
2659+
ConstantOffset);
2660+
#else
2661+
assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) &&
2662+
"The offset bit width does not match DL specification.");
2663+
2664+
auto CollectConstantOffset = [&](APInt Index, uint64_t Size) {
2665+
Index = Index.sextOrTrunc(BitWidth);
2666+
APInt IndexedSize = APInt(BitWidth, Size);
2667+
ConstantOffset += Index * IndexedSize;
2668+
};
2669+
2670+
for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
2671+
GTI != GTE; ++GTI) {
2672+
// Scalable vectors are multiplied by a runtime constant.
2673+
bool ScalableType = isa<ScalableVectorType>(GTI.getIndexedType());
2674+
2675+
Value *V = GTI.getOperand();
2676+
StructType *STy = GTI.getStructTypeOrNull();
2677+
// Handle ConstantInt if possible.
2678+
if (auto ConstOffset = dyn_cast<ConstantInt>(V)) {
2679+
if (ConstOffset->isZero())
2680+
continue;
2681+
// If the type is scalable and the constant is not zero (vscale * n * 0 =
2682+
// 0) bailout.
2683+
// TODO: If the runtime value is accessible at any point before DWARF
2684+
// emission, then we could potentially keep a forward reference to it
2685+
// in the debug value to be filled in later.
2686+
if (ScalableType)
2687+
return false;
2688+
// Handle a struct index, which adds its field offset to the pointer.
2689+
if (STy) {
2690+
unsigned ElementIdx = ConstOffset->getZExtValue();
2691+
const StructLayout *SL = DL.getStructLayout(STy);
2692+
// Element offset is in bytes.
2693+
CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)),
2694+
1);
2695+
continue;
2696+
}
2697+
CollectConstantOffset(ConstOffset->getValue(),
2698+
DL.getTypeAllocSize(GTI.getIndexedType()));
2699+
continue;
2700+
}
2701+
2702+
if (STy || ScalableType)
2703+
return false;
2704+
APInt IndexedSize =
2705+
APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
2706+
// Insert an initial offset of 0 for V iff none exists already, then
2707+
// increment the offset by IndexedSize.
2708+
if (IndexedSize != 0) {
2709+
VariableOffsets.insert({V, APInt(BitWidth, 0)});
2710+
VariableOffsets[V] += IndexedSize;
2711+
}
2712+
}
2713+
return true;
2714+
#endif
2715+
}

enzyme/Enzyme/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#ifndef ENZYME_UTILS_H
2626
#define ENZYME_UTILS_H
2727

28+
#include "llvm/ADT/MapVector.h"
2829
#include "llvm/ADT/STLExtras.h"
2930
#include "llvm/ADT/SmallPtrSet.h"
3031

@@ -1754,4 +1755,8 @@ static inline bool isSpecialPtr(llvm::Type *Ty) {
17541755
return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial;
17551756
}
17561757

1758+
bool collectOffset(llvm::GetElementPtrInst *gep, const llvm::DataLayout &DL,
1759+
unsigned BitWidth,
1760+
llvm::MapVector<llvm::Value *, llvm::APInt> &VariableOffsets,
1761+
llvm::APInt &ConstantOffset);
17571762
#endif

enzyme/test/test_find_package/main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <stdio.h>
1+
int printf(const char*, ...);
22

33
extern double __enzyme_autodiff(void*, double);
44

0 commit comments

Comments
 (0)