Skip to content

Commit 82f8cdf

Browse files
authored
BCLoad: Still ignore if cblas lowering required (#1709)
* BCLoad: Still ignore if cblas lowering required * inform which were replaced
1 parent e014a85 commit 82f8cdf

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

enzyme/BCLoad/BCLoader.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,30 @@ static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) {
2424
#endif // LLVM_VERSION_MAJOR
2525
}
2626

27-
bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
27+
bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions,
28+
std::vector<std::string> &replaced) {
2829
std::vector<StringRef> todo;
2930
bool seen32 = false;
3031
bool seen64 = false;
3132
for (auto &F : M) {
3233
if (!F.empty())
3334
continue;
35+
if (ignoreFunctions.count(F.getName().str()))
36+
continue;
3437
int index = 0;
3538
for (auto postfix : {"", "_", "_64_"}) {
3639
std::string str;
3740
if (strlen(postfix) == 0) {
3841
str = F.getName().str();
39-
if (ignoreFunctions.count(str)) continue;
4042
} else if (endsWith(F.getName(), postfix)) {
4143
auto blasName =
4244
F.getName().substr(0, F.getName().size() - strlen(postfix)).str();
43-
if (ignoreFunctions.count(blasName)) continue;
4445
str = "cblas_" + blasName;
4546
}
4647

4748
auto found = EnzymeBlasBC.find(str);
4849
if (found != EnzymeBlasBC.end()) {
50+
replaced.push_back(F.getName().str());
4951
todo.push_back(found->second);
5052
if (index == 1)
5153
seen32 = true;
@@ -81,13 +83,23 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
8183
});
8284
#endif
8385

84-
if (!BC)
86+
if (!BC) {
8587
Err.print("bcloader", llvm::errs());
88+
continue;
89+
}
8690
assert(BC);
8791
SmallVector<std::string, 1> toReplace;
8892
for (auto &F : *BC) {
8993
if (F.empty())
9094
continue;
95+
if (ignoreFunctions.count(F.getName().str())) {
96+
#if LLVM_VERSION_MAJOR >= 16
97+
F.erase(F.begin(), F.end());
98+
#else
99+
F.getBasicBlockList().erase(F.begin(), F.end());
100+
#endif
101+
continue;
102+
}
91103
toReplace.push_back(F.getName().str());
92104
}
93105
BC->setTargetTriple("");
@@ -106,12 +118,29 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
106118

107119
extern "C" {
108120
uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore,
109-
size_t numFncNames) {
121+
size_t numFncNames, const char ***foundP,
122+
size_t *foundLen) {
110123
std::set<std::string> ignoreFunctions = {};
111124
for (size_t i = 0; i < numFncNames; i++) {
112125
ignoreFunctions.insert(std::string(FncsNamesToIgnore[i]));
113126
}
114-
return provideDefinitions(*unwrap(M), ignoreFunctions);
127+
std::vector<std::string> replaced;
128+
auto res = provideDefinitions(*unwrap(M), ignoreFunctions, replaced);
129+
130+
const char **found = nullptr;
131+
if (replaced.size()) {
132+
found = (const char **)malloc(replaced.size() * sizeof(const char **));
133+
for (size_t i = 0; i < replaced.size(); i++) {
134+
char *data = (char *)malloc(replaced[i].size() + 1);
135+
memcpy(data, replaced[i].data(), replaced[i].size());
136+
data[replaced[i].size()] = 0;
137+
found[i] = data;
138+
}
139+
}
140+
*foundP = found;
141+
*foundLen = replaced.size();
142+
143+
return res;
115144
}
116145
}
117146

@@ -121,7 +150,10 @@ class BCLoader final : public ModulePass {
121150
static char ID;
122151
BCLoader() : ModulePass(ID) {}
123152

124-
bool runOnModule(Module &M) override { return provideDefinitions(M, {}); }
153+
bool runOnModule(Module &M) override {
154+
std::vector<std::string> replaced;
155+
return provideDefinitions(M, {}, replaced);
156+
}
125157
};
126158
} // namespace
127159

0 commit comments

Comments
 (0)