Skip to content

Commit 7c7bb60

Browse files
ZuseZ4wsmoses
andauthored
Tg blas3 (#1117)
* removing manual blas support * minimal blas generation * add blas sanity checks * emit handle_blas * emit more parts * emit caching * finish rebasing of tblgen-blas * finish refactor on latest main * fix fwd-mode * update tests * fix Attribute settings * fix tg for older llvm * add more const and clarify comments * add inactive fwd test * disable dot fallback * make BLAS fallback optional * cleanup * fixup * cleanup --------- Co-authored-by: William S. Moses <[email protected]>
1 parent eeea6e6 commit 7c7bb60

File tree

14 files changed

+2089
-951
lines changed

14 files changed

+2089
-951
lines changed

enzyme/BCLoad/BCLoader.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using namespace llvm;
1616
#include "blas_headers.h"
1717
#undef DATA
1818

19-
bool provideDefinitions(Module &M) {
19+
bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
2020
std::vector<StringRef> todo;
2121
bool seen32 = false;
2222
bool seen64 = false;
@@ -27,11 +27,14 @@ bool provideDefinitions(Module &M) {
2727
int index = 0;
2828
for (auto postfix : {"", "_", "_64_"}) {
2929
std::string str;
30-
if (strlen(postfix) == 0)
30+
if (strlen(postfix) == 0) {
3131
str = F.getName().str();
32-
else if (F.getName().endswith(postfix)) {
33-
str = "cblas_" +
34-
F.getName().substr(0, F.getName().size() - strlen(postfix)).str();
32+
if (ignoreFunctions.count(str)) continue;
33+
} else if (F.getName().endswith(postfix)) {
34+
auto blasName =
35+
F.getName().substr(0, F.getName().size() - strlen(postfix)).str();
36+
if (ignoreFunctions.count(blasName)) continue;
37+
str = "cblas_" + blasName;
3538
}
3639

3740
auto found = EnzymeBlasBC.find(str);
@@ -96,8 +99,13 @@ bool provideDefinitions(Module &M) {
9699
}
97100

98101
extern "C" {
99-
uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M) {
100-
return provideDefinitions(*unwrap(M));
102+
uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore,
103+
size_t numFncNames) {
104+
std::set<std::string> ignoreFunctions = {};
105+
for (size_t i = 0; i < numFncNames; i++) {
106+
ignoreFunctions.insert(std::string(FncsNamesToIgnore[i]));
107+
}
108+
return provideDefinitions(*unwrap(M), ignoreFunctions);
101109
}
102110
}
103111

@@ -107,7 +115,7 @@ class BCLoader final : public ModulePass {
107115
static char ID;
108116
BCLoader() : ModulePass(ID) {}
109117

110-
bool runOnModule(Module &M) override { return provideDefinitions(M); }
118+
bool runOnModule(Module &M) override { return provideDefinitions(M, {}); }
111119
};
112120
} // namespace
113121

0 commit comments

Comments
 (0)