-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathextract.py
More file actions
648 lines (552 loc) · 22.7 KB
/
extract.py
File metadata and controls
648 lines (552 loc) · 22.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
#!/usr/bin/env python3
"""
AutoKernel Kernel Extractor -- Generate baseline kernels from profiling results.
Usage:
uv run extract.py # extract from workspace/profile_report.json
uv run extract.py --top 5 # extract only top-5 kernels
uv run extract.py --kernel-type matmul # extract only matmul kernels
uv run extract.py --report path/to/report.json
uv run extract.py --backend cuda # use CUDA C++ starter kernels instead of Triton
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
WORKSPACE_DIR = os.path.join(SCRIPT_DIR, "workspace")
KERNELS_DIR = os.path.join(SCRIPT_DIR, "kernels")
DEFAULT_REPORT_PATH = os.path.join(WORKSPACE_DIR, "profile_report.json")
OPTIMIZATION_PLAN_PATH = os.path.join(WORKSPACE_DIR, "optimization_plan.json")
# ---------------------------------------------------------------------------
# Shape key mappings per kernel type
# ---------------------------------------------------------------------------
# Each entry maps op_type -> list of (shape_key_aliases...) so we can parse
# various shape_info string formats from profile_report.json.
SHAPE_KEYS: Dict[str, List[str]] = {
"matmul": ["M", "N", "K"],
"flash_attention": ["B", "H", "N", "D"],
"layernorm": ["M", "N"],
"softmax": ["M", "N"],
"cross_entropy": ["batch", "vocab"],
"fused_mlp": ["M", "N", "K"],
"rmsnorm": ["M", "N"],
"reduce": ["M", "N"],
"rotary_embedding": ["B", "H", "N", "D"],
}
# Aliases: profile_report.json may use different key names than bench.py
# Map from alias -> canonical bench.py key, per op_type.
SHAPE_ALIAS_MAP: Dict[str, Dict[str, str]] = {
"matmul": {},
"flash_attention": {
"B": "batch", "H": "heads", "N": "seq_len", "S": "seq_len", "D": "head_dim",
"batch": "batch", "heads": "heads", "seq_len": "seq_len", "head_dim": "head_dim",
},
"layernorm": {
"M": "batch", "N": "dim", "rows": "batch", "cols": "dim",
"batch": "batch", "dim": "dim",
},
"softmax": {
"M": "rows", "N": "cols", "rows": "rows", "cols": "cols",
},
"cross_entropy": {
"batch": "batch", "vocab": "vocab",
},
"fused_mlp": {
"M": "batch", "N": "hidden", "K": "dim",
"batch": "batch", "dim": "dim", "hidden": "hidden",
},
"rmsnorm": {
"M": "M", "N": "N",
},
"reduce": {
"M": "M", "N": "N",
},
"rotary_embedding": {
"B": "batch", "H": "heads", "N": "seq_len", "S": "seq_len", "D": "head_dim",
"batch": "batch", "heads": "heads", "seq_len": "seq_len", "head_dim": "head_dim",
},
}
# Default tolerances per op_type (matching bench.py structure, serialized for template)
TOLERANCES_MAP: Dict[str, Dict[str, Dict[str, float]]] = {
"matmul": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 2e-2, "rtol": 2e-2},
"float32": {"atol": 1e-4, "rtol": 1e-4},
},
"flash_attention": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 2e-2, "rtol": 2e-2},
"float32": {"atol": 1e-4, "rtol": 1e-4},
},
"layernorm": {
"float16": {"atol": 1e-3, "rtol": 1e-3},
"bfloat16": {"atol": 2e-3, "rtol": 2e-3},
"float32": {"atol": 1e-5, "rtol": 1e-5},
},
"softmax": {
"float16": {"atol": 1e-3, "rtol": 1e-3},
"bfloat16": {"atol": 2e-3, "rtol": 2e-3},
"float32": {"atol": 1e-5, "rtol": 1e-5},
},
"cross_entropy": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 2e-2, "rtol": 2e-2},
"float32": {"atol": 1e-5, "rtol": 1e-5},
},
"fused_mlp": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 2e-2, "rtol": 2e-2},
"float32": {"atol": 1e-4, "rtol": 1e-4},
},
"rmsnorm": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 1e-1, "rtol": 5e-2},
},
"reduce": {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 1e-1, "rtol": 5e-2},
},
"rotary_embedding": {
"float16": {"atol": 1e-3, "rtol": 1e-3},
"bfloat16": {"atol": 2e-3, "rtol": 2e-3},
"float32": {"atol": 1e-5, "rtol": 1e-5},
},
}
# FLOPS formulas as source strings, per op_type
FLOPS_FN_SRC: Dict[str, str] = {
"matmul": 'return 2 * s["M"] * s["N"] * s["K"]',
"flash_attention": 'return 4 * s["batch"] * s["heads"] * (s["seq_len"] ** 2) * s["head_dim"]',
"layernorm": 'return 8 * s["batch"] * s["dim"]',
"softmax": 'return 5 * s["rows"] * s["cols"]',
"cross_entropy": 'return 4 * s["batch"] * s["vocab"]',
"fused_mlp": 'return 2 * s["batch"] * s["dim"] * s["hidden"] * 3',
"rmsnorm": 'return 6 * s["M"] * s["N"]',
"reduce": 'return s["M"] * s["N"]',
"rotary_embedding": 'return 6 * s["batch"] * s["heads"] * s["seq_len"] * s["head_dim"]',
}
# BYTES formulas as source strings, per op_type (dt_bytes is passed in)
BYTES_FN_SRC: Dict[str, str] = {
"matmul": 'return (s["M"] * s["K"] + s["K"] * s["N"] + s["M"] * s["N"]) * dt_bytes',
"flash_attention": 'return 4 * s["batch"] * s["heads"] * s["seq_len"] * s["head_dim"] * dt_bytes',
"layernorm": 'return (2 * s["batch"] * s["dim"] + 2 * s["dim"]) * dt_bytes',
"softmax": 'return 2 * s["rows"] * s["cols"] * dt_bytes',
"cross_entropy": 'return (s["batch"] * s["vocab"] + s["batch"]) * dt_bytes',
"fused_mlp": 'return (s["batch"] * s["dim"] + s["hidden"] * s["dim"] * 3 + s["batch"] * s["dim"]) * dt_bytes',
"rmsnorm": 'return (2 * s["M"] * s["N"] + s["N"]) * dt_bytes',
"reduce": 'return (s["M"] * s["N"] + s["M"]) * dt_bytes',
"rotary_embedding": 'return (s["batch"] * s["heads"] * s["seq_len"] * s["head_dim"] * 2 + s["seq_len"] * s["head_dim"]) * dt_bytes',
}
# Speedup potential heuristic per op_type
SPEEDUP_ESTIMATES: Dict[str, str] = {
"matmul": "2-3x",
"flash_attention": "2-4x",
"layernorm": "1.5-3x",
"softmax": "1.5-3x",
"cross_entropy": "1.5-2x",
"fused_mlp": "2-3x",
"rmsnorm": "1.5-3x",
"reduce": "1.5-2x",
"rotary_embedding": "1.5-2x",
}
# ---------------------------------------------------------------------------
# Shape parsing
# ---------------------------------------------------------------------------
def parse_shape_info(shape_info_str: str, op_type: str) -> Optional[Dict[str, int]]:
"""
Parse a shape_info string like "M=4096, N=4096, K=4096" into a dict.
Handles various formats:
- "M=4096, N=4096, K=4096"
- "B=1, H=32, N=4096, D=128"
- "batch=4096, vocab=32000"
- "rows=4096, cols=4096"
Returns None if parsing fails.
"""
if not shape_info_str or not isinstance(shape_info_str, str):
return None
# Match key=value pairs
pairs = re.findall(r"([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(\d+)", shape_info_str)
if not pairs:
return None
raw = {k: int(v) for k, v in pairs}
# Map to canonical bench.py keys using alias map
alias_map = SHAPE_ALIAS_MAP.get(op_type, {})
if alias_map:
canonical = {}
for k, v in raw.items():
mapped_key = alias_map.get(k, k)
canonical[mapped_key] = v
return canonical
else:
return raw
def shape_to_display(shape: Dict[str, int]) -> str:
"""Convert a shape dict to a display string like 'M=4096, N=4096, K=4096'."""
return ", ".join(f"{k}={v}" for k, v in shape.items())
def scale_shape(shape: Dict[str, int], factor: float) -> Dict[str, int]:
"""
Scale all shape dimensions by a factor, rounding to nearest integer.
Ensures all values are at least 1.
"""
return {k: max(1, int(round(v * factor))) for k, v in shape.items()}
def get_default_shape(op_type: str) -> Dict[str, int]:
"""
Return a reasonable default shape for a given op_type when parsing fails.
Based on the 'large' size from bench.py KERNEL_CONFIGS.
"""
defaults: Dict[str, Dict[str, int]] = {
"matmul": {"M": 2048, "N": 2048, "K": 2048},
"flash_attention": {"batch": 2, "heads": 32, "seq_len": 1024, "head_dim": 64},
"layernorm": {"batch": 4096, "dim": 2048},
"softmax": {"rows": 4096, "cols": 4096},
"cross_entropy": {"batch": 4096, "vocab": 32000},
"fused_mlp": {"batch": 2048, "dim": 2048, "hidden": 5504},
"rmsnorm": {"M": 4096, "N": 4096},
"reduce": {"M": 4096, "N": 4096},
"rotary_embedding": {"batch": 2, "heads": 32, "seq_len": 1024, "head_dim": 128},
}
return defaults.get(op_type, {"M": 2048, "N": 2048})
# ---------------------------------------------------------------------------
# Kernel file generation
# ---------------------------------------------------------------------------
def read_starter_kernel(op_type: str, backend: str = "triton") -> Optional[str]:
"""Read the starter kernel file. Returns None if not found.
For backend='triton': reads from kernels/{op_type}.py
For backend='cuda': reads from kernels/cuda/{op_type}.py
"""
if backend == "cuda":
path = os.path.join(KERNELS_DIR, "cuda", f"{op_type}.py")
else:
path = os.path.join(KERNELS_DIR, f"{op_type}.py")
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return f.read()
def extract_kernel_body(starter_code: str) -> str:
"""
Extract the Triton kernel code from a starter file, stripping the
original module docstring and KERNEL_TYPE declaration (which we replace
in the template header).
Returns everything from the first 'import' statement onward.
"""
lines = starter_code.split("\n")
# Find the first import line
import_idx = None
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("import ") or stripped.startswith("from "):
import_idx = i
break
if import_idx is not None:
return "\n".join(lines[import_idx:])
else:
# Fallback: return everything after KERNEL_TYPE line
for i, line in enumerate(lines):
if line.strip().startswith("KERNEL_TYPE"):
return "\n".join(lines[i + 1:])
return starter_code
def generate_kernel_file(
op_type: str,
rank: int,
pct_total: float,
model_shape: Dict[str, int],
model_name: str,
gpu_time_ms: float,
starter_code: str,
backend: str = "triton",
) -> str:
"""Generate the complete kernel file content for extraction."""
half_shape = scale_shape(model_shape, 0.5)
double_shape = scale_shape(model_shape, 2.0)
shape_display = shape_to_display(model_shape)
half_display = shape_to_display(half_shape)
double_display = shape_to_display(double_shape)
tolerances = TOLERANCES_MAP.get(op_type, {
"float16": {"atol": 1e-2, "rtol": 1e-2},
"bfloat16": {"atol": 2e-2, "rtol": 2e-2},
"float32": {"atol": 1e-4, "rtol": 1e-4},
})
flops_fn_body = FLOPS_FN_SRC.get(op_type, 'return 0')
bytes_fn_body = BYTES_FN_SRC.get(op_type, 'return 0')
# Extract the kernel code body (imports + jit functions + kernel_fn)
kernel_body = extract_kernel_body(starter_code)
# Build the file
lines = []
# Header docstring
lines.append('"""')
lines.append(f"AutoKernel -- Extracted kernel from model profiling.")
lines.append(f"Op type: {op_type}")
lines.append(f"Rank: {rank} ({pct_total}% of GPU time)")
lines.append(f"Model shape: {shape_display}")
lines.append(f"")
lines.append(f"This kernel was extracted from profiling {model_name}.")
lines.append(f"The agent optimizes this to maximize throughput at the model-specific shapes.")
lines.append('"""')
lines.append("")
# KERNEL_TYPE and BACKEND
lines.append(f'KERNEL_TYPE = "{op_type}"')
if backend == "cuda":
lines.append(f'BACKEND = "cuda"')
lines.append("")
# Model-specific shapes
lines.append("# Model-specific shapes (the shapes that matter for THIS model)")
lines.append(f"MODEL_SHAPES = {repr(model_shape)}")
lines.append("")
# Benchmark config
lines.append("# Benchmark config (self-describing -- bench.py can load this dynamically)")
lines.append("TEST_SIZES = [")
lines.append(f' ("model_primary", {repr(model_shape)}),')
lines.append(f" # Also test nearby sizes for robustness")
lines.append(f' ("model_half", {repr(half_shape)}),')
lines.append(f' ("model_double", {repr(double_shape)}),')
lines.append("]")
lines.append("")
# Tolerances
lines.append(f"TOLERANCES = {repr(tolerances)}")
lines.append("")
# FLOPS function
lines.append("")
lines.append("def FLOPS_FN(s):")
lines.append(f" {flops_fn_body}")
lines.append("")
# BYTES function
lines.append("")
lines.append("def BYTES_FN(s, dt_bytes):")
lines.append(f" {bytes_fn_body}")
lines.append("")
# Separator
lines.append("")
lines.append(f"# {'=' * 70}")
backend_label = "CUDA C++" if backend == "cuda" else "Triton"
backend_dir = f"kernels/cuda/{op_type}.py" if backend == "cuda" else f"kernels/{op_type}.py"
lines.append(f"# {backend_label} kernel code (from {backend_dir})")
lines.append(f"# {'=' * 70}")
lines.append("")
# Kernel body
lines.append(kernel_body)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Profile reading and validation
# ---------------------------------------------------------------------------
def load_profile_report(path: str) -> Optional[Dict[str, Any]]:
"""Load and validate the profile report JSON. Returns None on failure."""
if not os.path.exists(path):
return None
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return data
except (json.JSONDecodeError, IOError) as e:
print(f"ERROR: Failed to read profile report: {e}")
return None
def get_supported_kernels(report: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Extract the list of supported (autokernel_supported=True) kernels from
the profile report, sorted by rank.
"""
kernels = report.get("top_kernels", report.get("kernels", report.get("bottleneck_kernels", [])))
supported = []
for k in kernels:
if k.get("autokernel_supported", False):
supported.append(k)
# Sort by rank if available, otherwise by gpu_time_ms descending
supported.sort(key=lambda x: x.get("rank", x.get("gpu_time_ms", 0)))
# Ensure rank ordering (lower rank = higher priority)
for i, k in enumerate(supported):
if "rank" not in k:
k["rank"] = i + 1
return supported
# ---------------------------------------------------------------------------
# Optimization plan generation
# ---------------------------------------------------------------------------
def generate_optimization_plan(
extracted: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Build the optimization_plan.json data structure."""
kernels_to_optimize = []
total_pct = 0.0
for entry in extracted:
total_pct += entry["pct_total"]
kernels_to_optimize.append({
"rank": entry["rank"],
"file": entry["output_file"],
"op_type": entry["op_type"],
"model_shape": entry["model_shape"],
"gpu_time_ms": entry["gpu_time_ms"],
"pct_total": entry["pct_total"],
"estimated_speedup_potential": SPEEDUP_ESTIMATES.get(
entry["op_type"], "1.5-2x"
),
})
return {
"kernels_to_optimize": kernels_to_optimize,
"total_optimization_targets": len(kernels_to_optimize),
"covered_gpu_time_pct": round(total_pct, 1),
}
# ---------------------------------------------------------------------------
# Main extraction logic
# ---------------------------------------------------------------------------
def extract_kernels(
report_path: str,
top_n: Optional[int] = None,
kernel_type_filter: Optional[str] = None,
backend: str = "triton",
) -> None:
"""Main extraction pipeline."""
backend_label = "CUDA C++" if backend == "cuda" else "Triton"
print(f"=== AutoKernel Kernel Extractor ({backend_label}) ===")
print()
# -- Load profile report --
print(f"Reading profile from {report_path}...")
report = load_profile_report(report_path)
if report is None:
print(f"ERROR: Profile report not found at {report_path}")
print(f" Run the profiler first: uv run profile.py")
sys.exit(1)
# -- Get model name --
model_name = report.get("model_name", report.get("model", "unknown model"))
# -- Get supported kernels --
supported = get_supported_kernels(report)
if not supported:
print("ERROR: No supported kernels found in profile report.")
print(" Ensure the profiler marks kernels with autokernel_supported=True.")
sys.exit(1)
# -- Apply filters --
if kernel_type_filter:
supported = [k for k in supported if k.get("op_type") == kernel_type_filter]
if not supported:
print(f"WARNING: No kernels of type '{kernel_type_filter}' found in profile report.")
sys.exit(1)
if top_n is not None:
supported = supported[:top_n]
print(f"Found {len(supported)} supported kernels to extract.")
print()
# -- Ensure workspace directory exists --
os.makedirs(WORKSPACE_DIR, exist_ok=True)
# -- Extract each kernel --
print("Extracting kernels:")
extracted = []
skipped = 0
for idx, kernel_info in enumerate(supported):
rank = kernel_info.get("rank", idx + 1)
op_type = kernel_info.get("op_type", "unknown")
pct_total = kernel_info.get("pct_total", kernel_info.get("pct_gpu_time", 0.0))
gpu_time_ms = kernel_info.get("gpu_time_ms", kernel_info.get("total_gpu_time_ms", 0.0))
shape_info_str = kernel_info.get("shape_info", kernel_info.get("shape", ""))
# Parse model shape
model_shape = parse_shape_info(shape_info_str, op_type)
if model_shape is None:
# Try to use a "shapes" dict directly if provided
if isinstance(kernel_info.get("shapes"), dict):
model_shape = kernel_info["shapes"]
else:
print(f" WARNING: Could not parse shape for {op_type} (rank {rank}), "
f"using default shapes.")
model_shape = get_default_shape(op_type)
# Read starter kernel
starter_code = read_starter_kernel(op_type, backend=backend)
if starter_code is None:
starter_dir = "kernels/cuda" if backend == "cuda" else "kernels"
print(f" WARNING: No starter kernel found at {starter_dir}/{op_type}.py -- skipping.")
skipped += 1
continue
# Generate output filename
output_filename = f"kernel_{op_type}_{rank}.py"
output_path = os.path.join(WORKSPACE_DIR, output_filename)
# Relative path for display and plan
output_relpath = f"workspace/{output_filename}"
# Generate the customized kernel file
kernel_content = generate_kernel_file(
op_type=op_type,
rank=rank,
pct_total=pct_total,
model_shape=model_shape,
model_name=model_name,
gpu_time_ms=gpu_time_ms,
starter_code=starter_code,
backend=backend,
)
# Write to workspace
with open(output_path, "w", encoding="utf-8") as f:
f.write(kernel_content)
# Print progress
position = idx + 1
total = len(supported)
shape_display = shape_to_display(model_shape)
print(f" [{position}/{total}] {op_type} (rank {rank}, {pct_total}%) "
f"-> {output_relpath}")
print(f" Model shape: {shape_display}")
starter_dir = "kernels/cuda" if backend == "cuda" else "kernels"
print(f" Based on: {starter_dir}/{op_type}.py")
print()
extracted.append({
"rank": rank,
"op_type": op_type,
"pct_total": pct_total,
"gpu_time_ms": gpu_time_ms,
"model_shape": model_shape,
"output_file": output_relpath,
})
if not extracted:
print("ERROR: No kernels were successfully extracted.")
if skipped > 0:
print(f" {skipped} kernel(s) skipped due to missing starter files.")
sys.exit(1)
# -- Generate optimization plan --
plan = generate_optimization_plan(extracted)
with open(OPTIMIZATION_PLAN_PATH, "w", encoding="utf-8") as f:
json.dump(plan, f, indent=4)
print(f"Optimization plan saved to workspace/optimization_plan.json")
# -- Print next steps --
print()
top_kernel = extracted[0]
top_file = top_kernel["output_file"]
print("Next steps:")
print(f" 1. Copy a kernel to kernel.py: cp {top_file} kernel.py")
print(f" 2. Run benchmark: uv run bench.py")
print(f" 3. Start optimizing (or let the agent do it via program.md)")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="AutoKernel Kernel Extractor -- Generate baseline kernels from profiling results.",
)
parser.add_argument(
"--report",
type=str,
default=DEFAULT_REPORT_PATH,
help="Path to profile_report.json (default: workspace/profile_report.json)",
)
parser.add_argument(
"--top",
type=int,
default=None,
help="Extract only the top-N kernels by rank",
)
parser.add_argument(
"--kernel-type",
type=str,
default=None,
help="Extract only kernels of this type (e.g., matmul, flash_attention)",
)
parser.add_argument(
"--backend",
type=str,
choices=["triton", "cuda"],
default="triton",
help="Backend for starter kernels: 'triton' (default) or 'cuda' (native CUDA C++)",
)
args = parser.parse_args()
extract_kernels(
report_path=args.report,
top_n=args.top,
kernel_type_filter=args.kernel_type,
backend=args.backend,
)
if __name__ == "__main__":
main()