Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(
dispatch_jobs: int = 2,
allow_fallback: bool = True,
target_platform: str | None = None,
ignore_router_config: bool = False,
) -> None:
self.ka_model = ka_model
self.ka_num_workers = ka_num_workers
Expand All @@ -352,6 +353,7 @@ def __init__(
self.dispatch_jobs = dispatch_jobs
self.allow_fallback = allow_fallback
self.platform_config = get_platform(target_platform)
self.ignore_router_config = ignore_router_config

def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
agent = TritonKernelAgent(
Expand Down Expand Up @@ -503,8 +505,8 @@ def solve(self, problem_path: Path) -> RouteResult:
# Confidence too low or invalid JSON; resort to heuristic
strategy = "fuser" if heuristic_prefers_fuser else "kernelagent"

# Apply optional dynamic config from router
if isinstance(route_cfg, dict):
# Apply optional dynamic config from router (skip if ignore requested)
if isinstance(route_cfg, dict) and not self.ignore_router_config:
# KernelAgent tuning
self.ka_max_rounds = int(route_cfg.get("ka_max_rounds", self.ka_max_rounds))
self.ka_num_workers = int(
Expand Down Expand Up @@ -704,6 +706,11 @@ def main(argv: list[str] | None = None) -> int:
p.add_argument("--verify", action="store_true")
p.add_argument("--dispatch-jobs", type=int, default=2)
p.add_argument("--no-fallback", action="store_true")
p.add_argument(
"--ignore-router-config",
action="store_true",
help="Ignore router config. Use CLI-provided model/config arguments",
)
p.add_argument(
"--target-platform",
default="cuda",
Expand Down Expand Up @@ -741,6 +748,7 @@ def main(argv: list[str] | None = None) -> int:
dispatch_jobs=args.dispatch_jobs,
allow_fallback=(not args.no_fallback),
target_platform=args.target_platform,
ignore_router_config=args.ignore_router_config,
)

try:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.

## Component Details

- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails.
- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--ignore-router-config` to manually specify routed execution configs.

- **Fuser Orchestrator (`Fuser/orchestrator.py`)**: rewrites the PyTorch module into fusable modules, executes them for validation, and packages a tarball of the fused code. Run IDs and directories are managed via `Fuser/paths.py`.

Expand Down