Skip to content

Commit a4bff8d

Browse files
authored
Updating typehints (Optional, Dict, List, Tuple) to Python 3.10+ standard (#66)
* Replace Optional with | None * Fix optional typos * Update typing for dict/list/tuple * Whitespace * Fix rebase misses * apply to platform conig
1 parent 9cece75 commit a4bff8d

28 files changed

+204
-212
lines changed

Fuser/auto_agent.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
import sys
5151
from dataclasses import dataclass
5252
from pathlib import Path
53-
from typing import Any, Dict, Optional, Tuple
53+
from typing import Any
5454

5555
from dotenv import load_dotenv
5656
from Fuser.pipeline import run_pipeline
@@ -109,7 +109,7 @@ def _file_sha256_text(txt: str) -> str:
109109
return hashlib.sha256(txt.encode("utf-8")).hexdigest()
110110

111111

112-
def _load_router_cache() -> Dict[str, Any]:
112+
def _load_router_cache() -> dict[str, Any]:
113113
try:
114114
if _ROUTER_CACHE_PATH.is_file():
115115
return json.loads(_ROUTER_CACHE_PATH.read_text(encoding="utf-8"))
@@ -118,7 +118,7 @@ def _load_router_cache() -> Dict[str, Any]:
118118
return {}
119119

120120

121-
def _save_router_cache(cache: Dict[str, Any]) -> None:
121+
def _save_router_cache(cache: dict[str, Any]) -> None:
122122
try:
123123
_ensure_dir(_ROUTER_CACHE_PATH)
124124
_ROUTER_CACHE_PATH.write_text(json.dumps(cache, indent=2), encoding="utf-8")
@@ -148,7 +148,7 @@ class Complexity:
148148
pool_ops: int
149149
act_ops: int
150150
chain_len_estimate: int
151-
raw_op_names: Dict[str, int]
151+
raw_op_names: dict[str, int]
152152

153153
def route_to_fuser(self) -> bool:
154154
# Primary triggers
@@ -217,7 +217,7 @@ def analyze_problem_code(code: str) -> Complexity:
217217

218218
# AST path: inspect Model.forward for ops and control flow
219219
has_control_flow = False
220-
raw_op_counts: Dict[str, int] = {}
220+
raw_op_counts: dict[str, int] = {}
221221
has_attention_like = False
222222
has_conv_transpose = False
223223
has_group_norm = False
@@ -302,19 +302,19 @@ def visit_Assign(self, node: ast.Assign) -> Any:
302302
class RouteResult:
303303
route: str # "kernelagent" or "fuser"
304304
success: bool
305-
details: Dict[str, Any]
306-
kernel_code: Optional[str] = None
305+
details: dict[str, Any]
306+
kernel_code: str | None = None
307307

308308

309309
class AutoKernelRouter:
310310
def __init__(
311311
self,
312-
ka_model: Optional[str] = None,
312+
ka_model: str | None = None,
313313
ka_num_workers: int = 4,
314314
ka_max_rounds: int = 10,
315315
ka_high_reasoning: bool = True,
316316
# Router LLM
317-
router_model: Optional[str] = "gpt-5",
317+
router_model: str | None = "gpt-5",
318318
router_high_reasoning: bool = True,
319319
router_temperature: float = 0.2,
320320
router_max_tokens: int = 700,
@@ -329,7 +329,7 @@ def __init__(
329329
verify: bool = True,
330330
dispatch_jobs: int = 2,
331331
allow_fallback: bool = True,
332-
target_platform: Optional[str] = None,
332+
target_platform: str | None = None,
333333
) -> None:
334334
self.ka_model = ka_model
335335
self.ka_num_workers = ka_num_workers
@@ -439,7 +439,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult:
439439

440440
comp = res.get("composition", {}) or {}
441441
ok = bool(comp.get("verify_passed", not self.verify))
442-
kernel_code: Optional[str] = None
442+
kernel_code: str | None = None
443443
try:
444444
composed_path = comp.get("composed_path")
445445
if composed_path and Path(composed_path).is_file():
@@ -465,9 +465,9 @@ def solve(self, problem_path: Path) -> RouteResult:
465465
cache = _load_router_cache()
466466
cached = cache.get(code_hash)
467467

468-
strategy: Optional[str] = None
469-
route_conf: Optional[float] = None
470-
route_cfg: Dict[str, Any] = {}
468+
strategy: str | None = None
469+
route_conf: float | None = None
470+
route_cfg: dict[str, Any] = {}
471471

472472
if isinstance(cached, dict):
473473
strategy = (
@@ -553,7 +553,7 @@ def solve(self, problem_path: Path) -> RouteResult:
553553
# -------- LLM decision helper --------
554554
def _llm_decide_route(
555555
self, problem_path: Path, code: str, cx: Complexity
556-
) -> Tuple[Optional[str], Optional[float], Dict[str, Any]]:
556+
) -> tuple[str | None, float | None, dict[str, Any]]:
557557
"""Ask an LLM to choose a routing STRATEGY and optional budgets.
558558
559559
The LLM must return JSON with keys:
@@ -629,7 +629,7 @@ def _llm_decide_route(
629629
f"Features:\n```json\n{json.dumps(feats, indent=2)}\n```\n\n"
630630
"Problem code:\n```python\n" + code + "\n```\n"
631631
)
632-
kwargs: Dict[str, Any] = {
632+
kwargs: dict[str, Any] = {
633633
"max_tokens": self.router_max_tokens,
634634
"temperature": self.router_temperature,
635635
}
@@ -644,7 +644,7 @@ def _llm_decide_route(
644644
# Best-effort JSON parse
645645
route = None
646646
conf = None
647-
raw_info: Dict[str, Any] = {"raw": txt}
647+
raw_info: dict[str, Any] = {"raw": txt}
648648
try:
649649
# If model returned extra text, try to locate JSON object
650650
first = txt.find("{")
@@ -676,7 +676,7 @@ def _llm_decide_route(
676676
# ------------------------
677677

678678

679-
def main(argv: Optional[list[str]] = None) -> int:
679+
def main(argv: list[str] | None = None) -> int:
680680
p = argparse.ArgumentParser(
681681
description="Auto-router for KernelBench problems (KernelAgent vs Fuser)"
682682
)
@@ -763,7 +763,7 @@ def main(argv: Optional[list[str]] = None) -> int:
763763
)
764764
return 1
765765

766-
out: Dict[str, Any] = {
766+
out: dict[str, Any] = {
767767
"route": res.route,
768768
"success": res.success,
769769
"details": res.details,

Fuser/compose_end_to_end.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import textwrap
4646
from dataclasses import dataclass
4747
from pathlib import Path
48-
from typing import Any, Dict, List, Optional, Tuple
48+
from typing import Any
4949

5050
from dotenv import load_dotenv
5151

@@ -79,11 +79,11 @@ def _read_text(path: Path) -> str:
7979
return path.read_text(encoding="utf-8")
8080

8181

82-
def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]:
82+
def _load_kernels_from_summary(summary_path: Path) -> list[KernelItem]:
8383
data = json.loads(_read_text(summary_path))
8484
if not isinstance(data, list):
8585
raise SystemExit("kernels summary must be a JSON array (from dispatch step)")
86-
items: List[KernelItem] = []
86+
items: list[KernelItem] = []
8787
for it in data:
8888
if not isinstance(it, dict):
8989
continue
@@ -104,8 +104,8 @@ def _load_kernels_from_summary(summary_path: Path) -> List[KernelItem]:
104104
return items
105105

106106

107-
def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str:
108-
lines: List[str] = []
107+
def _summarize_subgraphs_for_prompt(subgraphs: list[dict[str, Any]]) -> str:
108+
lines: list[str] = []
109109
for it in subgraphs:
110110
sid = str(it.get("id", "unknown"))
111111
typ = str(it.get("type", ""))
@@ -132,8 +132,8 @@ def _summarize_subgraphs_for_prompt(subgraphs: List[Dict[str, Any]]) -> str:
132132

133133
def _build_composition_prompt(
134134
problem_code: str,
135-
subgraphs: List[Dict[str, Any]],
136-
kernel_items: List[KernelItem],
135+
subgraphs: list[dict[str, Any]],
136+
kernel_items: list[KernelItem],
137137
target_platform: PlatformConfig,
138138
) -> str:
139139
"""Create a single user message to instruct composition by the LLM."""
@@ -142,7 +142,7 @@ def _build_composition_prompt(
142142

143143
# Include only essential snippets from each kernel to keep token usage sane
144144
# We include full files for now; callers can trim by model limits.
145-
kernels_section_parts: List[str] = []
145+
kernels_section_parts: list[str] = []
146146
for ki in kernel_items:
147147
kernels_section_parts.append(
148148
f"### Subgraph {ki.subgraph_id}\n```python\n" + ki.code + "\n```\n"
@@ -208,7 +208,7 @@ def _build_composition_prompt(
208208
"""
209209
).strip()
210210

211-
user_lines: List[str] = []
211+
user_lines: list[str] = []
212212
user_lines.append(guidance)
213213
user_lines.append("")
214214
user_lines.append("SUBGRAPHS (summary):")
@@ -230,10 +230,10 @@ def _build_composition_prompt(
230230

231231
def _build_refinement_prompt(
232232
problem_code: str,
233-
subgraphs: List[Dict[str, Any]],
234-
kernel_items: List[KernelItem],
233+
subgraphs: list[dict[str, Any]],
234+
kernel_items: list[KernelItem],
235235
previous_code: str,
236-
error_info: Dict[str, str],
236+
error_info: dict[str, str],
237237
target_platform: PlatformConfig,
238238
) -> str:
239239
"""Prompt the LLM to refine the previously produced code based on errors."""
@@ -262,7 +262,7 @@ def _build_refinement_prompt(
262262
"""
263263
).strip()
264264

265-
lines: List[str] = []
265+
lines: list[str] = []
266266
lines.append(guidance)
267267
lines.append("")
268268
lines.append("ERROR_CONTEXT (stderr tail):\n```\n" + err_tail + "\n```")
@@ -284,7 +284,7 @@ def _build_refinement_prompt(
284284

285285
def _auto_patch_common_triton_issues(
286286
code: str, target_platform: PlatformConfig
287-
) -> Tuple[str, bool]:
287+
) -> tuple[str, bool]:
288288
"""Apply tiny safe textual patches for known Triton pitfalls.
289289
290290
- Replace tl.broadcast(0.0, ...) or tl.broadcast(1.0, ...) with scalar constants.
@@ -336,7 +336,7 @@ def compose(
336336
verify: bool = False,
337337
max_iters: int = 5,
338338
target_platform: str = "cuda",
339-
) -> Dict[str, Any]:
339+
) -> dict[str, Any]:
340340
if get_model_provider is None:
341341
raise SystemExit(
342342
"KernelAgent providers unavailable; ensure package import and dependencies"
@@ -360,7 +360,7 @@ def compose(
360360

361361
last_usage = None
362362
last_code = None
363-
verify_info: Dict[str, Any] = {}
363+
verify_info: dict[str, Any] = {}
364364

365365
for i in range(1, max_iters + 1):
366366
if i == 1 or last_code is None:
@@ -441,7 +441,7 @@ def compose(
441441
composed_path = out_dir / "composed_kernel.py"
442442
composed_path.write_text(last_code or "", encoding="utf-8")
443443

444-
result: Dict[str, Any] = {
444+
result: dict[str, Any] = {
445445
"success": bool(verify_info.get("verify_passed", not verify)),
446446
"composed_path": str(composed_path.resolve()),
447447
"model": model_name,
@@ -458,7 +458,7 @@ def compose(
458458
return result
459459

460460

461-
def main(argv: Optional[List[str]] = None) -> int:
461+
def main(argv: list[str] | None = None) -> int:
462462
load_dotenv()
463463
p = argparse.ArgumentParser(
464464
description="Compose end-to-end Triton kernel from subgraphs + generated kernels"

Fuser/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import annotations
1515
from dataclasses import dataclass, asdict
1616
from pathlib import Path
17-
from typing import Optional
1817
import json
1918
import time
2019
import uuid
@@ -83,8 +82,8 @@ def platform_config(self) -> "PlatformConfig":
8382
@dataclass
8483
class ResultSummary:
8584
run_id: str
86-
winner_worker_id: Optional[str]
87-
artifact_path: Optional[str]
85+
winner_worker_id: str | None
86+
artifact_path: str | None
8887
reason: str
8988

9089

Fuser/dedup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
import json
1616
import time
1717
from pathlib import Path
18-
from typing import Tuple, Optional
1918

2019

2120
def register_digest(
2221
shared_digests_dir: Path, sha256: str, worker_id: str, iter_index: int
23-
) -> Tuple[str, Optional[str]]:
22+
) -> tuple[str, str | None]:
2423
"""
2524
Atomically register a digest in shared_digests_dir.
2625
Returns (status, owner_worker_id or None), where status is one of:

0 commit comments

Comments
 (0)