@@ -24,28 +24,30 @@ static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) {
2424#endif // LLVM_VERSION_MAJOR
2525}
2626
27- bool provideDefinitions (Module &M, std::set<std::string> ignoreFunctions = {}) {
27+ bool provideDefinitions (Module &M, std::set<std::string> ignoreFunctions,
28+ std::vector<std::string> &replaced) {
2829 std::vector<StringRef> todo;
2930 bool seen32 = false ;
3031 bool seen64 = false ;
3132 for (auto &F : M) {
3233 if (!F.empty ())
3334 continue ;
35+ if (ignoreFunctions.count (F.getName ().str ()))
36+ continue ;
3437 int index = 0 ;
3538 for (auto postfix : {" " , " _" , " _64_" }) {
3639 std::string str;
3740 if (strlen (postfix) == 0 ) {
3841 str = F.getName ().str ();
39- if (ignoreFunctions.count (str)) continue ;
4042 } else if (endsWith (F.getName (), postfix)) {
4143 auto blasName =
4244 F.getName ().substr (0 , F.getName ().size () - strlen (postfix)).str ();
43- if (ignoreFunctions.count (blasName)) continue ;
4445 str = " cblas_" + blasName;
4546 }
4647
4748 auto found = EnzymeBlasBC.find (str);
4849 if (found != EnzymeBlasBC.end ()) {
50+ replaced.push_back (F.getName ().str ());
4951 todo.push_back (found->second );
5052 if (index == 1 )
5153 seen32 = true ;
@@ -81,13 +83,23 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
8183 });
8284#endif
8385
84- if (!BC)
86+ if (!BC) {
8587 Err.print (" bcloader" , llvm::errs ());
88+ continue ;
89+ }
8690 assert (BC);
8791 SmallVector<std::string, 1 > toReplace;
8892 for (auto &F : *BC) {
8993 if (F.empty ())
9094 continue ;
95+ if (ignoreFunctions.count (F.getName ().str ())) {
96+ #if LLVM_VERSION_MAJOR >= 16
97+ F.erase (F.begin (), F.end ());
98+ #else
99+ F.getBasicBlockList ().erase (F.begin (), F.end ());
100+ #endif
101+ continue ;
102+ }
91103 toReplace.push_back (F.getName ().str ());
92104 }
93105 BC->setTargetTriple (" " );
@@ -106,12 +118,29 @@ bool provideDefinitions(Module &M, std::set<std::string> ignoreFunctions = {}) {
106118
107119extern " C" {
108120uint8_t EnzymeBitcodeReplacement (LLVMModuleRef M, char **FncsNamesToIgnore,
109- size_t numFncNames) {
121+ size_t numFncNames, const char ***foundP,
122+ size_t *foundLen) {
110123 std::set<std::string> ignoreFunctions = {};
111124 for (size_t i = 0 ; i < numFncNames; i++) {
112125 ignoreFunctions.insert (std::string (FncsNamesToIgnore[i]));
113126 }
114- return provideDefinitions (*unwrap (M), ignoreFunctions);
127+ std::vector<std::string> replaced;
128+ auto res = provideDefinitions (*unwrap (M), ignoreFunctions, replaced);
129+
130+ const char **found = nullptr ;
131+ if (replaced.size ()) {
132+ found = (const char **)malloc (replaced.size () * sizeof (const char **));
133+ for (size_t i = 0 ; i < replaced.size (); i++) {
134+ char *data = (char *)malloc (replaced[i].size () + 1 );
135+ memcpy (data, replaced[i].data (), replaced[i].size ());
136+ data[replaced[i].size ()] = 0 ;
137+ found[i] = data;
138+ }
139+ }
140+ *foundP = found;
141+ *foundLen = replaced.size ();
142+
143+ return res;
115144}
116145}
117146
@@ -121,7 +150,10 @@ class BCLoader final : public ModulePass {
121150 static char ID;
122151 BCLoader () : ModulePass(ID) {}
123152
124- bool runOnModule (Module &M) override { return provideDefinitions (M, {}); }
153+ bool runOnModule (Module &M) override {
154+ std::vector<std::string> replaced;
155+ return provideDefinitions (M, {}, replaced);
156+ }
125157};
126158} // namespace
127159
0 commit comments