Skip to content

Commit 927ba3a

Browse files
ZuseZ4wsmoses
andauthored
Dbg tblgen (#1190)
* fix fortran abi, support scal, WIP * scal support, fix fortran abi, fix other things * fix build * fix format --------- Co-authored-by: William S. Moses <[email protected]>
1 parent 5487147 commit 927ba3a

File tree

9 files changed

+108
-77
lines changed

9 files changed

+108
-77
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ class AdjointGenerator
799799

800800
if (prediff)
801801
((DiffeGradientUtils *)gutils)
802-
->addToInvertedPtrDiffe(&I, vd, LoadSize, I.getOperand(0),
802+
->addToInvertedPtrDiffe(&I, &I, vd, LoadSize, I.getOperand(0),
803803
prediff, Builder2, alignment, premask);
804804

805805
unsigned start = 0;

enzyme/Enzyme/CApi.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -393,36 +393,39 @@ void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val,
393393
}
394394

395395
void EnzymeGradientUtilsAddToInvertedPointerDiffe(
396-
DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMTypeRef addingType,
397-
unsigned start, unsigned size, LLVMValueRef origptr, LLVMValueRef dif,
398-
LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef mask) {
396+
DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal,
397+
LLVMTypeRef addingType, unsigned start, unsigned size, LLVMValueRef origptr,
398+
LLVMValueRef dif, LLVMBuilderRef BuilderM, unsigned align,
399+
LLVMValueRef mask) {
399400
#if LLVM_VERSION_MAJOR >= 10
400401
MaybeAlign align2;
401402
if (align)
402403
align2 = MaybeAlign(align);
403404
#else
404405
auto align2 = align;
405406
#endif
406-
gutils->addToInvertedPtrDiffe(
407-
cast_or_null<Instruction>(unwrap(orig)), unwrap(addingType), start, size,
408-
unwrap(origptr), unwrap(dif), *unwrap(BuilderM), align2, unwrap(mask));
407+
auto inst = cast_or_null<Instruction>(unwrap(orig));
408+
gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), unwrap(addingType),
409+
start, size, unwrap(origptr), unwrap(dif),
410+
*unwrap(BuilderM), align2, unwrap(mask));
409411
}
410412

411413
void EnzymeGradientUtilsAddToInvertedPointerDiffeTT(
412-
DiffeGradientUtils *gutils, LLVMValueRef orig, CTypeTreeRef vd,
413-
unsigned LoadSize, LLVMValueRef origptr, LLVMValueRef prediff,
414-
LLVMBuilderRef BuilderM, unsigned align, LLVMValueRef premask) {
414+
DiffeGradientUtils *gutils, LLVMValueRef orig, LLVMValueRef origVal,
415+
CTypeTreeRef vd, unsigned LoadSize, LLVMValueRef origptr,
416+
LLVMValueRef prediff, LLVMBuilderRef BuilderM, unsigned align,
417+
LLVMValueRef premask) {
415418
#if LLVM_VERSION_MAJOR >= 10
416419
MaybeAlign align2;
417420
if (align)
418421
align2 = MaybeAlign(align);
419422
#else
420423
auto align2 = align;
421424
#endif
422-
gutils->addToInvertedPtrDiffe(cast_or_null<Instruction>(unwrap(orig)),
423-
*(TypeTree *)vd, LoadSize, unwrap(origptr),
424-
unwrap(prediff), *unwrap(BuilderM), align2,
425-
unwrap(premask));
425+
auto inst = cast_or_null<Instruction>(unwrap(orig));
426+
gutils->addToInvertedPtrDiffe(inst, unwrap(origVal), *(TypeTree *)vd,
427+
LoadSize, unwrap(origptr), unwrap(prediff),
428+
*unwrap(BuilderM), align2, unwrap(premask));
426429
}
427430

428431
void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val,

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -584,16 +584,16 @@ CallInst *DiffeGradientUtils::freeCache(BasicBlock *forwardPreheader,
584584

585585
#if LLVM_VERSION_MAJOR >= 10
586586
void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
587-
Type *addingType, unsigned start,
588-
unsigned size, Value *origptr,
589-
Value *dif,
587+
Value *origVal, Type *addingType,
588+
unsigned start, unsigned size,
589+
Value *origptr, Value *dif,
590590
IRBuilder<> &BuilderM,
591591
MaybeAlign align, Value *mask)
592592
#else
593593
void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
594-
Type *addingType, unsigned start,
595-
unsigned size, Value *origptr,
596-
Value *dif,
594+
Value *origVal, Type *addingType,
595+
unsigned start, unsigned size,
596+
Value *origptr, Value *dif,
597597
IRBuilder<> &BuilderM,
598598
unsigned align, Value *mask)
599599
#endif
@@ -892,8 +892,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
892892

