diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index dadaa8f..dcd400f 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -330,6 +330,7 @@ def __init__( dispatch_jobs: int = 2, allow_fallback: bool = True, target_platform: str | None = None, + ignore_router_config: bool = False, ) -> None: self.ka_model = ka_model self.ka_num_workers = ka_num_workers @@ -352,6 +353,7 @@ def __init__( self.dispatch_jobs = dispatch_jobs self.allow_fallback = allow_fallback self.platform_config = get_platform(target_platform) + self.ignore_router_config = ignore_router_config def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: agent = TritonKernelAgent( @@ -503,8 +505,8 @@ def solve(self, problem_path: Path) -> RouteResult: # Confidence too low or invalid JSON; resort to heuristic strategy = "fuser" if heuristic_prefers_fuser else "kernelagent" - # Apply optional dynamic config from router - if isinstance(route_cfg, dict): + # Apply optional dynamic config from router (skip if ignore requested) + if isinstance(route_cfg, dict) and not self.ignore_router_config: # KernelAgent tuning self.ka_max_rounds = int(route_cfg.get("ka_max_rounds", self.ka_max_rounds)) self.ka_num_workers = int( @@ -704,6 +706,11 @@ def main(argv: list[str] | None = None) -> int: p.add_argument("--verify", action="store_true") p.add_argument("--dispatch-jobs", type=int, default=2) p.add_argument("--no-fallback", action="store_true") + p.add_argument( + "--ignore-router-config", + action="store_true", + help="Ignore router config. Use CLI-provided model/config arguments", + ) p.add_argument( "--target-platform", default="cuda", @@ -741,6 +748,7 @@ def main(argv: list[str] | None = None) -> int: dispatch_jobs=args.dispatch_jobs, allow_fallback=(not args.no_fallback), target_platform=args.target_platform, + ignore_router_config=args.ignore_router_config, ) try: diff --git a/README.md b/README.md index d39fb4b..339b8d9 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. ## Component Details -- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. +- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--ignore-router-config` to manually specify routed execution configs. - **Fuser Orchestrator (`Fuser/orchestrator.py`)**: rewrites the PyTorch module into fusable modules, executes them for validation, and packages a tarball of the fused code. Run IDs and directories are managed via `Fuser/paths.py`.