5050import sys
5151from dataclasses import dataclass
5252from pathlib import Path
53- from typing import Any , Dict , Optional , Tuple
53+ from typing import Any
5454
5555from dotenv import load_dotenv
5656from Fuser .pipeline import run_pipeline
@@ -109,7 +109,7 @@ def _file_sha256_text(txt: str) -> str:
109109 return hashlib .sha256 (txt .encode ("utf-8" )).hexdigest ()
110110
111111
112- def _load_router_cache () -> Dict [str , Any ]:
112+ def _load_router_cache () -> dict [str , Any ]:
113113 try :
114114 if _ROUTER_CACHE_PATH .is_file ():
115115 return json .loads (_ROUTER_CACHE_PATH .read_text (encoding = "utf-8" ))
@@ -118,7 +118,7 @@ def _load_router_cache() -> Dict[str, Any]:
118118 return {}
119119
120120
121- def _save_router_cache (cache : Dict [str , Any ]) -> None :
121+ def _save_router_cache (cache : dict [str , Any ]) -> None :
122122 try :
123123 _ensure_dir (_ROUTER_CACHE_PATH )
124124 _ROUTER_CACHE_PATH .write_text (json .dumps (cache , indent = 2 ), encoding = "utf-8" )
@@ -148,7 +148,7 @@ class Complexity:
148148 pool_ops : int
149149 act_ops : int
150150 chain_len_estimate : int
151- raw_op_names : Dict [str , int ]
151+ raw_op_names : dict [str , int ]
152152
153153 def route_to_fuser (self ) -> bool :
154154 # Primary triggers
@@ -217,7 +217,7 @@ def analyze_problem_code(code: str) -> Complexity:
217217
218218 # AST path: inspect Model.forward for ops and control flow
219219 has_control_flow = False
220- raw_op_counts : Dict [str , int ] = {}
220+ raw_op_counts : dict [str , int ] = {}
221221 has_attention_like = False
222222 has_conv_transpose = False
223223 has_group_norm = False
@@ -302,19 +302,19 @@ def visit_Assign(self, node: ast.Assign) -> Any:
302302class RouteResult :
303303 route : str # "kernelagent" or "fuser"
304304 success : bool
305- details : Dict [str , Any ]
306- kernel_code : Optional [ str ] = None
305+ details : dict [str , Any ]
306+ kernel_code : str | None = None
307307
308308
309309class AutoKernelRouter :
310310 def __init__ (
311311 self ,
312- ka_model : Optional [ str ] = None ,
312+ ka_model : str | None = None ,
313313 ka_num_workers : int = 4 ,
314314 ka_max_rounds : int = 10 ,
315315 ka_high_reasoning : bool = True ,
316316 # Router LLM
317- router_model : Optional [ str ] = "gpt-5" ,
317+ router_model : str | None = "gpt-5" ,
318318 router_high_reasoning : bool = True ,
319319 router_temperature : float = 0.2 ,
320320 router_max_tokens : int = 700 ,
@@ -329,7 +329,7 @@ def __init__(
329329 verify : bool = True ,
330330 dispatch_jobs : int = 2 ,
331331 allow_fallback : bool = True ,
332- target_platform : Optional [ str ] = None ,
332+ target_platform : str | None = None ,
333333 ) -> None :
334334 self .ka_model = ka_model
335335 self .ka_num_workers = ka_num_workers
@@ -439,7 +439,7 @@ def _solve_with_fuser(self, problem_path: Path) -> RouteResult:
439439
440440 comp = res .get ("composition" , {}) or {}
441441 ok = bool (comp .get ("verify_passed" , not self .verify ))
442- kernel_code : Optional [ str ] = None
442+ kernel_code : str | None = None
443443 try :
444444 composed_path = comp .get ("composed_path" )
445445 if composed_path and Path (composed_path ).is_file ():
@@ -465,9 +465,9 @@ def solve(self, problem_path: Path) -> RouteResult:
465465 cache = _load_router_cache ()
466466 cached = cache .get (code_hash )
467467
468- strategy : Optional [ str ] = None
469- route_conf : Optional [ float ] = None
470- route_cfg : Dict [str , Any ] = {}
468+ strategy : str | None = None
469+ route_conf : float | None = None
470+ route_cfg : dict [str , Any ] = {}
471471
472472 if isinstance (cached , dict ):
473473 strategy = (
@@ -553,7 +553,7 @@ def solve(self, problem_path: Path) -> RouteResult:
553553 # -------- LLM decision helper --------
554554 def _llm_decide_route (
555555 self , problem_path : Path , code : str , cx : Complexity
556- ) -> Tuple [ Optional [ str ], Optional [ float ], Dict [str , Any ]]:
556+ ) -> tuple [ str | None , float | None , dict [str , Any ]]:
557557 """Ask an LLM to choose a routing STRATEGY and optional budgets.
558558
559559 The LLM must return JSON with keys:
@@ -629,7 +629,7 @@ def _llm_decide_route(
629629 f"Features:\n ```json\n { json .dumps (feats , indent = 2 )} \n ```\n \n "
630630 "Problem code:\n ```python\n " + code + "\n ```\n "
631631 )
632- kwargs : Dict [str , Any ] = {
632+ kwargs : dict [str , Any ] = {
633633 "max_tokens" : self .router_max_tokens ,
634634 "temperature" : self .router_temperature ,
635635 }
@@ -644,7 +644,7 @@ def _llm_decide_route(
644644 # Best-effort JSON parse
645645 route = None
646646 conf = None
647- raw_info : Dict [str , Any ] = {"raw" : txt }
647+ raw_info : dict [str , Any ] = {"raw" : txt }
648648 try :
649649 # If model returned extra text, try to locate JSON object
650650 first = txt .find ("{" )
@@ -676,7 +676,7 @@ def _llm_decide_route(
676676# ------------------------
677677
678678
679- def main (argv : Optional [ list [str ]] = None ) -> int :
679+ def main (argv : list [str ] | None = None ) -> int :
680680 p = argparse .ArgumentParser (
681681 description = "Auto-router for KernelBench problems (KernelAgent vs Fuser)"
682682 )
@@ -763,7 +763,7 @@ def main(argv: Optional[list[str]] = None) -> int:
763763 )
764764 return 1
765765
766- out : Dict [str , Any ] = {
766+ out : dict [str , Any ] = {
767767 "route" : res .route ,
768768 "success" : res .success ,
769769 "details" : res .details ,
0 commit comments