Skip to content

Commit 6d8404a

Browse files
Fix tune gemm breaks (#1175)
1 parent f794ae4 commit 6d8404a

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

aiter/tuned_gemm.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,28 @@
4040
soltype = 0
4141

4242

43+
@torch_compile_guard()
44+
def create_ds_custom() -> None:
45+
global solids, bestsols, solMap
46+
df: pd.DataFrame = bestsols
47+
solids = {}
48+
for i in range(len(df)):
49+
ds = df.iloc[i]
50+
key = (
51+
ds["M"],
52+
ds["N"],
53+
ds["K"],
54+
ds["bias"],
55+
ds["dtype"],
56+
ds["outdtype"],
57+
ds["scaleAB"],
58+
)
59+
60+
if ds["libtype"] in ["hipblaslt", "rocblas", "asm"]:
61+
soltype_ = solMap.index(ds["libtype"])
62+
solids[key] = (soltype_, int(ds["solidx"]))
63+
64+
4365
@torch_compile_guard()
4466
def load_best_sols_custom(tune_path: str) -> bool:
4567
global bestsols
@@ -160,25 +182,7 @@ def load_best_sols(self):
160182
self.bestsols = bestsols
161183

162184
def create_ds(self):
163-
global solids
164-
df: pd.DataFrame = self.bestsols
165-
solds = {}
166-
for i in range(len(df)):
167-
ds = df.iloc[i]
168-
key = (
169-
ds["M"],
170-
ds["N"],
171-
ds["K"],
172-
ds["bias"],
173-
ds["dtype"],
174-
ds["outdtype"],
175-
ds["scaleAB"],
176-
)
177-
178-
if ds["libtype"] in ["hipblaslt", "rocblas", "asm"]:
179-
soltype = self.solMap.index(ds["libtype"])
180-
solds[key] = (soltype, int(ds["solidx"]))
181-
solids = solds
185+
create_ds_custom()
182186
self.solfuncs = [
183187
self.apply_torch_mm,
184188
self.apply_hipb_mm,

0 commit comments

Comments
 (0)