Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions oink/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# KernelAgent-Oink

KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for
**NVIDIA Blackwell (SM100 / GB200 / B200-class)**, bundled as a lightweight
Python package that can be used standalone or as a **vLLM general plugin**.

At the moment, the vLLM integration exposes the following `torch.library.custom_op`
entrypoints under the `oink::` namespace:

- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor`
- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place)

The package also includes additional SM100 kernels used by the benchmark suite:
LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd).

## Requirements

- GPU: **SM100** for the fast CuTeDSL paths. On other GPUs, Oink falls back to
reference PyTorch implementations for correctness.
- Python dependencies:
- `nvidia-cutlass-dsl` (CuTeDSL)
- `cuda-python`
- `torch` (provided by your environment / vLLM)

Recommended env vars:

```bash
export CUTE_DSL_ARCH=sm_100a
export PYTORCH_ALLOC_CONF=expandable_segments:True
```

## Install (editable)

From the `KernelAgent` repo root:

```bash
pip install -e ./oink
```

For running the in-repo benchmark suite / plots:

```bash
pip install -e "./oink[bench]"
```

## Usage

### vLLM (general plugin)

1) Enable the plugin:

```bash
export VLLM_USE_OINK_RMSNORM=1
```

2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs:

```python
from vllm import LLM

llm = LLM(
model=...,
tensor_parallel_size=...,
enforce_eager=False,
compilation_config={"custom_ops": ["none", "+rms_norm"]},
)
```

Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither
vLLM’s CUDA RMSNorm nor Oink will run.

### Direct PyTorch usage (manual op registration)

For standalone use (outside vLLM), register the custom ops once:

```python
import kernelagent_oink
import torch

kernelagent_oink.register(force=True)

x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16)
w = torch.randn(4096, device="cuda", dtype=torch.bfloat16)
y = torch.ops.oink.rmsnorm(x, w, 1e-6)
```

## Benchmarks

The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare
Oink against Quack on SM100 and to reproduce the reported speedups.

- How to run + methodology: `oink/benchmarks/README.md`
- Pre-generated plots: `oink/benchmarks/media/`

<div align="center">
<img src="benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg" alt="SM100 BF16: Oink vs Quack (Quack-suite)">
</div>

<div align="center">
<img src="benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg" alt="SM100 BF16: Oink vs Quack (DSv3-like shapes)">
</div>

## Links

| What | Link |
|---|---|
| Quack (expert baseline) | https://github.com/Dao-AILab/quack |
| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent |
| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 |
152 changes: 152 additions & 0 deletions oink/benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SM100 Benchmarks (KernelAgent-Oink vs Quack)

This folder contains SM100 (GB200 / Blackwell) microbenchmarks for the Oink
CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100
kernels where Quack provides an equivalent API.

## Prereqs

- GPU: **SM100** (`torch.cuda.get_device_capability() == (10, 0)`).
- Python deps in your environment:
- `torch`
- `nvidia-cutlass-dsl` (CuTeDSL)
- `cuda-python`
- `triton` (only for `triton.testing.do_bench`)
- `quack` (optional; only needed for Oink-vs-Quack comparisons)

Recommended env vars:

```bash
export PYTORCH_ALLOC_CONF=expandable_segments:True
export CUTE_DSL_ARCH=sm_100a
```

## Shape suites

- **Quack-suite**: `(batch, seq) ∈ {1,4,8,16,32} × {8192,16384,32768,65536,131072}`,
with `hidden = 4096` so `M = batch * seq`, `N = 4096`.
- **DeepSeek-V3-like (DSv3)**
- RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}`
- Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}`

## Correctness gates

By default, each script runs a per-shape `torch.testing.assert_close` check
vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack
is available for that op/path, the script also validates Quack vs the *same*
reference (so speedups can’t come from looser numerics).

Disable with `--skip-verify` only for quick smoke tests.

## Running benchmarks

All scripts support:

- `--quack-suite` or `--dsv3` (or `--configs MxN,...`)
- `--dtype {bf16,fp16,fp32}`
- `--iters <ms>` and `--warmup-ms <ms>` for kernel-only timing
- `--json <path>` and/or `--csv <path>` outputs (meta + rows)

### One-command suite

Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts
to a timestamped directory:

```bash
python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16
```

Turn the JSON artifacts into Markdown tables (with geomean speedups):

```bash
python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_<timestamp> \
--out /tmp/kernelagent_oink_sm100_suite_summary.md
```

### Measured HBM roofline (STREAM-like)

To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth
ceiling (rather than a theoretical spec), run:

```bash
CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \
--json /tmp/hbm_roofline_sm100_bf16.json
```

### RMSNorm forward

```bash
python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_quack_suite.json

python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_dsv3.json

# vLLM-style inference weights (weight dtype == activation dtype)
python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json
```

### Fused Add + RMSNorm (vLLM-style, in-place)

This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math):

```bash
CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
--json /tmp/fused_add_rmsnorm_sm100_bf16.json
```

Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`).
Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark
times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`)
to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only
the Quack fused kernel with preallocated outputs.

### RMSNorm backward

```bash
python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \
--csv /tmp/oink_rmsnorm_bwd_quack_suite.csv

python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \
--csv /tmp/oink_rmsnorm_bwd_dsv3.csv
```

### Softmax (forward + backward)

```bash
python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
--json /tmp/oink_softmax_fwd_bwd_quack_suite.json

python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
--json /tmp/oink_softmax_fwd_bwd_dsv3.json
```

### Cross-entropy (forward + backward)

```bash
python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
--json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json

python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
--json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json
```

### LayerNorm forward

```bash
python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \
--json /tmp/oink_layernorm_fwd_quack_suite.json

python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \
--json /tmp/oink_layernorm_fwd_dsv3.json
```

## Notes

- These scripts intentionally avoid importing any external Oink checkout so the
results reflect the in-tree KernelAgent Oink kernels.
- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that
is only used when the pointer-based fast path cannot be used (e.g. when
`weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You
can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`.
Loading