Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(
dispatch_jobs: int = 2,
allow_fallback: bool = True,
target_platform: str | None = None,
use_router_cache: bool = True,
) -> None:
self.ka_model = ka_model
self.ka_num_workers = ka_num_workers
Expand All @@ -352,6 +353,7 @@ def __init__(
self.dispatch_jobs = dispatch_jobs
self.allow_fallback = allow_fallback
self.platform_config = get_platform(target_platform)
self.use_router_cache = use_router_cache

def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
agent = TritonKernelAgent(
Expand Down Expand Up @@ -461,20 +463,23 @@ def solve(self, problem_path: Path) -> RouteResult:
heuristic_prefers_fuser = cx.route_to_fuser()

# Cache lookup by content hash to avoid repeated router calls
cache = {}
code_hash = _file_sha256_text(code)
cache = _load_router_cache()
cached = cache.get(code_hash)

strategy: str | None = None
route_conf: float | None = None
route_cfg: dict[str, Any] = {}

if isinstance(cached, dict):
strategy = (
str(cached.get("route_strategy") or cached.get("route") or "") or None
)
route_conf = cached.get("confidence")
route_cfg = cached.get("config") or {}
if self.use_router_cache:
cache = _load_router_cache()
cached = cache.get(code_hash)

if isinstance(cached, dict):
strategy = (
str(cached.get("route_strategy") or cached.get("route") or "")
or None
)
route_conf = cached.get("confidence")
route_cfg = cached.get("config") or {}

if strategy is None:
# Try LLM-driven decision
Expand All @@ -483,11 +488,12 @@ def solve(self, problem_path: Path) -> RouteResult:
problem_path, code, cx
)
# Persist in cache for future runs
cache[code_hash] = info.get("parsed") or {
"route_strategy": strategy,
"confidence": route_conf,
}
_save_router_cache(cache)
if self.use_router_cache:
cache[code_hash] = info.get("parsed") or {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

router-emitted config.llm_models / config.ka_model can be non-configured/inaccessible, gets cached verbatim, and later runs apply it unvalidated. Please consider validating/intersecting against the local registry (utils/providers/available_models.py) + provider availability before (a) applying and/or (b) writing to cache.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I have a local fix I'll push in a separate PR since there's actually some other bugs that get tackled over there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#79

"route_strategy": strategy,
"confidence": route_conf,
}
_save_router_cache(cache)
except Exception:
# No provider or failure; fall back later
pass
Expand Down Expand Up @@ -704,6 +710,11 @@ def main(argv: list[str] | None = None) -> int:
p.add_argument("--verify", action="store_true")
p.add_argument("--dispatch-jobs", type=int, default=2)
p.add_argument("--no-fallback", action="store_true")
p.add_argument(
"--no-router-cache",
action="store_true",
help="Disable router cache (do not read from or write to cache)",
)
p.add_argument(
"--target-platform",
default="cuda",
Expand Down Expand Up @@ -741,6 +752,7 @@ def main(argv: list[str] | None = None) -> int:
dispatch_jobs=args.dispatch_jobs,
allow_fallback=(not args.no_fallback),
target_platform=args.target_platform,
use_router_cache=(not args.no_router_cache),
)

try:
Expand Down