893893
SmallVector<Metadata *, 1> scopeMD = {
894894
getDerivativeAliasScope(origptr, idx)};
895-
if (orig)
896-
if (auto MD = orig->getMetadata(LLVMContext::MD_alias_scope)) {
895+
if (auto origValI = dyn_cast_or_null<Instruction>(origVal))
896+
if (auto MD = origValI->getMetadata(LLVMContext::MD_alias_scope)) {
897897
auto MDN = cast<MDNode>(MD);
898898
for (auto &o : MDN->operands())
899899
scopeMD.push_back(o);
@@ -907,8 +907,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
907907
if (j != (ssize_t)idx)
908908
MDs.push_back(getDerivativeAliasScope(origptr, j));
909909
}
910-
if (orig)
911-
if (auto MD = orig->getMetadata(LLVMContext::MD_noalias)) {
910+
if (auto origValI = dyn_cast_or_null<Instruction>(origVal))
911+
if (auto MD = origValI->getMetadata(LLVMContext::MD_noalias)) {
912912
auto MDN = cast<MDNode>(MD);
913913
for (auto &o : MDN->operands())
914914
MDs.push_back(o);
@@ -918,17 +918,19 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
918918
LI->setMetadata(LLVMContext::MD_noalias, noscope);
919919
st->setMetadata(LLVMContext::MD_noalias, noscope);
920920

921-
if (orig && start == 0 &&
922-
size == (DL.getTypeSizeInBits(orig->getType()) + 7) / 8) {
923-
LI->copyMetadata(*orig, MD_ToCopy);
924-
LI->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
921+
if (origVal && isa<Instruction>(origVal) && start == 0 &&
922+
size == (DL.getTypeSizeInBits(origVal->getType()) + 7) / 8) {
923+
auto origValI = cast<Instruction>(origVal);
924+
LI->copyMetadata(*origValI, MD_ToCopy);
925925
unsigned int StoreData[] = {LLVMContext::MD_tbaa,
926926
LLVMContext::MD_tbaa_struct};
927927
for (auto MD : StoreData)
928-
st->setMetadata(MD, orig->getMetadata(MD));
929-
st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
928+
st->setMetadata(MD, origValI->getMetadata(MD));
930929
}
931930

931+
LI->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
932+
st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc()));
933+
932934
if (align) {
933935
#if LLVM_VERSION_MAJOR >= 10
934936
auto alignv = align ? align.getValue().value() : 0;
@@ -987,14 +989,14 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
987989

988990
#if LLVM_VERSION_MAJOR >= 10
989991
void DiffeGradientUtils::addToInvertedPtrDiffe(
990-
llvm::Instruction *orig, TypeTree vd, unsigned LoadSize,
991-
llvm::Value *origptr, llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
992-
MaybeAlign alignment, llvm::Value *premask)
992+
llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd,
993+
unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
994+
llvm::IRBuilder<> &Builder2, MaybeAlign alignment, llvm::Value *premask)
993995
#else
994996
void DiffeGradientUtils::addToInvertedPtrDiffe(
995-
llvm::Instruction *orig, TypeTree vd, unsigned LoadSize,
996-
llvm::Value *origptr, llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
997-
unsigned alignment, llvm::Value *premask)
997+
llvm::Instruction *orig, llvm::Value *origVal, TypeTree vd,
998+
unsigned LoadSize, llvm::Value *origptr, llvm::Value *prediff,
999+
llvm::IRBuilder<> &Builder2, unsigned alignment, llvm::Value *premask)
9981000
#endif
9991001
{
10001002

@@ -1026,12 +1028,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
10261028

10271029
if (Type *isfloat = dt.isFloat()) {
10281030

1029-
if (orig) {
1031+
if (origVal) {
10301032
if (start == 0 && nextStart == LoadSize) {
1031-
setDiffe(orig, Constant::getNullValue(getShadowType(orig->getType())),
1033+
setDiffe(origVal,
1034+
Constant::getNullValue(getShadowType(origVal->getType())),
10321035
Builder2);
10331036
} else {
1034-
Value *tostore = getDifferential(orig);
1037+
Value *tostore = getDifferential(origVal);
10351038

10361039
auto i8 = Type::getInt8Ty(tostore->getContext());
10371040
if (start != 0) {
@@ -1074,8 +1077,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
10741077
// Masked partial type is unhanled.
10751078
if (premask)
10761079
assert(start == 0 && nextStart == LoadSize);
1077-
addToInvertedPtrDiffe(orig, isfloat, start, nextStart - start, origptr,
1078-
prediff, Builder2, alignment, premask);
1080+
addToInvertedPtrDiffe(orig, origVal, isfloat, start, nextStart - start,
1081+
origptr, prediff, Builder2, alignment, premask);
10791082
}
10801083
}
10811084

enzyme/Enzyme/DiffeGradientUtils.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,31 +99,32 @@ class DiffeGradientUtils final : public GradientUtils {
9999

100100
/// align is the alignment that should be specified for load/store to pointer
101101
#if LLVM_VERSION_MAJOR >= 10
102-
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Type *addingType,
103-
unsigned start, unsigned size,
104-
llvm::Value *origptr, llvm::Value *dif,
105-
llvm::IRBuilder<> &BuilderM,
106-
llvm::MaybeAlign align,
102+
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
103+
llvm::Type *addingType, unsigned start,
104+
unsigned size, llvm::Value *origptr,
105+
llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
106+
llvm::MaybeAlign align = llvm::MaybeAlign(),
107107
llvm::Value *mask = nullptr);
108108
#else
109-
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Type *addingType,
110-
unsigned start, unsigned size,
111-
llvm::Value *origptr, llvm::Value *dif,
112-
llvm::IRBuilder<> &BuilderM, unsigned align,
113-
llvm::Value *mask = nullptr);
109+
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
110+
llvm::Type *addingType, unsigned start,
111+
unsigned size, llvm::Value *origptr,
112+
llvm::Value *dif, llvm::IRBuilder<> &BuilderM,
113+
unsigned align = 0, llvm::Value *mask = nullptr);
114114
#endif
115115

116116
#if LLVM_VERSION_MAJOR >= 10
117-
void addToInvertedPtrDiffe(llvm::Instruction *orig, TypeTree vd,
118-
unsigned size, llvm::Value *origptr,
117+
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
118+
TypeTree vd, unsigned size, llvm::Value *origptr,
119119
llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
120-
llvm::MaybeAlign align,
120+
llvm::MaybeAlign align = llvm::MaybeAlign(),
121121
llvm::Value *premask = nullptr);
122122
#else
123-
void addToInvertedPtrDiffe(llvm::Instruction *orig, TypeTree vd,
124-
unsigned size, llvm::Value *origptr,
123+
void addToInvertedPtrDiffe(llvm::Instruction *orig, llvm::Value *origVal,
124+
TypeTree vd, unsigned size, llvm::Value *origptr,
125125
llvm::Value *prediff, llvm::IRBuilder<> &Builder2,
126-
unsigned align, llvm::Value *premask = nullptr);
126+
unsigned align = 0,
127+
llvm::Value *premask = nullptr);
127128
#endif
128129
};
129130

enzyme/Enzyme/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,7 +1844,7 @@ Function *GetFunctionFromValue(Value *fn) {
18441844
}
18451845

18461846
size_t getFirstLenOrIncPosition(BlasInfo blas) {
1847-
if (blas.function == "dot") {
1847+
if (blas.function == "dot" || blas.function == "scal") {
18481848
return 0;
18491849
} else {
18501850
llvm::errs() << "unsuported BLAS fnc\n";
@@ -1854,7 +1854,7 @@ size_t getFirstLenOrIncPosition(BlasInfo blas) {
18541854

18551855
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in) {
18561856
llvm::Twine floatType[] = {"s", "d"}; // c, z
1857-
llvm::Twine extractable[] = {"dot"};
1857+
llvm::Twine extractable[] = {"dot", "scal"};
18581858
llvm::Twine prefixes[] = {"" /*Fortran*/, "cblas_", "cublas_"};
18591859
llvm::Twine suffixes[] = {"", "_", "64_", "_64_"};
18601860
for (auto t : floatType) {

enzyme/Enzyme/targets/BlasDerivatives.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class Constant<string _value> {
7272
def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx),
7373
["x"],[len, fp, vinc],
7474
[
75-
(b<"asum"> $n, input<"x">, $incx),
75+
// dot must proceed scal, because scal modifies adj<"x">
76+
(b<"dot"> $n, $x, $incx, adj<"x">, $incx),
7677
(b<"scal"> $n, $alpha, adj<"x">, $incx)
7778
]
7879
>;

enzyme/tools/enzyme-tblgen/blasTAUpdater.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@ void emit_BLASTypes(raw_ostream &os) {
22

33
os << "size_t firstIntPos = getFirstLenOrIncPosition(blas);\n";
44

5-
os << "#if LLVM_VERSION_MAJOR >= 10\n"
6-
<< " const bool byRef = !call.getArgOperand(firstIntPos\n"
7-
<< ")->getType()->isIntegerTy() && blas.prefix == \"\";\n"
8-
<< "#else\n"
9-
<< " const bool byRef = !call.getOperand(firstIntPos\n"
10-
<< ")->getType()->isIntegerTy() && blas.prefix == \"\";\n"
11-
<< "#endif\n";
5+
os << " const bool byRef = blas.prefix == \"\";\n";
126

137
os << "TypeTree ttFloat;\n"
148
<< "llvm::Type *floatType; \n"
@@ -17,7 +11,12 @@ void emit_BLASTypes(raw_ostream &os) {
1711
<< "} else {\n"
1812
<< " floatType = Type::getDoubleTy(call.getContext());\n"
1913
<< "}\n"
20-
<< "ttFloat.insert({-1},floatType);\n";
14+
<< "if (byRef) {\n"
15+
<< " ttFloat.insert({-1},BaseType::Pointer);\n"
16+
<< " ttFloat.insert({-1,0},floatType);\n"
17+
<< "} else { \n"
18+
<< " ttFloat.insert({-1},floatType);\n"
19+
<< "}\n";
2120

2221
os << "TypeTree ttInt;\n"
2322
<< "if (byRef) {\n"

enzyme/tools/enzyme-tblgen/datastructures.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ class Rule {
8686
size_t getHandledArgIdx() { return activeArg; }
8787
StringMap<size_t> getArgNameMap() { return argNameToPos; }
8888
DenseMap<size_t, argType> getArgTypeMap() { return argTypes; }
89+
//std::string to_string(Rule const&r) {
90+
// std::string res = "function: " + r.blasName + "\n";
91+
// res += "handling
92+
93+
// for (auto rule : r.rules) {
94+
// }
95+
//}
8996
};
9097

9198
void fillActiveArgSet(const Record *pattern,
@@ -197,6 +204,11 @@ class TGPattern {
197204
public:
198205
TGPattern(Record &r) {
199206
blasName = r.getNameInitAsString();
207+
// if (blasName != "scal") {
208+
// llvm::errs() << blasName << " skipped!\n";
209+
// return;
210+
// }
211+
// llvm::errs() << blasName << "\n";
200212

201213
args = llvm::SmallVector<std::string, 6>();
202214
argNameToPos = StringMap<size_t>{};
@@ -216,6 +228,7 @@ class TGPattern {
216228
rules = llvm::SmallVector<Rule, 3>{};
217229
ListInit *derivOps = r.getValueAsListInit("ArgDerivatives");
218230
for (auto derivOp : llvm::enumerate(*derivOps)) {
231+
// llvm::errs() << derivOp.index() << ": \n";
219232
DagInit *derivRule = cast<DagInit>(derivOp.value());
220233
size_t actIdx = posActArgs[derivOp.index()];
221234
rules.push_back(
@@ -225,6 +238,14 @@ class TGPattern {
225238

226239
argUsers = DenseMap<size_t, DenseSet<size_t>>();
227240
fillArgUserMap(rules, args, posActArgs, argUsers);
241+
// for (auto key : argUsers) {
242+
// DenseSet<size_t> users = key.second; // argUsers.lookup(key);
243+
// llvm::errs() << "\nKey " << key.first << ": ";
244+
// for (auto user: users) {
245+
// llvm::errs() << user << " ";
246+
// }
247+
// llvm::errs() << "\n";
248+
// }
228249
}
229250
DenseMap<size_t, DenseSet<size_t>> getArgUsers() { return argUsers; }
230251
std::string getName() { return blasName; }

0 commit comments

Comments
 (0)