diff --git a/oink/README.md b/oink/README.md
new file mode 100644
index 0000000..aeb0c09
--- /dev/null
+++ b/oink/README.md
@@ -0,0 +1,109 @@
+# KernelAgent-Oink
+
+KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for
+**NVIDIA Blackwell (SM100 / GB200 / B200-class)**, bundled as a lightweight
+Python package that can be used standalone or as a **vLLM general plugin**.
+
+At the moment, the vLLM integration exposes the following `torch.library.custom_op`
+entrypoints under the `oink::` namespace:
+
+- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor`
+- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place)
+
+The package also includes additional SM100 kernels used by the benchmark suite:
+LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd).
+
+## Requirements
+
+- GPU: **SM100** for the fast CuTeDSL paths. On other GPUs, Oink falls back to
+ reference PyTorch implementations for correctness.
+- Python dependencies:
+ - `nvidia-cutlass-dsl` (CuTeDSL)
+ - `cuda-python`
+ - `torch` (provided by your environment / vLLM)
+
+Recommended env vars:
+
+```bash
+export CUTE_DSL_ARCH=sm_100a
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+```
+
+## Install (editable)
+
+From the `KernelAgent` repo root:
+
+```bash
+pip install -e ./oink
+```
+
+For running the in-repo benchmark suite / plots:
+
+```bash
+pip install -e "./oink[bench]"
+```
+
+## Usage
+
+### vLLM (general plugin)
+
+1) Enable the plugin:
+
+```bash
+export VLLM_USE_OINK_RMSNORM=1
+```
+
+2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs:
+
+```python
+from vllm import LLM
+
+llm = LLM(
+ model=...,
+ tensor_parallel_size=...,
+ enforce_eager=False,
+ compilation_config={"custom_ops": ["none", "+rms_norm"]},
+)
+```
+
+Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither
+vLLM’s CUDA RMSNorm nor Oink will run.
+
+### Direct PyTorch usage (manual op registration)
+
+For standalone use (outside vLLM), register the custom ops once:
+
+```python
+import kernelagent_oink
+import torch
+
+kernelagent_oink.register(force=True)
+
+x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16)
+w = torch.randn(4096, device="cuda", dtype=torch.bfloat16)
+y = torch.ops.oink.rmsnorm(x, w, 1e-6)
+```
+
+## Benchmarks
+
+The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare
+Oink against Quack on SM100 and to reproduce the reported speedups.
+
+- How to run + methodology: `oink/benchmarks/README.md`
+- Pre-generated plots: `oink/benchmarks/media/`
+
+
+

+
+
+
+

+
+
+## Links
+
+| What | Link |
+|---|---|
+| Quack (expert baseline) | https://github.com/Dao-AILab/quack |
+| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent |
+| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 |
diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md
new file mode 100644
index 0000000..a5c4676
--- /dev/null
+++ b/oink/benchmarks/README.md
@@ -0,0 +1,152 @@
+# SM100 Benchmarks (KernelAgent-Oink vs Quack)
+
+This folder contains SM100 (GB200 / Blackwell) microbenchmarks for the Oink
+CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100
+kernels where Quack provides an equivalent API.
+
+## Prereqs
+
+- GPU: **SM100** (`torch.cuda.get_device_capability() == (10, 0)`).
+- Python deps in your environment:
+ - `torch`
+ - `nvidia-cutlass-dsl` (CuTeDSL)
+ - `cuda-python`
+ - `triton` (only for `triton.testing.do_bench`)
+ - `quack` (optional; only needed for Oink-vs-Quack comparisons)
+
+Recommended env vars:
+
+```bash
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+export CUTE_DSL_ARCH=sm_100a
+```
+
+## Shape suites
+
+- **Quack-suite**: `(batch, seq) ∈ {1,4,8,16,32} × {8192,16384,32768,65536,131072}`,
+ with `hidden = 4096` so `M = batch * seq`, `N = 4096`.
+- **DeepSeek-V3-like (DSv3)**
+ - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}`
+ - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}`
+
+## Correctness gates
+
+By default, each script runs a per-shape `torch.testing.assert_close` check
+vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack
+is available for that op/path, the script also validates Quack vs the *same*
+reference (so speedups can’t come from looser numerics).
+
+Disable with `--skip-verify` only for quick smoke tests.
+
+## Running benchmarks
+
+All scripts support:
+
+- `--quack-suite` or `--dsv3` (or `--configs MxN,...`)
+- `--dtype {bf16,fp16,fp32}`
+- `--iters ` and `--warmup-ms ` for kernel-only timing
+- `--json ` and/or `--csv ` outputs (meta + rows)
+
+### One-command suite
+
+Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts
+to a timestamped directory:
+
+```bash
+python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16
+```
+
+Turn the JSON artifacts into Markdown tables (with geomean speedups):
+
+```bash
+python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \
+ --out /tmp/kernelagent_oink_sm100_suite_summary.md
+```
+
+### Measured HBM roofline (STREAM-like)
+
+To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth
+ceiling (rather than a theoretical spec), run:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \
+ --json /tmp/hbm_roofline_sm100_bf16.json
+```
+
+### RMSNorm forward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_dsv3.json
+
+# vLLM-style inference weights (weight dtype == activation dtype)
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json
+```
+
+### Fused Add + RMSNorm (vLLM-style, in-place)
+
+This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math):
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
+ --json /tmp/fused_add_rmsnorm_sm100_bf16.json
+```
+
+Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`).
+Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark
+times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`)
+to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only
+the Quack fused kernel with preallocated outputs.
+
+### RMSNorm backward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \
+ --csv /tmp/oink_rmsnorm_bwd_quack_suite.csv
+
+python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \
+ --csv /tmp/oink_rmsnorm_bwd_dsv3.csv
+```
+
+### Softmax (forward + backward)
+
+```bash
+python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_softmax_fwd_bwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_softmax_fwd_bwd_dsv3.json
+```
+
+### Cross-entropy (forward + backward)
+
+```bash
+python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json
+```
+
+### LayerNorm forward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_layernorm_fwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_layernorm_fwd_dsv3.json
+```
+
+## Notes
+
+- These scripts intentionally avoid importing any external Oink checkout so the
+ results reflect the in-tree KernelAgent Oink kernels.
+- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that
+ is only used when the pointer-based fast path cannot be used (e.g. when
+ `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You
+ can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`.
diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py
new file mode 100644
index 0000000..ef996ec
--- /dev/null
+++ b/oink/benchmarks/benchmark/bench_utils.py
@@ -0,0 +1,289 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import csv
+import json
+import math
+import os
+import subprocess
+import sys
+from dataclasses import asdict, dataclass
+from datetime import datetime
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
+
+import torch
+from triton.testing import do_bench as triton_do_bench
+
+
+@dataclass(frozen=True)
+class DeviceMeta:
+ device: str
+ capability: Tuple[int, int]
+ torch: str
+ cuda: str
+ cute_dsl_arch: str
+ git_sha: str
+ timestamp: str
+
+
+def _try_git_sha() -> str:
+ here = os.path.dirname(os.path.abspath(__file__))
+ repo_root = os.path.abspath(os.path.join(here, "..", ".."))
+ try:
+ out = subprocess.check_output(
+ ["git", "rev-parse", "HEAD"],
+ cwd=repo_root,
+ stderr=subprocess.DEVNULL,
+ text=True,
+ )
+ return out.strip()
+ except Exception:
+ return ""
+
+
+def collect_device_meta(device: Optional[torch.device] = None) -> DeviceMeta:
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ timestamp = datetime.now().isoformat(timespec="seconds")
+ return DeviceMeta(
+ device=str(props.name),
+ capability=(int(props.major), int(props.minor)),
+ torch=str(torch.__version__),
+ cuda=str(getattr(torch.version, "cuda", "unknown")),
+ cute_dsl_arch=os.environ.get("CUTE_DSL_ARCH", ""),
+ git_sha=_try_git_sha(),
+ timestamp=timestamp,
+ )
+
+
+def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
+ """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ if sm >= 100:
+ return 8000.0
+ return 2000.0
+
+
+def do_bench_triton(
+ fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100
+) -> float:
+ """Kernel-only timing consistent with the Oink benchmark harnesses."""
+ return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean"))
+
+
+def parse_dtype(s: str) -> torch.dtype:
+ s = s.lower()
+ if s == "bf16":
+ return torch.bfloat16
+ if s == "fp16":
+ return torch.float16
+ if s == "fp32":
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {s}")
+
+
+def parse_configs(s: str) -> List[Tuple[int, int]]:
+ out: List[Tuple[int, int]] = []
+ for part in s.split(","):
+ m, n = part.lower().split("x")
+ out.append((int(m), int(n)))
+ return out
+
+
+def quack_suite_configs() -> List[Tuple[int, int, int]]:
+ """Return (batch, seq, hidden) triples following Quack's common grid (hidden=4096)."""
+ batch_sizes = [1, 4, 8, 16, 32]
+ seq_lengths = [8192, 16384, 32768, 65536, 131072]
+ hidden = 4096
+ cfgs: List[Tuple[int, int, int]] = []
+ for bs in batch_sizes:
+ for sl in seq_lengths:
+ M = bs * sl
+ if M * hidden > (2**31):
+ continue
+ cfgs.append((bs, sl, hidden))
+ return cfgs
+
+
+def ensure_oink_src_on_path() -> None:
+ """Make the in-repo KernelAgent Oink package importable without an editable install."""
+ here = os.path.dirname(os.path.abspath(__file__))
+ oink_src = os.path.abspath(os.path.join(here, "..", "..", "src"))
+ if oink_src not in sys.path:
+ sys.path.insert(0, oink_src)
+
+
+def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
+ file_exists = os.path.exists(path)
+ with open(path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(rows[0].keys()))
+ if not file_exists:
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+
+
+def write_json(
+ path: str,
+ meta: DeviceMeta,
+ rows: Sequence[Dict[str, Any]],
+ *,
+ extra: Dict[str, Any] | None = None,
+) -> None:
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
+ payload: Dict[str, Any] = {
+ "meta": {**asdict(meta), **(extra or {})},
+ "rows": list(rows),
+ }
+ with open(path, "w") as f:
+ json.dump(payload, f, indent=2)
+
+
+def iter_row_blocks(M: int, block_rows: int) -> Iterable[Tuple[int, int]]:
+ """Yield (start, end) row index ranges for a 2D (M, N) matrix.
+
+ The intent is to make correctness references for large tensors tractable
+ without materializing full float32 intermediates.
+ """
+ if M < 0:
+ raise ValueError(f"M must be non-negative, got {M}")
+ if block_rows <= 0:
+ raise ValueError(f"block_rows must be > 0, got {block_rows}")
+ for start in range(0, M, block_rows):
+ yield start, min(M, start + block_rows)
+
+
+@dataclass
+class ErrorStats:
+ """Numerical error stats between an output and a reference.
+
+ Notes:
+ - `max_abs` and `rel_l2` are computed exactly (streamed).
+ - `p99_abs` is computed over a deterministic strided sample of abs error
+ values (to keep very large tensors tractable).
+ """
+
+ max_abs: float
+ p99_abs: float
+ rel_l2: float
+ p99_sample_elems: int
+ p99_sample_stride: int
+
+
+class ErrorStatsAccumulator:
+ """Stream error stats over (output_block, ref_block) pairs.
+
+ This is intended for large 2D tensors where we compute reference results
+ block-by-block to avoid materializing full float32 intermediates.
+ """
+
+ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000):
+ if total_elems <= 0:
+ raise ValueError(f"total_elems must be > 0, got {total_elems}")
+ if p99_target_samples <= 0:
+ raise ValueError(
+ f"p99_target_samples must be > 0, got {p99_target_samples}"
+ )
+ self.total_elems = int(total_elems)
+ self.p99_target_samples = int(p99_target_samples)
+ # Deterministic strided sampling across the flattened tensor order.
+ self.sample_stride = max(1, self.total_elems // self.p99_target_samples)
+ self._global_offset = 0
+
+ self._max_abs = 0.0
+ self._err_sq_sum = 0.0
+ self._ref_sq_sum = 0.0
+ self._abs_err_samples: List[torch.Tensor] = []
+
+ def update(self, out: torch.Tensor, ref: torch.Tensor) -> None:
+ if out.shape != ref.shape:
+ raise ValueError(
+ f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}"
+ )
+
+ # Compute error in float32 for stable reductions.
+ err_f32 = (out - ref).to(torch.float32)
+ abs_err = err_f32.abs()
+
+ # Exact reductions.
+ self._max_abs = max(self._max_abs, float(abs_err.max().item()))
+ self._err_sq_sum += float((err_f32 * err_f32).sum(dtype=torch.float64).item())
+ ref_f32 = ref.to(torch.float32)
+ self._ref_sq_sum += float((ref_f32 * ref_f32).sum(dtype=torch.float64).item())
+
+ # Deterministic strided sample for p99_abs.
+ flat = abs_err.flatten()
+ block_elems = int(flat.numel())
+ if block_elems <= 0:
+ return
+
+ stride = int(self.sample_stride)
+ first = (-int(self._global_offset)) % stride
+ if first < block_elems:
+ idx = torch.arange(
+ first, block_elems, step=stride, device=flat.device, dtype=torch.int64
+ )
+ # Gather a modest number of values (≈ block_elems/stride).
+ vals = (
+ flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32)
+ )
+ self._abs_err_samples.append(vals)
+
+ self._global_offset += block_elems
+
+ def finalize(self) -> ErrorStats:
+ if self._abs_err_samples:
+ samples = torch.cat(self._abs_err_samples, dim=0)
+ if samples.numel() > self.p99_target_samples:
+ samples = samples[: self.p99_target_samples]
+ p99 = (
+ float(torch.quantile(samples, 0.99).item())
+ if samples.numel() > 0
+ else 0.0
+ )
+ sample_elems = int(samples.numel())
+ else:
+ p99 = 0.0
+ sample_elems = 0
+
+ denom = math.sqrt(self._ref_sq_sum) if self._ref_sq_sum > 0 else 0.0
+ rel_l2 = (math.sqrt(self._err_sq_sum) / denom) if denom > 0 else 0.0
+
+ return ErrorStats(
+ max_abs=float(self._max_abs),
+ p99_abs=float(p99),
+ rel_l2=float(rel_l2),
+ p99_sample_elems=int(sample_elems),
+ p99_sample_stride=int(self.sample_stride),
+ )
+
+
+def error_stats_to_row(prefix: str, stats: ErrorStats) -> Dict[str, Any]:
+ """Flatten ErrorStats into JSON-friendly row fields."""
+ return {
+ f"{prefix}_max_abs": float(stats.max_abs),
+ f"{prefix}_p99_abs": float(stats.p99_abs),
+ f"{prefix}_rel_l2": float(stats.rel_l2),
+ f"{prefix}_p99_sample_elems": int(stats.p99_sample_elems),
+ f"{prefix}_p99_sample_stride": int(stats.p99_sample_stride),
+ }
diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
new file mode 100644
index 0000000..3c8bf44
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
@@ -0,0 +1,498 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import cross_entropy as oink_ce # noqa: E402
+
+try:
+ from quack.cross_entropy import cross_entropy_bwd as quack_ce_bwd # type: ignore
+ from quack.cross_entropy import cross_entropy_fwd as quack_ce_fwd # type: ignore
+except Exception:
+ quack_ce_fwd = None
+ quack_ce_bwd = None
+
+
+# Match Quack's unit-test defaults (tests/test_cross_entropy.py).
+_VERIFY_TOL_LOSS = dict(atol=5e-5, rtol=1e-5) # float32 outputs (loss/lse)
+_VERIFY_TOL_DX = {
+ torch.float32: dict(atol=5e-5, rtol=1e-5),
+ # FP16 `dx` is low-precision; allow ~1 ulp at typical magnitudes.
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ # BF16 `dx` is low-precision; allow ~1 ulp at typical magnitudes.
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+
+def bytes_io_model_ce(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ target_dtype: torch.dtype = torch.int64,
+ mode: str,
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ t_elem = torch.tensor(0, dtype=target_dtype).element_size()
+ # Forward:
+ # read logits (M*N) + read target (M) + write loss (M fp32) + write lse (M fp32)
+ fwd = M * N * elem + M * t_elem + 2 * M * 4
+ # Backward (reduction="none" path):
+ # read logits (M*N) + read target (M) + read dloss (M fp32) + read lse (M fp32) + write dx (M*N)
+ bwd = 2 * M * N * elem + M * t_elem + 2 * M * 4
+
+ if mode == "fwd":
+ return int(fwd)
+ if mode == "bwd":
+ return int(bwd)
+ if mode == "fwd_bwd":
+ # Logical IO for dx given (logits, target, dloss): read logits + read target
+ # + read dloss + write dx. (Intermediate lse/loss are implementation details.)
+ return int(2 * M * N * elem + M * t_elem + M * 4)
+ raise ValueError(f"Unsupported mode: {mode}")
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [3072, 6144, 8192, 12288]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int
+) -> dict[str, object]:
+ dtype = logits.dtype
+ ref_block_rows = 512
+ dloss = torch.randn(
+ logits.size(0), device=logits.device, dtype=torch.float32
+ ) # upstream grad
+
+ with torch.no_grad():
+ loss_o, lse_o = oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=ignore_index, reduction="none"
+ )
+ dx_o = oink_ce.cross_entropy_backward(
+ dloss, logits, target, lse_o, ignore_index=ignore_index
+ )
+ dx_fused_o = oink_ce.cross_entropy_fwd_bwd(
+ dloss,
+ logits,
+ target,
+ ignore_index=ignore_index,
+ )
+
+ loss_q = None
+ lse_q = None
+ dx_q = None
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+ loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=ignore_index,
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ dx_q = quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ lse_q,
+ ignore_index=ignore_index,
+ inplace_backward=False,
+ )
+
+ M = int(logits.shape[0])
+ N = int(logits.shape[1])
+ loss_acc_ours = ErrorStatsAccumulator(
+ total_elems=M, p99_target_samples=min(M, 1_000_000)
+ )
+ lse_acc_ours = ErrorStatsAccumulator(
+ total_elems=M, p99_target_samples=min(M, 1_000_000)
+ )
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ loss_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+ lse_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+
+ # Match Quack tests: compare to a PyTorch reference computed on float32 logits.
+ # Chunk over rows so we don't materialize a full (M, N) float32 tensor.
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ logits_f32 = logits[start:end].float().requires_grad_(True)
+ target_blk = target[start:end]
+ dloss_blk = dloss[start:end]
+
+ loss_ref = torch.nn.functional.cross_entropy(
+ logits_f32,
+ target_blk,
+ reduction="none",
+ ignore_index=ignore_index,
+ )
+ lse_ref = torch.logsumexp(logits_f32, dim=-1)
+ (dx_ref_f32,) = torch.autograd.grad(
+ loss_ref, logits_f32, grad_outputs=dloss_blk
+ )
+ dx_ref = dx_ref_f32.to(dtype)
+
+ torch.testing.assert_close(
+ loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(
+ lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ torch.testing.assert_close(
+ dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]
+ )
+ loss_acc_ours.update(loss_o[start:end], loss_ref.detach())
+ lse_acc_ours.update(lse_o[start:end], lse_ref.detach())
+ dx_acc_ours.update(dx_o[start:end], dx_ref)
+ dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref)
+
+ if loss_q is not None and lse_q is not None and dx_q is not None:
+ torch.testing.assert_close(
+ loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(
+ lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ assert (
+ loss_acc_quack is not None
+ and lse_acc_quack is not None
+ and dx_acc_quack is not None
+ )
+ loss_acc_quack.update(loss_q[start:end], loss_ref.detach())
+ lse_acc_quack.update(lse_q[start:end], lse_ref.detach())
+ dx_acc_quack.update(dx_q[start:end], dx_ref)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_loss", loss_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize()))
+ if (
+ loss_acc_quack is not None
+ and lse_acc_quack is not None
+ and dx_acc_quack is not None
+ ):
+ stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ warmup_ms: int,
+ iters_ms: int,
+ mode: str,
+ verify: bool,
+ ignore_index: int,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype)
+ target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
+ # Sprinkle some ignore_index entries for robustness (and to match reduction semantics).
+ if ignore_index is not None:
+ mask = torch.rand(M, device=device) < 0.01
+ target[mask] = int(ignore_index)
+ dloss = torch.randn(M, device=device, dtype=torch.float32)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(logits, target, ignore_index=int(ignore_index))
+
+ bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode)
+
+ if mode == "fwd":
+
+ def fn_oink():
+ return oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=int(ignore_index), reduction="none"
+ )
+
+ fn_quack = None
+ if quack_ce_fwd is not None:
+
+ def fn_quack():
+ return quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+
+ elif mode == "bwd":
+ with torch.no_grad():
+ _loss_o, lse_o = oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=int(ignore_index), reduction="none"
+ )
+ if quack_ce_fwd is not None:
+ _loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ else:
+ lse_q = None
+
+ def fn_oink():
+ return oink_ce.cross_entropy_backward(
+ dloss, logits, target, lse_o, ignore_index=int(ignore_index)
+ )
+
+ fn_quack = None
+ if quack_ce_bwd is not None and lse_q is not None:
+
+ def fn_quack():
+ return quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ lse_q,
+ ignore_index=int(ignore_index),
+ inplace_backward=False,
+ )
+
+ elif mode == "fwd_bwd":
+
+ def fn_oink():
+ return oink_ce.cross_entropy_fwd_bwd(
+ dloss,
+ logits,
+ target,
+ ignore_index=int(ignore_index),
+ )
+
+ fn_quack = None
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+
+ def fn_quack():
+ _loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ return quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ lse_q,
+ ignore_index=int(ignore_index),
+ inplace_backward=False,
+ )
+
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if fn_quack is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]
+ )
+ p.add_argument("--ignore-index", type=int, default=-100)
+ p.add_argument(
+ "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)."
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument(
+ "--quack-suite",
+ action="store_true",
+ help="Run Quack-style batch/seq grid (vocab=4096)",
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch float32-logits cross entropy)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...",
+ flush=True,
+ )
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ mode=str(args.mode),
+ verify=not args.skip_verify,
+ ignore_index=int(args.ignore_index),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "mode": args.mode,
+ "ignore_index": int(args.ignore_index),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "mode-dependent; see bytes_io_model_ce in script",
+ },
+ )
+
+ headers = ["M", "N", "mode", "ours_ms", "ours_tbps"]
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
new file mode 100644
index 0000000..1787d7d
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -0,0 +1,376 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Benchmark fused_add_rmsnorm (in-place) on SM100.
+
+This matches vLLM's fused_add_rms_norm semantics:
+ z = x + residual (stored into residual)
+ y = RMSNorm(z, w) (stored into x)
+
+Why this exists:
+- It is a common inference hot path (vLLM).
+- It is strongly memory-bound (reads/writes two MxN tensors), making it a good
+ roofline case study for Blackwell.
+
+Example:
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \\
+ --json /tmp/fused_add_rmsnorm_sm100_bf16.json
+
+DSv3 suite (Oink vs Quack, multi-shape):
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\
+ --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
+
+Quack baseline note:
+- Oink exposes an **in-place** fused op (writes `x` and `residual` in-place).
+- Quack provides an equivalent fused kernel, but typically returns `out` and
+ `residual_out` (out-of-place) and does not expose a public "update my input
+ buffers in-place" API.
+- For integration realism (vLLM-style semantics) we default to timing:
+ Quack fused kernel + 2 explicit copies to apply the in-place updates
+ so the benchmark covers the full semantic cost.
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_dtype,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+_VERIFY_TOL = {
+ # Align with Quack's RMSNorm unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+try:
+ # Use the low-level mutating custom op to avoid per-iteration allocations
+ # (critical for fair comparisons on small/medium M).
+ from quack.rmsnorm import _rmsnorm_fwd as quack_rmsnorm_fwd_mut # type: ignore
+except Exception:
+ quack_rmsnorm_fwd_mut = None
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def bytes_io_model_fused_add_rmsnorm_inplace(M: int, N: int, dtype: torch.dtype) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ # Read x + read residual + write x + write residual + read weight
+ return int((4 * M * N + N) * elem)
+
+
+def _verify_parity(
+ *,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ w: torch.Tensor,
+ eps: float,
+) -> dict[str, object]:
+ tol = _VERIFY_TOL[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ z_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd_mut is not None
+ else None
+ )
+ z_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd_mut is not None
+ else None
+ )
+
+ x_o = x.clone()
+ r_o = residual.clone()
+ out_q = None
+ res_out_q = None
+ with torch.no_grad():
+ oink_rmsnorm.fused_add_rmsnorm_inplace_(x_o, r_o, w, eps=eps)
+
+ if quack_rmsnorm_fwd_mut is not None:
+ out_q = torch.empty_like(x)
+ res_out_q = torch.empty_like(residual)
+ quack_rmsnorm_fwd_mut(
+ x,
+ w,
+ out_q,
+ None, # bias
+ None, # rstd
+ None, # mean
+ residual,
+ res_out_q,
+ eps,
+ False, # is_layernorm
+ )
+
+ # Pure-PyTorch reference (float32 accumulation), chunked over rows.
+ M = int(x.shape[0])
+ w_f32 = w.float()
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ z = x[start:end] + residual[start:end]
+ zf = z.float()
+ rstd = torch.rsqrt(zf.square().mean(dim=-1, keepdim=True) + eps)
+ y_ref = ((zf * rstd) * w_f32).to(x.dtype)
+
+ torch.testing.assert_close(x_o[start:end], y_ref, **tol)
+ torch.testing.assert_close(r_o[start:end], z, **tol)
+ y_acc_ours.update(x_o[start:end], y_ref)
+ z_acc_ours.update(r_o[start:end], z)
+ if out_q is not None and res_out_q is not None:
+ torch.testing.assert_close(out_q[start:end], y_ref, **tol)
+ torch.testing.assert_close(res_out_q[start:end], z, **tol)
+ assert y_acc_quack is not None and z_acc_quack is not None
+ y_acc_quack.update(out_q[start:end], y_ref)
+ z_acc_quack.update(res_out_q[start:end], z)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize()))
+ if y_acc_quack is not None and z_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+ stats.update(
+ error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize())
+ )
+ return stats
+
+
+def bench_one(
+ *,
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+ quack_baseline: str,
+) -> Dict[str, Any]:
+ device = torch.device("cuda")
+ x = torch.randn((M, N), device=device, dtype=dtype)
+ residual = torch.randn_like(x)
+ w = torch.randn((N,), device=device, dtype=dtype)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x=x, residual=residual, w=w, eps=1e-6)
+
+ bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype)
+
+ def fn():
+ oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6)
+
+ ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms)
+
+ gbps = bytes_io / (ms * 1e-3) / 1e9
+ tbps = gbps / 1000.0
+ hbm_frac = gbps / detect_hbm_peak_gbps(device)
+
+ row: Dict[str, Any] = dict(
+ M=int(M),
+ N=int(N),
+ dtype="bf16"
+ if dtype is torch.bfloat16
+ else ("fp16" if dtype is torch.float16 else "fp32"),
+ ours_ms=float(ms),
+ ours_gbps=float(gbps),
+ ours_tbps=float(tbps),
+ ours_hbm_frac=float(hbm_frac),
+ )
+ row.update(stats)
+
+ if quack_rmsnorm_fwd_mut is not None:
+ x_q = x.clone()
+ residual_q = residual.clone()
+ out_q = torch.empty_like(x_q)
+ res_out_q = torch.empty_like(residual_q)
+
+ def fn_q_kernel():
+ quack_rmsnorm_fwd_mut(
+ x_q,
+ w,
+ out_q,
+ None, # bias
+ None, # rstd
+ None, # mean
+ residual_q,
+ res_out_q,
+ 1e-6,
+ False, # is_layernorm
+ )
+
+ if quack_baseline == "kernel":
+ fn_q = fn_q_kernel
+ elif quack_baseline == "kernel_inplace":
+
+ def fn_q():
+ fn_q_kernel()
+ # Apply the same in-place semantics as vLLM expects:
+ # - x is overwritten with y
+ # - residual is overwritten with z = x + residual
+ x_q.copy_(out_q)
+ residual_q.copy_(res_out_q)
+
+ else:
+ raise ValueError(f"Unknown quack_baseline: {quack_baseline}")
+
+ ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_q = bytes_io / (ms_q * 1e-3) / 1e9
+ row.update(
+ dict(
+ quack_ms=float(ms_q),
+ quack_gbps=float(gbps_q),
+ quack_tbps=float(gbps_q / 1000.0),
+ speedup_vs_quack=float(ms_q / ms),
+ )
+ )
+
+ return row
+
+
+def _dtype_label(dtype: torch.dtype) -> str:
+ if dtype is torch.bfloat16:
+ return "bf16"
+ if dtype is torch.float16:
+ return "fp16"
+ return "fp32"
+
+
+def _print_table(rows: List[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ has_quack = any("quack_ms" in r for r in rows)
+ if has_quack:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
+ )
+ p.add_argument("--M", type=int, default=65536)
+ p.add_argument("--N", type=int, default=4096)
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)"
+ )
+ p.add_argument(
+ "--quack-baseline",
+ type=str,
+ default="kernel_inplace",
+ choices=["kernel", "kernel_inplace"],
+ help=(
+ "How to time Quack for the in-place fused op.\n"
+ "- kernel: Quack fused kernel only (preallocated out/residual_out).\n"
+ "- kernel_inplace: Quack fused kernel + 2 explicit copies to apply "
+ "in-place semantics (integration-realistic)."
+ ),
+ )
+ p.add_argument("--skip-verify", action="store_true")
+ p.add_argument("--json", type=str, default=None)
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ meta = collect_device_meta(torch.device("cuda"))
+
+ cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))]
+ rows: List[Dict[str, Any]] = []
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...",
+ flush=True,
+ )
+ rows.append(
+ bench_one(
+ M=int(M),
+ N=int(N),
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not bool(args.skip_verify),
+ quack_baseline=str(args.quack_baseline),
+ )
+ )
+
+ _print_table(rows)
+
+ if args.json:
+ write_json(
+ args.json,
+ meta,
+ rows,
+ extra=dict(
+ io_model_bytes="(4*M*N + N)*elem_size",
+ warmup_ms=int(args.warmup_ms),
+ rep_ms=int(args.iters),
+ method="triton.testing.do_bench(mean)",
+ note=(
+ "Oink fused_add_rmsnorm_inplace_ vs Quack baseline "
+ f"({args.quack_baseline}) when available"
+ ),
+ ),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
new file mode 100644
index 0000000..22fb48d
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
@@ -0,0 +1,268 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+HBM roofline microbenchmark for SM100 (GB200 / Blackwell).
+
+This script measures a STREAM-like bandwidth ceiling using a simple Triton kernel
+that performs a large contiguous copy (read + write) and/or triad (read + read + write)
+over a large buffer.
+
+Why this exists:
+- The benchmark harnesses for Oink ops report an "ours_tbps" derived from an IO model.
+- For roofline discussions, comparing against a *measured* device bandwidth ceiling
+ is often more meaningful than quoting a marketing/theoretical spec.
+
+Example:
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op copy --gb 2
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Tuple
+
+import torch
+import triton
+import triton.language as tl
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+from bench_utils import ( # noqa: E402
+ collect_device_meta,
+ do_bench_triton,
+ parse_dtype,
+ write_json,
+)
+
+
+@triton.jit
+def _copy_kernel(
+ x_ptr,
+ y_ptr,
+ n_elements,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask, other=0)
+ tl.store(y_ptr + offsets, x, mask=mask)
+
+
+@triton.jit
+def _triad_kernel(
+ x_ptr,
+ y_ptr,
+ n_elements,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask, other=0)
+ y = tl.load(y_ptr + offsets, mask=mask, other=0)
+ tl.store(y_ptr + offsets, x + y, mask=mask)
+
+
+def _bytes_moved(n_elements: int, elem_size: int, *, op: str) -> int:
+ if op == "copy":
+ return int(2 * n_elements * elem_size) # read x + write y
+ if op == "triad":
+ return int(3 * n_elements * elem_size) # read x + read y + write y
+ raise ValueError(f"Unsupported op: {op}")
+
+
+def bench_one(
+ *,
+ n_elements: int,
+ dtype: torch.dtype,
+ op: str,
+ block: int,
+ num_warps: int,
+ warmup_ms: int,
+ iters_ms: int,
+) -> Tuple[float, float]:
+ device = torch.device("cuda")
+ x = torch.empty((n_elements,), device=device, dtype=dtype)
+ y = torch.empty_like(x)
+ # Avoid pathological compression-friendly patterns (e.g. all-zeros) that can
+ # artificially inflate apparent bandwidth on some GPUs. Random-ish data is
+ # a closer match to ML workloads.
+ x.uniform_(-1, 1)
+ y.uniform_(-1, 1)
+
+ grid = (triton.cdiv(n_elements, block),)
+
+ if op == "copy":
+
+ def launch():
+ _copy_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+
+ elif op == "triad":
+
+ def launch():
+ _triad_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+
+ else:
+ raise ValueError(f"Unsupported op: {op}")
+
+ # Force compilation out of the timed region.
+ launch()
+ torch.cuda.synchronize()
+
+ ms = do_bench_triton(launch, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ moved = _bytes_moved(n_elements, x.element_size(), op=op)
+ tbps = moved / (ms * 1e-3) / 1e12
+ return ms, tbps
+
+
+def _print_summary(rows: List[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ best = max(rows, key=lambda r: float(r["tbps"]))
+ print("\nSummary (STREAM-like):")
+ print(
+ f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})"
+ )
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
+ )
+ p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"])
+ p.add_argument(
+ "--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)"
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Write JSON results to this path"
+ )
+ p.add_argument(
+ "--no-sweep",
+ action="store_true",
+ help="Disable tuning sweep; run a single config",
+ )
+ p.add_argument(
+ "--block", type=int, default=2048, help="BLOCK size when --no-sweep is set"
+ )
+ p.add_argument(
+ "--warps", type=int, default=8, help="num_warps when --no-sweep is set"
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ cap = (int(props.major), int(props.minor))
+ if cap != (10, 0):
+ raise RuntimeError(f"Expected SM100 (10,0), got {cap} ({props.name})")
+
+ elem_size = torch.tensor(0, dtype=dtype).element_size()
+ bytes_per_tensor = int(args.gb * (1024**3))
+ n_elements = max(1, bytes_per_tensor // elem_size)
+
+ ops: List[str]
+ if args.op == "both":
+ ops = ["copy", "triad"]
+ else:
+ ops = [args.op]
+
+ if args.no_sweep:
+ sweep: List[Tuple[int, int]] = [(int(args.block), int(args.warps))]
+ else:
+ # A tiny hand-tuned sweep that keeps compile overhead reasonable.
+ sweep = [
+ (1024, 4),
+ (1024, 8),
+ (2048, 4),
+ (2048, 8),
+ (4096, 8),
+ ]
+
+ print(f"Running on {props.name} (SM{props.major}{props.minor})")
+ print(f"- dtype: {args.dtype} (elem={elem_size}B)")
+ print(
+ f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)"
+ )
+ print(f"- ops: {ops}")
+ print(f"- sweep: {sweep}")
+
+ meta = collect_device_meta(device)
+ rows: List[Dict[str, Any]] = []
+ for op in ops:
+ for block, warps in sweep:
+ ms, tbps = bench_one(
+ n_elements=n_elements,
+ dtype=dtype,
+ op=op,
+ block=block,
+ num_warps=warps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ )
+ rows.append(
+ dict(
+ op=op,
+ dtype=str(args.dtype),
+ n_elements=int(n_elements),
+ elem_size_B=int(elem_size),
+ block=int(block),
+ num_warps=int(warps),
+ warmup_ms=int(args.warmup_ms),
+ rep_ms=int(args.iters),
+ ms=float(ms),
+ tbps=float(tbps),
+ )
+ )
+ print(
+ f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)"
+ )
+
+ _print_summary(rows)
+
+ if args.json:
+ # Write meta + detailed rows for reproducibility.
+ extra = dict(
+ bytes_model="copy:2*N*elem, triad:3*N*elem",
+ bytes_per_tensor=int(bytes_per_tensor),
+ gb_per_tensor=float(args.gb),
+ )
+ write_json(args.json, meta, rows, extra=extra)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
new file mode 100644
index 0000000..20895b7
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -0,0 +1,451 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import layernorm as oink_ln # noqa: E402
+
+try:
+ # Quack exposes LayerNorm through the RMSNorm module (is_layernorm=True path).
+ from quack.rmsnorm import layernorm_fwd as quack_layernorm # type: ignore
+except Exception:
+ quack_layernorm = None
+
+_VERIFY_TOL_Y = {
+ # Match Quack's unit-test defaults (tests/test_layernorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-4),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+# Quack checks rstd/mean (fp32) with a tighter fixed tolerance.
+_VERIFY_TOL_STATS = dict(atol=6e-4, rtol=6e-4)
+
+
+def bytes_io_model_layernorm(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ has_bias: bool,
+ return_rstd: bool,
+ return_mean: bool,
+ weight_dtype: torch.dtype = torch.float32,
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ total = 0
+ # Read x + write y
+ total += 2 * M * N * elem
+ # Read weight (+ optional bias) along feature dim
+ total += N * w_elem
+ if has_bias:
+ total += N * w_elem
+ # Optional per-row stats (fp32)
+ if return_rstd:
+ total += M * 4
+ if return_mean:
+ total += M * 4
+ return int(total)
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ b: torch.Tensor | None,
+ *,
+ eps: float,
+ return_rstd: bool,
+ return_mean: bool,
+) -> dict[str, object]:
+ tol_y = _VERIFY_TOL_Y[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_layernorm is not None and b is None)
+ else None
+ )
+ with torch.no_grad():
+ ours = oink_ln.layernorm(
+ x,
+ w,
+ bias=b,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ quack = None
+ if quack_layernorm is not None and b is None:
+ quack = quack_layernorm(
+ x,
+ w,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ torch.cuda.synchronize()
+
+ def _unpack(out):
+ if return_rstd and return_mean:
+ y, rstd, mean = out
+ elif return_rstd and not return_mean:
+ y, rstd = out
+ mean = None
+ elif return_mean and not return_rstd:
+ y, mean = out
+ rstd = None
+ else:
+ y, rstd, mean = out, None, None
+ return y, rstd, mean
+
+ y_o, rstd_o, mean_o = _unpack(ours)
+ y_q, rstd_q, mean_q = _unpack(quack) if quack is not None else (None, None, None)
+
+ # Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests:
+ # - compute ref output via F.layer_norm on float32
+ # - compute mean/rstd from float32 input
+ rstd_ref_all = (
+ torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None
+ )
+ mean_ref_all = (
+ torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None
+ )
+
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ y_ref_f32 = torch.nn.functional.layer_norm(x_f32, w.shape, w, b, eps)
+ y_ref = y_ref_f32.to(x.dtype)
+ torch.testing.assert_close(y_o[start:end], y_ref, **tol_y)
+ y_acc_ours.update(y_o[start:end], y_ref)
+ if y_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref, **tol_y)
+ assert y_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref)
+
+ # Per-row stats in fp32, as in Quack's tests.
+ if return_rstd or return_mean:
+ mean_f32 = x_f32.mean(dim=-1)
+ if return_mean:
+ assert mean_ref_all is not None
+ mean_ref_all[start:end] = mean_f32
+ if return_rstd:
+ var_f32 = ((x_f32 - mean_f32.unsqueeze(1)) ** 2).mean(dim=-1)
+ rstd_ref = 1.0 / torch.sqrt(var_f32 + eps)
+ assert rstd_ref_all is not None
+ rstd_ref_all[start:end] = rstd_ref
+
+ assert rstd_o is not None
+ torch.testing.assert_close(
+ rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS
+ )
+ if rstd_q is not None:
+ torch.testing.assert_close(
+ rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS
+ )
+
+ if return_mean:
+ mean_ref = mean_f32
+ assert mean_o is not None
+ torch.testing.assert_close(
+ mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS
+ )
+ if mean_q is not None:
+ torch.testing.assert_close(
+ mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS
+ )
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ if y_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+
+ if return_rstd:
+ assert rstd_o is not None and rstd_ref_all is not None
+ rstd_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref_all.numel()),
+ p99_target_samples=int(rstd_ref_all.numel()),
+ )
+ rstd_acc_ours.update(rstd_o, rstd_ref_all)
+ stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
+ if rstd_q is not None:
+ rstd_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref_all.numel()),
+ p99_target_samples=int(rstd_ref_all.numel()),
+ )
+ rstd_acc_quack.update(rstd_q, rstd_ref_all)
+ stats.update(
+ error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())
+ )
+
+ if return_mean:
+ assert mean_o is not None and mean_ref_all is not None
+ mean_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(mean_ref_all.numel()),
+ p99_target_samples=int(mean_ref_all.numel()),
+ )
+ mean_acc_ours.update(mean_o, mean_ref_all)
+ stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize()))
+ if mean_q is not None:
+ mean_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(mean_ref_all.numel()),
+ p99_target_samples=int(mean_ref_all.numel()),
+ )
+ mean_acc_quack.update(mean_q, mean_ref_all)
+ stats.update(
+ error_stats_to_row("quack_err_mean", mean_acc_quack.finalize())
+ )
+
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ eps: float,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+ return_rstd: bool,
+ return_mean: bool,
+ has_bias: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=torch.float32)
+ b = torch.randn(N, device=device, dtype=torch.float32) if has_bias else None
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(
+ x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean
+ )
+
+ bytes_io = bytes_io_model_layernorm(
+ M,
+ N,
+ dtype,
+ has_bias=has_bias,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ weight_dtype=w.dtype,
+ )
+
+ def fn_oink():
+ return oink_ln.layernorm(
+ x,
+ w,
+ bias=b,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if quack_layernorm is None or has_bias:
+ return (ms_oink, gbps_oink), None, stats
+
+ def fn_quack():
+ return quack_layernorm(
+ x,
+ w,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument("--return-rstd", action="store_true")
+ p.add_argument("--return-mean", action="store_true")
+ p.add_argument(
+ "--with-bias",
+ action="store_true",
+ help="Benchmark bias path (Quack compare skipped)",
+ )
+ p.add_argument(
+ "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)."
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument(
+ "--quack-suite",
+ action="store_true",
+ help="Run Quack-style batch/seq grid (hidden=4096)",
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference; Quack compare skipped when bias is enabled)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for M, N in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not args.skip_verify,
+ return_rstd=bool(args.return_rstd),
+ return_mean=bool(args.return_mean),
+ has_bias=bool(args.with_bias),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "eps": eps,
+ "return_rstd": bool(args.return_rstd),
+ "return_mean": bool(args.return_mean),
+ "with_bias": bool(args.with_bias),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "see bytes_io_model_layernorm in script",
+ },
+ )
+
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ if quack_layernorm is not None and (not args.with_bias):
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
new file mode 100644
index 0000000..50ecb2e
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -0,0 +1,464 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import csv
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+from triton.testing import do_bench as triton_do_bench
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ ensure_oink_src_on_path,
+ error_stats_to_row,
+ iter_row_blocks,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+try:
+ from quack.rmsnorm import rmsnorm_bwd as quack_rmsnorm_bwd # type: ignore
+except Exception:
+ quack_rmsnorm_bwd = None
+
+_VERIFY_TOL_DX = {
+ # Match Quack's unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+
+def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
+ """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ if sm >= 100:
+ return 8000.0
+ return 2000.0
+
+
+@dataclass
+class Result:
+ ms: float
+ gbps: float
+
+
+def do_bench_triton(fn, warmup_ms: int = 25, rep_ms: int = 100) -> float:
+ # Kernel-only timing consistent with the existing Oink forward harness.
+ return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean"))
+
+
+def bytes_io_model_bwd(
+ M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32
+) -> int:
+ """A simple IO model for RMSNorm backward.
+
+ This intentionally ignores partial-reduction scratch buffers (`dw_partial` /
+ `db_partial`) since those are highly implementation-specific and depend on
+ sm_count; we still report speedups and times regardless.
+ """
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ # Read x + dout + write dx
+ total = 3 * M * N * elem
+ # Read weight + write dw
+ total += 2 * N * w_elem
+ # Read rstd (fp32 per row)
+ total += M * 4
+ return int(total)
+
+
+def parse_dtype(s: str) -> torch.dtype:
+ s = s.lower()
+ if s == "bf16":
+ return torch.bfloat16
+ if s == "fp16":
+ return torch.float16
+ if s == "fp32":
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {s}")
+
+
+def parse_configs(s: str) -> List[Tuple[int, int]]:
+ out: List[Tuple[int, int]] = []
+ for part in s.split(","):
+ m, n = part.lower().split("x")
+ out.append((int(m), int(n)))
+ return out
+
+
+def quack_suite_configs() -> List[Tuple[int, int, int]]:
+ """Return (batch, seq, hidden) triples following Quack's grid (hidden=4096)."""
+ batch_sizes = [1, 4, 8, 16, 32]
+ seq_lengths = [8192, 16384, 32768, 65536, 131072]
+ hidden = 4096
+ cfgs: List[Tuple[int, int, int]] = []
+ for bs in batch_sizes:
+ for sl in seq_lengths:
+ M = bs * sl
+ if M * hidden > (2**31):
+ continue
+ cfgs.append((bs, sl, hidden))
+ return cfgs
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ dout: torch.Tensor,
+ rstd: torch.Tensor,
+ *,
+ has_bias: bool,
+ has_residual: bool,
+) -> dict[str, object]:
+ tol_dx = _VERIFY_TOL_DX[x.dtype]
+ ref_block_rows = 1024
+ M, N = int(x.shape[0]), int(x.shape[1])
+
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_bwd is not None
+ else None
+ )
+
+ with torch.no_grad():
+ dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=has_bias,
+ has_residual=has_residual,
+ )
+
+ dx_quack = None
+ dw_quack = None
+ db_quack = None
+ dres_quack = None
+ if quack_rmsnorm_bwd is not None:
+ dx_quack, dw_quack, db_quack, dres_quack = quack_rmsnorm_bwd(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=has_bias,
+ has_residual=has_residual,
+ )
+ torch.cuda.synchronize()
+
+ # Pure-PyTorch reference, matching Quack's rmsnorm_bwd_ref (float32 math for x_hat).
+ # Chunk over rows to avoid materializing an (M, N) float32 tensor for large shapes.
+ dw_accum = torch.zeros((N,), device=x.device, dtype=torch.float32)
+ w_f32 = w.float()
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ rstd_blk = rstd[start:end]
+ x_hat = x_f32 * rstd_blk.unsqueeze(1)
+ # Match Quack/PyTorch reference behavior: gradient math uses float32
+ # intermediates even when (x, w, dout) are bf16/fp16.
+ dout_f32 = dout[start:end].float()
+ wdy = dout_f32 * w_f32
+ c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
+ dx_ref = ((wdy - x_hat * c1) * rstd_blk.unsqueeze(1)).to(x.dtype)
+
+ torch.testing.assert_close(dx_oink[start:end], dx_ref, **tol_dx)
+ dx_acc_ours.update(dx_oink[start:end], dx_ref)
+ if dx_quack is not None:
+ torch.testing.assert_close(dx_quack[start:end], dx_ref, **tol_dx)
+ assert dx_acc_quack is not None
+ dx_acc_quack.update(dx_quack[start:end], dx_ref)
+
+ if dw_oink is not None:
+ dw_accum += (dout_f32 * x_hat).sum(dim=0)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ if dx_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+
+ if dw_oink is not None:
+ dw_ref = dw_accum.to(w.dtype)
+ if w.dtype == torch.float32:
+ # Weight grad is sensitive to reduction order; use a slightly larger
+ # absolute tolerance in the suite harness (Quack's unit tests use
+ # smaller M, where dw is typically tighter).
+ dw_tol = dict(atol=2e-3, rtol=1e-3)
+ else:
+ # For fp16/bf16 weights, `dw` is low-precision and grows with M; use an
+ # ulp/magnitude-aware tolerance rather than a fixed epsilon.
+ dw_ref_f32 = dw_ref.to(torch.float32)
+ dw_oink_f32 = dw_oink.to(torch.float32)
+ scale = float(dw_ref_f32.abs().max().item())
+ dw_atol = max(2.0 * torch.finfo(w.dtype).eps * scale, 1e-3)
+ dw_tol = dict(atol=dw_atol, rtol=1e-3)
+ torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol)
+ if dw_quack is not None:
+ torch.testing.assert_close(
+ dw_quack.to(torch.float32), dw_ref_f32, **dw_tol
+ )
+ dw_tol = None # handled above
+ if dw_tol is not None:
+ torch.testing.assert_close(dw_oink, dw_ref, **dw_tol)
+ if dw_quack is not None:
+ torch.testing.assert_close(dw_quack, dw_ref, **dw_tol)
+
+ # Record weight-grad error stats (small, so exact p99 over the full vector).
+ dw_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())
+ )
+ dw_acc_ours.update(dw_oink, dw_ref)
+ stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize()))
+ if dw_quack is not None:
+ dw_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())
+ )
+ dw_acc_quack.update(dw_quack, dw_ref)
+ stats.update(error_stats_to_row("quack_err_dw", dw_acc_quack.finalize()))
+
+ assert db_oink is None and db_quack is None
+ assert dres_oink is None and dres_quack is None
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ weight_dtype: torch.dtype,
+ iters_ms: int,
+ eps: float,
+ warmup_ms: int,
+ verify: bool,
+) -> Tuple[Result, Result | None, dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=weight_dtype)
+ dout = torch.randn(M, N, device=device, dtype=dtype)
+ # rstd is fp32 per row; compute once outside the timed region.
+ with torch.no_grad():
+ xf = x.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False)
+
+ def fn_oink():
+ return oink_rmsnorm.rmsnorm_backward(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+ ours = Result(ms=ms_oink, gbps=gbps_oink)
+
+ if quack_rmsnorm_bwd is None:
+ return ours, None, stats
+
+ def fn_quack():
+ return quack_rmsnorm_bwd(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return ours, Result(ms=ms_quack, gbps=gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--weight-dtype",
+ type=str,
+ default="fp32",
+ choices=["same", "fp16", "bf16", "fp32"],
+ help="RMSNorm weight dtype. `same` matches activation dtype.",
+ )
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument(
+ "--iters",
+ type=int,
+ default=100,
+ help="Triton do_bench rep_ms (kernel-only).",
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch RMSNorm backward reference)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ if args.weight_dtype == "same":
+ weight_dtype = dtype
+ else:
+ weight_dtype = parse_dtype(args.weight_dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+
+ rows_out: list[dict[str, object]] = []
+
+ for M, N in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ ours, quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ iters_ms=int(args.iters),
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ verify=not args.skip_verify,
+ )
+
+ row: dict[str, object] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "weight_dtype": args.weight_dtype,
+ "ours_ms": ours.ms,
+ "ours_gbps": ours.gbps,
+ "ours_tbps": ours.gbps / 1000.0,
+ "ours_hbm_frac": ours.gbps / hbm_peak,
+ }
+ if quack is not None:
+ row.update(
+ {
+ "quack_ms": quack.ms,
+ "quack_gbps": quack.gbps,
+ "quack_tbps": quack.gbps / 1000.0,
+ "speedup_vs_quack": quack.ms / ours.ms,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ file_exists = os.path.exists(args.csv)
+ with open(args.csv, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(row.keys()))
+ if not file_exists:
+ writer.writeheader()
+ writer.writerow(row)
+
+ if args.json is not None:
+ meta = collect_device_meta(device)
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "see bytes_io_model_bwd in script",
+ "weight_dtype": str(args.weight_dtype),
+ },
+ )
+
+ # Print a small summary table.
+ headers = ["M", "N", "dtype", "ours_ms", "ours_tbps", "ours_hbm_frac"]
+ if quack_rmsnorm_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: list[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
new file mode 100644
index 0000000..39e6cd7
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -0,0 +1,384 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+try:
+ from quack.rmsnorm import rmsnorm_fwd as quack_rmsnorm_fwd # type: ignore
+except Exception:
+ quack_rmsnorm_fwd = None
+
+_VERIFY_TOL_Y = {
+ # Match Quack's unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ # NOTE: bf16 ulp grows with magnitude; a slightly larger rtol is more robust
+ # for the large-M suite shapes (and fused paths that can see larger values).
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+_VERIFY_TOL_RSTD = {
+ torch.float32: dict(atol=1e-5, rtol=1e-5),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-3, rtol=1e-3),
+}
+
+
+def bytes_io_model_fwd(
+ M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ # Read x + write y
+ total = 2 * M * N * elem
+ # Read weight
+ total += N * w_elem
+ return int(total)
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ # DSv3-ish hidden sizes used throughout the Oink/Quack SM100 suite tables.
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ *,
+ eps: float,
+ store_rstd: bool,
+) -> dict[str, object]:
+ tol_y = _VERIFY_TOL_Y[x.dtype]
+ tol_rstd = _VERIFY_TOL_RSTD[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd is not None
+ else None
+ )
+
+ with torch.no_grad():
+ y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward(
+ x,
+ weight=w,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ y_q = None
+ rstd_q = None
+ if quack_rmsnorm_fwd is not None:
+ # Quack returns (out, residual_out, rstd).
+ y_q, res_q, rstd_q = quack_rmsnorm_fwd(
+ x,
+ w,
+ bias=None,
+ residual=None,
+ out_dtype=None,
+ residual_dtype=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
+ # Pure-PyTorch reference (float32 accumulation), chunked over rows to avoid
+ # materializing an (M, N) float32 tensor for large Quack-suite shapes.
+ w_f32 = w.float()
+ rstd_ref = torch.empty((M,), device=x.device, dtype=torch.float32)
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ rstd_blk = torch.rsqrt(x_f32.square().mean(dim=-1) + eps)
+ rstd_ref[start:end] = rstd_blk
+
+ y_ref_blk_f32 = (x_f32 * rstd_blk.unsqueeze(1)) * w_f32
+ y_ref_blk = y_ref_blk_f32.to(x.dtype)
+ torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol_y)
+ y_acc_ours.update(y_o[start:end], y_ref_blk)
+ if y_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol_y)
+ assert y_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref_blk)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ if y_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+
+ if store_rstd:
+ assert rstd_o is not None
+ torch.testing.assert_close(rstd_o, rstd_ref, **tol_rstd)
+ if y_q is not None:
+ assert rstd_q is not None
+ torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd)
+ # Stats for rstd are cheap (M elements); compute exact p99 over all rows.
+ rstd_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())
+ )
+ rstd_acc_ours.update(rstd_o, rstd_ref)
+ stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
+ if rstd_q is not None:
+ rstd_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref.numel()),
+ p99_target_samples=int(rstd_ref.numel()),
+ )
+ rstd_acc_quack.update(rstd_q, rstd_ref)
+ stats.update(
+ error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())
+ )
+ # Residual output semantics differ slightly across implementations:
+ # - Oink returns `None` when residual is None.
+ # - Quack returns `x` as a safe alias in that case.
+ #
+ # For parity we focus on `y` (and optional `rstd`) for the residual=None path.
+ assert res_o is None
+ if quack_rmsnorm_fwd is not None:
+ assert res_q is x
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ weight_dtype: torch.dtype,
+ eps: float,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+ store_rstd: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=weight_dtype)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x, w, eps=eps, store_rstd=store_rstd)
+
+ bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype)
+
+ def fn_oink():
+ return oink_rmsnorm.rmsnorm_forward(
+ x,
+ weight=w,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if quack_rmsnorm_fwd is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ def fn_quack():
+ return quack_rmsnorm_fwd(
+ x,
+ w,
+ bias=None,
+ residual=None,
+ out_dtype=None,
+ residual_dtype=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--weight-dtype",
+ type=str,
+ default="fp32",
+ choices=["same", "fp16", "bf16", "fp32"],
+ help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).",
+ )
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument(
+ "--store-rstd", action="store_true", help="Also write rstd (fp32 per row)"
+ )
+ p.add_argument(
+ "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)."
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ if args.weight_dtype == "same":
+ weight_dtype = dtype
+ else:
+ weight_dtype = parse_dtype(args.weight_dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for M, N in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not args.skip_verify,
+ store_rstd=bool(args.store_rstd),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "weight_dtype": args.weight_dtype,
+ "eps": eps,
+ "store_rstd": bool(args.store_rstd),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "(2*M*N)*elem_size + N*weight_elem_size",
+ "store_rstd": bool(args.store_rstd),
+ "weight_dtype": str(args.weight_dtype),
+ },
+ )
+
+ # Print a compact summary table.
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ if quack_rmsnorm_fwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
new file mode 100644
index 0000000..a5b2b3c
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
@@ -0,0 +1,341 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import softmax as oink_softmax # noqa: E402
+
+try:
+ from quack.softmax import softmax_bwd as quack_softmax_bwd # type: ignore
+ from quack.softmax import softmax_fwd as quack_softmax_fwd # type: ignore
+except Exception:
+ quack_softmax_fwd = None
+ quack_softmax_bwd = None
+
+_VERIFY_TOL = {
+ # Match Quack's unit-test defaults (tests/test_softmax.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-4),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+
+def bytes_io_model_softmax(M: int, N: int, dtype: torch.dtype, *, mode: str) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ if mode == "fwd":
+ return int(2 * M * N * elem) # read x + write y
+ if mode == "bwd":
+ return int(3 * M * N * elem) # read dy + read y + write dx
+ if mode == "fwd_bwd":
+ # Logical IO for dx given (x, dy): read x + read dy + write dx.
+ # (The intermediate y=softmax(x) is an implementation detail and is
+ # intentionally not counted here.)
+ return int(3 * M * N * elem)
+ raise ValueError(f"Unsupported mode: {mode}")
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(x: torch.Tensor) -> dict[str, object]:
+ tol = _VERIFY_TOL[x.dtype]
+ ref_block_rows = 4096
+ dy = torch.randn_like(x) # upstream grad
+
+ with torch.no_grad():
+ y_o = oink_softmax.softmax_forward(x)
+ dx_o = oink_softmax.softmax_backward(dy, y_o)
+ dx_fused_o = oink_softmax.softmax_fwd_bwd(dy, x)
+
+ y_q = None
+ dx_q = None
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+ y_q = quack_softmax_fwd(x)
+ dx_q = quack_softmax_bwd(dy, y_q)
+
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_softmax_fwd is not None and quack_softmax_bwd is not None)
+ else None
+ )
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_softmax_fwd is not None and quack_softmax_bwd is not None)
+ else None
+ )
+
+ # Match Quack tests: compare to PyTorch softmax refs (fwd+bwd), chunked.
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_blk = x[start:end]
+ dy_blk = dy[start:end]
+ y_ref_blk = torch.softmax(x_blk, dim=-1)
+ dot = torch.sum(dy_blk * y_ref_blk, dim=-1, keepdim=True, dtype=torch.float32)
+ dx_ref_blk = (dy_blk - dot.to(dy_blk.dtype)) * y_ref_blk
+
+ torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol)
+ torch.testing.assert_close(dx_o[start:end], dx_ref_blk, **tol)
+ torch.testing.assert_close(dx_fused_o[start:end], dx_ref_blk, **tol)
+ y_acc_ours.update(y_o[start:end], y_ref_blk)
+ dx_acc_ours.update(dx_o[start:end], dx_ref_blk)
+ dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref_blk)
+ if y_q is not None and dx_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol)
+ torch.testing.assert_close(dx_q[start:end], dx_ref_blk, **tol)
+ assert y_acc_quack is not None and dx_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref_blk)
+ dx_acc_quack.update(dx_q[start:end], dx_ref_blk)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize()))
+ if y_acc_quack is not None and dx_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ warmup_ms: int,
+ iters_ms: int,
+ mode: str,
+ verify: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ dy = torch.randn_like(x)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x)
+
+ bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode)
+
+ if mode == "fwd":
+
+ def fn_oink():
+ return oink_softmax.softmax_forward(x)
+
+ fn_quack = None
+ if quack_softmax_fwd is not None:
+
+ def fn_quack():
+ return quack_softmax_fwd(x)
+
+ elif mode == "bwd":
+ with torch.no_grad():
+ y_o = oink_softmax.softmax_forward(x)
+ y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None
+
+ def fn_oink():
+ return oink_softmax.softmax_backward(dy, y_o)
+
+ fn_quack = None
+ if quack_softmax_bwd is not None and y_q is not None:
+
+ def fn_quack():
+ return quack_softmax_bwd(dy, y_q)
+
+ elif mode == "fwd_bwd":
+
+ def fn_oink():
+ return oink_softmax.softmax_fwd_bwd(dy, x)
+
+ fn_quack = None
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+
+ def fn_quack():
+ return quack_softmax_bwd(dy, quack_softmax_fwd(x))
+
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if fn_quack is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]
+ )
+ p.add_argument(
+ "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)."
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch softmax)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...",
+ flush=True,
+ )
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ mode=str(args.mode),
+ verify=not args.skip_verify,
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "mode": args.mode,
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "mode-dependent: fwd=2*M*N, bwd=3*M*N, fwd_bwd=3*M*N (all * elem_size; fwd_bwd counts logical x+dy+dx)",
+ },
+ )
+
+ headers = ["M", "N", "mode", "ours_ms", "ours_tbps"]
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
new file mode 100644
index 0000000..96b5b83
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
@@ -0,0 +1,2259 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
new file mode 100644
index 0000000..254623e
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
@@ -0,0 +1,2600 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg
new file mode 100644
index 0000000..9db31a5
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg
@@ -0,0 +1,2936 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg
new file mode 100644
index 0000000..c392959
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg
@@ -0,0 +1,1687 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg
new file mode 100644
index 0000000..0d4c1ae
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg
@@ -0,0 +1,2600 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
new file mode 100644
index 0000000..1780d62
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
@@ -0,0 +1,2580 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg
new file mode 100644
index 0000000..e3bcd46
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg
@@ -0,0 +1,2280 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg
new file mode 100644
index 0000000..e5cecac
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg
@@ -0,0 +1,2621 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg
new file mode 100644
index 0000000..1575906
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg
@@ -0,0 +1,2957 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg
new file mode 100644
index 0000000..66a3075
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg
@@ -0,0 +1,1708 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg
new file mode 100644
index 0000000..d87b7b9
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg
@@ -0,0 +1,2621 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg
new file mode 100644
index 0000000..5c849b5
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg
@@ -0,0 +1,2601 @@
+
+
+
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
new file mode 100644
index 0000000..88eebdf
--- /dev/null
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -0,0 +1,471 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite
+JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`.
+
+The intent is to match Quack's README visual style:
+ - 3 horizontal panels (suite-dependent):
+ - Quack-suite: RMSNorm / Softmax / CrossEntropy
+ - DSv3 (hidden-size): Fused Add+RMSNorm / Softmax / LayerNorm
+ - DSv3 (all ops, 4-panel): Fused Add+RMSNorm / Softmax / LayerNorm / CrossEntropy
+ - DSv3 CrossEntropy: CrossEntropy-only (single panel)
+ - y-axis: model memory bandwidth (GB/s) derived from an IO model
+ - x-axis: a small set of labeled (M, N) shape points
+ - thick lines + markers, dashed y-grid, compact legend
+ - optional horizontal roofline line (measured STREAM-like HBM peak)
+
+Example:
+ python oink/benchmarks/readme/plot_quack_style_svg.py \\
+ --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\
+ --suite quack_suite \\
+ --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\
+ --out oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
+
+For completeness, we can also include LayerNorm as an extra panel (Quack's
+own README plot does not include LayerNorm):
+ python oink/benchmarks/readme/plot_quack_style_svg.py \\
+ --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\
+ --suite quack_suite \\
+ --include-layernorm \\
+ --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\
+ --out oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
+
+Note on DSv3 suite:
+- The DSv3 plot intentionally covers only the hidden-size ops (fused Add+RMSNorm,
+ Softmax, LayerNorm) which share the same `(M, N)` sweep.
+- CrossEntropy in DSv3 uses a vocab-size-like `N` sweep and is plotted separately
+ via `--suite dsv3_cross_entropy` to avoid a mixed x-axis with gaps.
+- For README embedding convenience, `--suite dsv3_all` renders a 4-panel
+ single-row figure where the CrossEntropy panel uses its own x-axis.
+- The RMSNorm panel uses the real block primitive (fused residual-add + RMSNorm)
+ when available: `fused_add_rmsnorm_dsv3.json`.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import math
+import os
+from collections import defaultdict
+from statistics import median
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
+
+
+def _load_json(path: str) -> Dict[str, Any]:
+ with open(path) as f:
+ return json.load(f)
+
+
+def _fmt_k(v: int) -> str:
+ # Match Quack's x-axis labels: "32K" means 32768 (1024-based).
+ if v % 1024 == 0:
+ return f"{v // 1024}K"
+ return str(v)
+
+
+def _shape_label(m: int, n: int) -> str:
+ return f"({_fmt_k(m)}, {_fmt_k(n)})"
+
+
+def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]:
+ # Prefer GB/s in the JSON if present; otherwise fall back to TB/s.
+ gbps_key = f"{prefix}_gbps"
+ tbps_key = f"{prefix}_tbps"
+ if gbps_key in row and row[gbps_key] is not None:
+ return float(row[gbps_key])
+ if tbps_key in row and row[tbps_key] is not None:
+ return float(row[tbps_key]) * 1000.0
+ return None
+
+
+def _aggregate_by_shape(
+ rows: Sequence[Mapping[str, Any]],
+) -> Dict[Tuple[int, int], Dict[str, float]]:
+ """Aggregate duplicate (M, N) rows using median (more robust than mean)."""
+ buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ for r in rows:
+ m = int(r["M"])
+ n = int(r["N"])
+ ours = _gbps_from_row("ours", r)
+ quack = _gbps_from_row("quack", r)
+ if ours is not None:
+ buckets[(m, n)]["ours"].append(ours)
+ if quack is not None:
+ buckets[(m, n)]["quack"].append(quack)
+
+ out: Dict[Tuple[int, int], Dict[str, float]] = {}
+ for k, vs in buckets.items():
+ if not vs["ours"] or not vs["quack"]:
+ continue
+ out[k] = dict(ours=float(median(vs["ours"])), quack=float(median(vs["quack"])))
+ return out
+
+
+def _sort_shapes(shapes: Iterable[Tuple[int, int]]) -> List[Tuple[int, int]]:
+ # Sort by N then M to keep the x-axis stable across panels.
+ return sorted(set(shapes), key=lambda x: (x[1], x[0]))
+
+
+def _read_roofline_gbps(path: str) -> float:
+ payload = _load_json(path)
+ rows = payload.get("rows", [])
+ best_tbps = max(float(r["tbps"]) for r in rows)
+ return best_tbps * 1000.0
+
+
+def _ensure_matplotlib():
+ try:
+ import matplotlib as mpl # noqa: F401
+ import matplotlib.pyplot as plt # noqa: F401
+ except Exception as e: # pragma: no cover
+ raise SystemExit(
+ "matplotlib is required to generate SVG plots.\n"
+ "Install with: `python -m pip install matplotlib`"
+ ) from e
+
+
+def _plot(
+ *,
+ panels: Sequence[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]],
+ roofline_gbps: Optional[float],
+ out_path: str,
+ title: str,
+ shape_policy: str,
+ per_panel_x: bool,
+) -> None:
+ _ensure_matplotlib()
+ import matplotlib as mpl
+ import matplotlib.pyplot as plt
+
+ mpl.rcParams.update(
+ {
+ # Quack-style: embed glyphs as paths for consistent rendering.
+ "svg.fonttype": "path",
+ "font.family": "DejaVu Sans",
+ "axes.titlesize": 18,
+ "axes.labelsize": 16,
+ "xtick.labelsize": 10,
+ "ytick.labelsize": 12,
+ }
+ )
+
+ # Colors roughly matching Quack's SVG palette.
+ COLOR_OINK = "#5ba3f5"
+ COLOR_QUACK = "#ff4444"
+ COLOR_ROOF = "#4d4d4d"
+
+ fig, axes = plt.subplots(
+ nrows=1,
+ ncols=len(panels),
+ figsize=(6.0 * len(panels), 5.6),
+ constrained_layout=False,
+ sharey=True,
+ )
+ if len(panels) == 1:
+ axes = [axes]
+
+ max_y = 0.0
+ for ax, (panel_title, data) in zip(axes, panels):
+ if per_panel_x:
+ shapes = _sort_shapes(data.keys())
+ else:
+ # Quack-style plots use a single shared x-axis across panels. Prefer
+ # the intersection so every panel has a value at every x tick
+ # (cleaner than rendering gaps), and fall back to the union if the
+ # intersection is empty.
+ shape_sets = [set(d.keys()) for _n, d in panels]
+ if shape_policy in {"first", "primary"}:
+ shapes = _sort_shapes(shape_sets[0]) if shape_sets else []
+ elif shape_policy == "intersection" and shape_sets:
+ common = set.intersection(*shape_sets)
+ shapes = _sort_shapes(common) if common else []
+ elif shape_policy == "union":
+ shapes = _sort_shapes(s for _n, d in panels for s in d.keys())
+ else:
+ raise ValueError(f"Unsupported shape_policy: {shape_policy}")
+ if not shapes:
+ shapes = _sort_shapes(s for _n, d in panels for s in d.keys())
+
+ x = list(range(len(shapes)))
+ x_labels = [_shape_label(m, n) for (m, n) in shapes]
+
+ ours_y: List[float] = []
+ quack_y: List[float] = []
+ for s in shapes:
+ rec = data.get(s)
+ if rec is None: # only possible in shared-x mode with union
+ ours_y.append(math.nan)
+ quack_y.append(math.nan)
+ continue
+ ours_y.append(float(rec["ours"]))
+ quack_y.append(float(rec["quack"]))
+ max_y = max(
+ max_y,
+ *(v for v in ours_y if math.isfinite(v)),
+ *(v for v in quack_y if math.isfinite(v)),
+ )
+
+ ax.plot(
+ x,
+ ours_y,
+ marker="o",
+ linewidth=5,
+ markersize=7,
+ color=COLOR_OINK,
+ label="KernelAgent-Oink (ours)",
+ )
+ ax.plot(
+ x,
+ quack_y,
+ marker="o",
+ linewidth=5,
+ markersize=7,
+ color=COLOR_QUACK,
+ label="Quack",
+ )
+ if roofline_gbps is not None:
+ ax.axhline(
+ roofline_gbps,
+ color=COLOR_ROOF,
+ linewidth=3,
+ linestyle=(0, (4, 6)),
+ label="HBM peak (measured)" if ax is axes[0] else None,
+ )
+ max_y = max(max_y, float(roofline_gbps))
+
+ ax.set_title(panel_title)
+ ax.set_xticks(x)
+ ax.set_xticklabels(x_labels, rotation=-45, ha="left")
+ if per_panel_x:
+ # DSv3 "all ops" figure: each panel has its own x-axis. Make the
+ # semantics explicit so readers don't assume the same `N` meaning
+ # across panels (CrossEntropy uses a classes/vocab-shard-like axis).
+ if "cross" in panel_title.lower():
+ ax.set_xlabel("Shape (M, C classes)")
+ else:
+ ax.set_xlabel("Shape (M, N hidden)")
+
+ # Quack-like dashed y-grid.
+ ax.grid(axis="y", linestyle=(0, (4, 7.2)), linewidth=0.8, color="#b0b0b0")
+ ax.set_axisbelow(True)
+
+ # Light spines (Quack SVG uses a light gray frame).
+ for spine in ax.spines.values():
+ spine.set_color("#d3d3d3")
+ spine.set_linewidth(1.5)
+
+ axes[0].set_ylabel("Memory Bandwidth (GB/s)")
+
+ # A little headroom above the tallest curve/roofline.
+ ymax = max_y * 1.08 if max_y > 0 else 1.0
+ for ax in axes:
+ ax.set_ylim(0.0, ymax)
+
+ # Tight layout for the axes area, reserving headroom for the suptitle and a
+ # shared legend. In some matplotlib versions, figure-level legends can
+ # overlap the middle panel title unless we reserve a slightly taller header
+ # band.
+ fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.70))
+
+ # Single shared legend across the top (like Quack), but keep it inside the
+ # reserved header band so it doesn't overlap the middle panel title.
+ handles, labels = axes[0].get_legend_handles_labels()
+ # Quack's legend fits nicely in one row because their plots are 3-panel and
+ # therefore wide. For single-panel figures, a 3-column legend can overflow
+ # the canvas and get clipped in the SVG, so we stack it vertically.
+ legend_ncol = min(3, len(labels))
+ legend_fontsize = 13
+ if len(panels) == 1:
+ legend_ncol = 1
+ legend_fontsize = 12
+ fig.legend(
+ handles,
+ labels,
+ loc="upper center",
+ ncol=legend_ncol,
+ frameon=False,
+ bbox_to_anchor=(0.5, 0.91),
+ fontsize=legend_fontsize,
+ handlelength=2.5,
+ )
+ # Single-panel figures (e.g. DSv3 CrossEntropy) are much narrower than the
+ # Quack-style 3-panel plots; use a slightly smaller suptitle font to avoid
+ # clipping in the exported SVG.
+ suptitle_fs = 22 if len(panels) > 1 else 18
+ fig.suptitle(title, y=0.98, fontsize=suptitle_fs)
+
+ out_path = os.path.abspath(out_path)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ # Use a tight bounding box so rotated x tick labels and the figure-level
+ # legend don't get clipped in SVG exports (matplotlib can be fragile here
+ # across versions).
+ fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.02)
+ plt.close(fig)
+
+
+def _panel_files_for_suite(suite: str) -> List[Tuple[str, str]]:
+ if suite == "quack_suite":
+ return [
+ ("RMSNorm (fp32 weight)", "rmsnorm_fwd_quack_suite_wfp32.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_quack_suite.json"),
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_quack_suite.json"),
+ ]
+ if suite == "dsv3":
+ return [
+ ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"),
+ ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"),
+ ]
+ if suite == "dsv3_all":
+ return [
+ ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"),
+ ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"),
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"),
+ ]
+ if suite == "dsv3_cross_entropy":
+ return [
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"),
+ ]
+ raise ValueError(f"Unsupported suite: {suite}")
+
+
+def _layernorm_file_for_suite(suite: str) -> str:
+ if suite == "quack_suite":
+ return "layernorm_fwd_quack_suite.json"
+ raise ValueError(f"Unsupported suite: {suite}")
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(
+ description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs."
+ )
+ p.add_argument(
+ "--in-dir",
+ type=str,
+ required=True,
+ help="Directory containing suite JSON outputs",
+ )
+ p.add_argument(
+ "--suite",
+ type=str,
+ default="quack_suite",
+ choices=["quack_suite", "dsv3", "dsv3_all", "dsv3_cross_entropy"],
+ )
+ p.add_argument(
+ "--include-layernorm",
+ action="store_true",
+ help="Add a LayerNorm (fwd) panel (only meaningful for `--suite quack_suite`).",
+ )
+ p.add_argument(
+ "--shape-policy",
+ type=str,
+ default="intersection",
+ choices=["intersection", "union", "first"],
+ help=(
+ "How to pick x-axis shapes across panels. "
+ "`intersection` matches Quack-style (only shapes common to every panel). "
+ "`first` uses the first panel's shapes (keeps DSv3 N=7168 visible). "
+ "`union` includes every shape across panels (may create gaps)."
+ ),
+ )
+ p.add_argument(
+ "--roofline-json",
+ type=str,
+ default=None,
+ help="Optional /tmp/hbm_roofline_sm100_*.json path",
+ )
+ p.add_argument("--out", type=str, required=True, help="Output SVG path")
+ p.add_argument(
+ "--title", type=str, default=None, help="Optional figure title override"
+ )
+ args = p.parse_args()
+
+ in_dir = os.path.abspath(args.in_dir)
+ if not os.path.isdir(in_dir):
+ raise SystemExit(f"--in-dir is not a directory: {in_dir}")
+
+ roofline_gbps = (
+ _read_roofline_gbps(args.roofline_json) if args.roofline_json else None
+ )
+
+ panel_files = list(_panel_files_for_suite(str(args.suite)))
+ if args.include_layernorm:
+ if args.suite != "quack_suite":
+ raise SystemExit(
+ "--include-layernorm is only supported for `--suite quack_suite`."
+ )
+ panel_files.append(
+ ("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite)))
+ )
+
+ panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = []
+ for panel_title, filename in panel_files:
+ path = os.path.join(in_dir, filename)
+ if not os.path.exists(path):
+ raise SystemExit(f"Missing required JSON: {path}")
+ payload = _load_json(path)
+ rows = payload.get("rows", [])
+ if not isinstance(rows, list):
+ rows = []
+ panels.append((panel_title, _aggregate_by_shape(rows)))
+
+ if args.title is not None:
+ title = str(args.title)
+ else:
+ # Try to infer dtype from the first panel's JSON.
+ first_json = os.path.join(in_dir, panel_files[0][1])
+ payload = _load_json(first_json)
+ rows = payload.get("rows", [])
+ dtype = rows[0].get("dtype", "") if rows else ""
+ if args.suite == "quack_suite":
+ suite_name = "Quack-suite"
+ elif args.suite == "dsv3":
+ suite_name = "DSv3 (hidden-size ops)"
+ elif args.suite == "dsv3_all":
+ suite_name = "DSv3 (4 ops)"
+ elif args.suite == "dsv3_cross_entropy":
+ # Keep this short: this suite is rendered as a single panel, so the
+ # figure is much narrower than the 3-panel plots.
+ suite_name = "DSv3 CrossEntropy"
+ else:
+ suite_name = str(args.suite)
+ suffix = (
+ " (+LayerNorm)"
+ if (args.suite == "quack_suite" and args.include_layernorm)
+ else ""
+ )
+ if args.suite == "dsv3_cross_entropy":
+ title = f"SM100 {dtype.upper()} — {suite_name}{suffix}"
+ else:
+ title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}"
+
+ _plot(
+ panels=panels,
+ roofline_gbps=roofline_gbps,
+ out_path=str(args.out),
+ title=title,
+ shape_policy=str(args.shape_policy),
+ per_panel_x=(str(args.suite) == "dsv3_all"),
+ )
+ print(f"Wrote: {os.path.abspath(args.out)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
new file mode 100644
index 0000000..af33e38
--- /dev/null
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -0,0 +1,337 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import os
+import subprocess
+import sys
+from datetime import datetime
+from typing import List, Tuple
+
+
+def _ts() -> str:
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
+
+
+def _run(cmd: List[str], *, dry_run: bool) -> None:
+ print("+", " ".join(cmd), flush=True)
+ if dry_run:
+ return
+ subprocess.run(cmd, check=True)
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--out-dir",
+ type=str,
+ default=None,
+ help="Directory to write JSON outputs (default: /tmp/kernelagent_oink_sm100_suite_)",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)",
+ )
+ p.add_argument(
+ "--dry-run", action="store_true", help="Print commands without executing them"
+ )
+ args = p.parse_args()
+
+ # Standardize env for standalone runs outside the vLLM plugin.
+ os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+ os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+ out_dir = args.out_dir or f"/tmp/kernelagent_oink_sm100_suite_{_ts()}"
+ os.makedirs(out_dir, exist_ok=True)
+
+ here = os.path.dirname(os.path.abspath(__file__))
+ bench_dir = os.path.abspath(os.path.join(here, "..", "benchmark"))
+ py = sys.executable
+
+ def script(name: str) -> str:
+ return os.path.join(bench_dir, name)
+
+ common = ["--dtype", args.dtype]
+ if args.skip_verify:
+ common = [*common, "--skip-verify"]
+
+ runs: List[Tuple[str, List[str]]] = [
+ (
+ "rmsnorm_fwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"),
+ ],
+ ),
+ # vLLM inference-style RMSNorm (weight dtype == activation dtype).
+ (
+ "rmsnorm_fwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "fused_add_rmsnorm_dsv3",
+ [
+ py,
+ script("benchmark_fused_add_rmsnorm_sm100.py"),
+ *common,
+ "--dsv3",
+ "--quack-baseline",
+ "kernel_inplace",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_quack_suite",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_dsv3",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_dsv3.json"),
+ ],
+ ),
+ ]
+
+ print(f"Writing results to: {out_dir}", flush=True)
+ for name, cmd in runs:
+ print(f"\n== {name} ==", flush=True)
+ _run(cmd, dry_run=bool(args.dry_run))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py
new file mode 100644
index 0000000..684694d
--- /dev/null
+++ b/oink/benchmarks/readme/summarize_results.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import argparse
+import json
+import math
+import os
+from typing import Any, Dict, Iterable, List, Optional, Sequence
+
+
+def _load_json(path: str) -> Dict[str, Any]:
+ with open(path) as f:
+ return json.load(f)
+
+
+def _fmt_cell(v: object) -> str:
+ if v is None:
+ return ""
+ if isinstance(v, float):
+ if math.isfinite(v):
+ av = abs(v)
+ # Use scientific notation for very small values so we don't render
+ # meaningful error stats as "0.0000".
+ if av != 0.0 and av < 1e-3:
+ return f"{v:.2e}"
+ return f"{v:.4f}"
+ return str(v)
+ return str(v)
+
+
+def _md_table(rows: Sequence[Dict[str, Any]], columns: Sequence[str]) -> str:
+ header = "| " + " | ".join(columns) + " |"
+ sep = "|" + "|".join(["---"] * len(columns)) + "|"
+ lines = [header, sep]
+ for r in rows:
+ lines.append("| " + " | ".join(_fmt_cell(r.get(c)) for c in columns) + " |")
+ return "\n".join(lines)
+
+
+def _pick_columns(rows: Sequence[Dict[str, Any]]) -> List[str]:
+ preferred = [
+ "M",
+ "N",
+ "dtype",
+ "weight_dtype",
+ "mode",
+ "eps",
+ "store_rstd",
+ "return_rstd",
+ "return_mean",
+ "ignore_index",
+ "ours_ms",
+ "ours_tbps",
+ "ours_hbm_frac",
+ "quack_ms",
+ "quack_tbps",
+ "speedup_vs_quack",
+ ]
+ present = set().union(*(r.keys() for r in rows)) if rows else set()
+ cols = [c for c in preferred if c in present]
+ # Fall back to a stable sorted view if we missed everything (shouldn't happen).
+ return cols or sorted(present)
+
+
+def _geomean(values: Iterable[float]) -> Optional[float]:
+ logs: List[float] = []
+ for v in values:
+ if v <= 0 or not math.isfinite(v):
+ continue
+ logs.append(math.log(v))
+ if not logs:
+ return None
+ return math.exp(sum(logs) / len(logs))
+
+
+def _collect_error_prefixes(rows: Sequence[Dict[str, Any]]) -> List[str]:
+ """Infer error-stat prefixes like `ours_err_dx` from row keys."""
+ prefixes: set[str] = set()
+ for r in rows:
+ for k in r.keys():
+ if not isinstance(k, str):
+ continue
+ if not k.endswith("_max_abs"):
+ continue
+ if "err_" not in k:
+ continue
+ prefixes.add(k[: -len("_max_abs")])
+ return sorted(prefixes)
+
+
+def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str:
+ prefixes = _collect_error_prefixes(rows)
+ if not prefixes:
+ return ""
+
+ out_rows: List[Dict[str, Any]] = []
+ for pfx in prefixes:
+ # Per-prefix worst-case across rows.
+ max_abs_vals = [
+ float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r
+ ]
+ p99_abs_vals = [
+ float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r
+ ]
+ rel_l2_vals = [
+ float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r
+ ]
+ if not max_abs_vals and not p99_abs_vals and not rel_l2_vals:
+ continue
+ out_rows.append(
+ {
+ "metric": pfx,
+ "max_abs (max over shapes)": max(max_abs_vals)
+ if max_abs_vals
+ else None,
+ "p99_abs (max over shapes)": max(p99_abs_vals)
+ if p99_abs_vals
+ else None,
+ "rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None,
+ }
+ )
+
+ if not out_rows:
+ return ""
+
+ cols = [
+ "metric",
+ "max_abs (max over shapes)",
+ "p99_abs (max over shapes)",
+ "rel_l2 (max over shapes)",
+ ]
+ return "\n".join(
+ ["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""]
+ )
+
+
+def summarize_one(path: str) -> str:
+ payload = _load_json(path)
+ meta = payload.get("meta", {})
+ rows = payload.get("rows", [])
+ if not isinstance(rows, list):
+ rows = []
+
+ cols = _pick_columns(rows)
+ parts: List[str] = []
+
+ base = os.path.basename(path)
+ parts.append(f"## `{base}`")
+ if meta:
+ device = meta.get("device")
+ cap = meta.get("capability")
+ torch_ver = meta.get("torch")
+ cuda_ver = meta.get("cuda")
+ git_sha = meta.get("git_sha")
+ ts = meta.get("timestamp")
+ parts.append("")
+ parts.append(
+ f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`"
+ )
+ method = meta.get("method")
+ if method is not None:
+ parts.append(f"- method: `{method}`")
+ if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None:
+ parts.append(
+ f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`"
+ )
+
+ if rows:
+ parts.append("")
+ parts.append(_md_table(rows, cols))
+
+ speeds = [float(r["speedup_vs_quack"]) for r in rows if "speedup_vs_quack" in r]
+ gm = _geomean(speeds)
+ if gm is not None:
+ parts.append("")
+ parts.append(
+ f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)"
+ )
+
+ err_block = _summarize_error_stats(rows)
+ if err_block:
+ parts.append(err_block.rstrip())
+ else:
+ parts.append("")
+ parts.append("_No rows found in JSON._")
+
+ parts.append("")
+ return "\n".join(parts)
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(
+ description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables."
+ )
+ p.add_argument(
+ "--in-dir",
+ type=str,
+ required=True,
+ help="Directory containing benchmark JSON files",
+ )
+ p.add_argument(
+ "--out",
+ type=str,
+ default=None,
+ help="Optional output markdown path (default: stdout)",
+ )
+ args = p.parse_args()
+
+ in_dir = os.path.abspath(args.in_dir)
+ if not os.path.isdir(in_dir):
+ raise SystemExit(f"--in-dir is not a directory: {in_dir}")
+
+ json_paths = sorted(
+ os.path.join(in_dir, name)
+ for name in os.listdir(in_dir)
+ if name.endswith(".json")
+ )
+ if not json_paths:
+ raise SystemExit(f"No .json files found under: {in_dir}")
+
+ out_parts: List[str] = []
+ out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary")
+ out_parts.append("")
+ out_parts.append(f"Input directory: `{in_dir}`")
+ out_parts.append("")
+ for path in json_paths:
+ out_parts.append(summarize_one(path))
+
+ text = "\n".join(out_parts).rstrip() + "\n"
+ if args.out is None:
+ print(text, end="")
+ return
+
+ out_path = os.path.abspath(args.out)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ with open(out_path, "w") as f:
+ f.write(text)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/pyproject.toml b/oink/pyproject.toml
new file mode 100644
index 0000000..0d19d6e
--- /dev/null
+++ b/oink/pyproject.toml
@@ -0,0 +1,51 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "kernelagent-oink"
+version = "0.1.0"
+description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin"
+readme = "README.md"
+requires-python = ">=3.10"
+license = {text = "Apache-2.0"}
+authors = [{name = "PyTorch Labs"}]
+keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"]
+classifiers = [
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+
+[project.urls]
+Repository = "https://github.com/meta-pytorch/KernelAgent"
+Documentation = "https://github.com/meta-pytorch/KernelAgent/tree/main/oink"
+Issues = "https://github.com/meta-pytorch/KernelAgent/issues"
+
+# Keep dependencies minimal, but include the CuTeDSL stack required by the
+# Blackwell RMSNorm implementation.
+#
+# We intentionally do NOT depend on `torch` here because vLLM already pins and
+# provides a compatible PyTorch build.
+dependencies = [
+ "nvidia-cutlass-dsl",
+ "cuda-python",
+]
+
+[project.optional-dependencies]
+# Optional extras for running the in-repo benchmark suite (not needed for vLLM integration).
+bench = [
+ "matplotlib",
+ "triton",
+]
+
+[project.entry-points."vllm.general_plugins"]
+oink = "kernelagent_oink:register"
+
+[tool.setuptools.packages.find]
+where = ["src"]
+include = ["kernelagent_oink*"]
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
new file mode 100644
index 0000000..f5c36a6
--- /dev/null
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -0,0 +1,127 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin.
+
+This package can be loaded as a vLLM "general plugin" (entrypoint group
+`vllm.general_plugins`). In that mode it registers Oink custom ops only when
+explicitly enabled via an environment variable (so installing the package does
+not change behavior by default).
+
+For standalone usage (outside vLLM), call `kernelagent_oink.register(force=True)`
+to register the custom ops explicitly.
+"""
+
+from __future__ import annotations
+
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+_OPS_REGISTERED = False
+
+
+def _env_truthy(name: str) -> bool:
+ val = os.environ.get(name)
+ if val is None:
+ return False
+ return val.strip().lower() in ("1", "true", "yes", "on")
+
+
+def _infer_cuda_device_index() -> int:
+ local_rank = os.environ.get("LOCAL_RANK")
+ if local_rank is not None:
+ try:
+ return int(local_rank)
+ except ValueError:
+ pass
+ return 0
+
+
+def _compute_cutedsl_arch(major: int, minor: int) -> str:
+ # CuTeDSL uses an "a" suffix for >= Hopper.
+ suffix = "a" if major >= 9 else ""
+ # Match cutlass/base_dsl/env_manager.py: map sm_110 -> sm_101.
+ if major == 11 and minor == 0:
+ major, minor = 10, 1
+ return f"sm_{major}{minor}{suffix}"
+
+
+def register(*, force: bool = False) -> None:
+ """Register Oink torch custom ops.
+
+ - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy.
+ - Standalone mode: pass `force=True` to register explicitly.
+
+ This function must be safe to call multiple times and must not raise. vLLM
+ executes it in multiple processes (engine + workers).
+ """
+ global _OPS_REGISTERED
+
+ if _OPS_REGISTERED:
+ return
+
+ # Gate on the vLLM integration flag so installing the package does not
+ # change behavior unless explicitly enabled. For standalone usage (outside
+ # vLLM), callers can pass force=True to register the ops explicitly.
+ if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"):
+ return
+
+ try:
+ import torch
+ except Exception as e: # pragma: no cover
+ logger.debug("Oink plugin: torch import failed: %s", e)
+ return
+
+ try:
+ if not torch.cuda.is_available():
+ return
+ device_index = _infer_cuda_device_index()
+ major, minor = torch.cuda.get_device_capability(device_index)
+ sm = 10 * int(major) + int(minor)
+ if sm < 100:
+ return
+
+ # Ensure required deps are importable before registering ops so that vLLM
+ # doesn't detect ops that would later fail at first use.
+ try:
+ import cutlass # noqa: F401
+ import cuda.bindings.driver as _cuda # noqa: F401
+ except Exception as e:
+ logger.warning(
+ "Oink plugin: CuTeDSL deps missing; skipping op registration. "
+ "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s",
+ e,
+ )
+ return
+
+ # Ensure CuTeDSL sees a target arch early. If the user has already set it,
+ # respect their choice.
+ os.environ.setdefault(
+ "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor))
+ )
+
+ # Import registers the ops via torch.library.custom_op decorators.
+ from .blackwell import oink_custom_ops # noqa: F401
+ except Exception as e: # pragma: no cover
+ # Do not raise: vLLM plugin loader does not guard plugin execution.
+ logger.exception("Oink plugin: failed to register ops: %s", e)
+ return
+
+ _OPS_REGISTERED = True
+
+
+__all__ = ["register"]
diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py
new file mode 100644
index 0000000..a92109a
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+__all__ = []
diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
new file mode 100644
index 0000000..d8b37ea
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
@@ -0,0 +1,2257 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Cross-entropy forward + backward kernels for SM100 (Blackwell) in CuteDSL.
+
+This module implements numerically stable cross-entropy over the last
+dimension of 2D logits tensors `(M, N)` together with its backward pass,
+targeting SM100 with Quack-style tiling, cp.async pipelines, and (for the
+forward pass) optional cluster-wide online softmax reductions, but without
+depending on the external `quack` package at runtime.
+
+Public APIs:
+
+- ``cross_entropy_forward(logits, target, ignore_index=-100, reduction="none")``
+ returns ``(loss, lse)`` where ``loss`` follows the requested reduction and
+ ``lse`` is always per-example log-sum-exp (shape ``(M,)``).
+- ``cross_entropy_backward(dloss, logits, target, lse, ignore_index=-100)``
+ returns per-logit gradients ``dlogits`` matching PyTorch /
+ ``quack.cross_entropy_bwd`` semantics for ``reduction="none"``.
+- ``cross_entropy(logits, target, ignore_index=-100, reduction="mean"|"sum"|"none")``
+ is a convenience wrapper that mirrors ``torch.nn.functional.cross_entropy``
+ reductions using the SM100 CuteDSL kernels for the forward pass.
+
+The kernels are self-contained and use only local helpers in
+`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import math
+import os
+import re
+from typing import Literal, Optional, Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.cross_entropy requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Boolean, Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell.fast_launch import (
+ StableI32Arg,
+ disable_fast_launch,
+ fast_launch_enabled,
+ set_runtime_ptr,
+ tls_cache as _tls_fast_launch_cache,
+)
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase,
+ fill_oob,
+ online_softmax_reduce,
+ predicate_k,
+)
+
+_FWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {}
+_BWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {}
+_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+class _PtrCrossEntropyFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_logits: object,
+ ptr_target: object,
+ ptr_aux_a: object,
+ ptr_aux_b: object,
+ ptr_aux_c: object | None,
+ arg_m: StableI32Arg,
+ arg_ld: StableI32Arg,
+ arg_ignore_index: StableI32Arg,
+ stream: cuda.CUstream,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ logits_align: int,
+ target_align: int,
+ aux_a_align: int,
+ aux_b_align: int,
+ aux_c_align: int | None,
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_logits = ptr_logits
+ self._ptr_target = ptr_target
+ self._ptr_aux_a = ptr_aux_a
+ self._ptr_aux_b = ptr_aux_b
+ self._ptr_aux_c = ptr_aux_c
+ self._arg_m = arg_m
+ self._arg_ld = arg_ld
+ self._arg_ignore_index = arg_ignore_index
+ self._stream = stream
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+ self._logits_align = int(logits_align)
+ self._target_align = int(target_align)
+ self._aux_a_align = int(aux_a_align)
+ self._aux_b_align = int(aux_b_align)
+ self._aux_c_align = int(aux_c_align) if aux_c_align is not None else None
+
+ self._use_fast_launch = True
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_logits_ptr = -1
+ self._last_target_ptr = -1
+ self._last_aux_a_ptr = -1
+ self._last_aux_b_ptr = -1
+ self._last_aux_c_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+ self._last_ignore_index = None
+
+ def launch(
+ self,
+ *,
+ logits_ptr: int,
+ target_ptr: int,
+ aux_a_ptr: int,
+ aux_b_ptr: int,
+ aux_c_ptr: int | None,
+ M: int,
+ ld: int,
+ ignore_index: int,
+ stream_handle: int,
+ dtype_logits: type[cutlass.Numeric],
+ aux_a_dtype: type[cutlass.Numeric],
+ aux_b_dtype: type[cutlass.Numeric],
+ aux_c_dtype: type[cutlass.Numeric] | None,
+ ) -> None:
+ if not fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if logits_ptr != self._last_logits_ptr:
+ try:
+ set_runtime_ptr(self._ptr_logits, logits_ptr)
+ self._last_logits_ptr = logits_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if target_ptr != self._last_target_ptr:
+ try:
+ set_runtime_ptr(self._ptr_target, target_ptr)
+ self._last_target_ptr = target_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if aux_a_ptr != self._last_aux_a_ptr:
+ try:
+ set_runtime_ptr(self._ptr_aux_a, aux_a_ptr)
+ self._last_aux_a_ptr = aux_a_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if aux_b_ptr != self._last_aux_b_ptr:
+ try:
+ set_runtime_ptr(self._ptr_aux_b, aux_b_ptr)
+ self._last_aux_b_ptr = aux_b_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if self._ptr_aux_c is not None and aux_c_ptr is not None:
+ if aux_c_ptr != self._last_aux_c_ptr:
+ try:
+ set_runtime_ptr(self._ptr_aux_c, aux_c_ptr)
+ self._last_aux_c_ptr = aux_c_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ logits_ptr=logits_ptr,
+ target_ptr=target_ptr,
+ aux_a_ptr=aux_a_ptr,
+ aux_b_ptr=aux_b_ptr,
+ aux_c_ptr=aux_c_ptr,
+ M=M,
+ ld=ld,
+ ignore_index=ignore_index,
+ stream_handle=stream_handle,
+ dtype_logits=dtype_logits,
+ aux_a_dtype=aux_a_dtype,
+ aux_b_dtype=aux_b_dtype,
+ aux_c_dtype=aux_c_dtype,
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+ if ignore_index != self._last_ignore_index:
+ self._arg_ignore_index.set(ignore_index)
+ self._last_ignore_index = int(ignore_index)
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ self._use_fast_launch = False
+ disable_fast_launch()
+
+ def _fallback_launch(
+ self,
+ *,
+ logits_ptr: int,
+ target_ptr: int,
+ aux_a_ptr: int,
+ aux_b_ptr: int,
+ aux_c_ptr: int | None,
+ M: int,
+ ld: int,
+ ignore_index: int,
+ stream_handle: int,
+ dtype_logits: type[cutlass.Numeric],
+ aux_a_dtype: type[cutlass.Numeric],
+ aux_b_dtype: type[cutlass.Numeric],
+ aux_c_dtype: type[cutlass.Numeric] | None,
+ ) -> None:
+ stream = cuda.CUstream(int(stream_handle))
+ ptr_logits = rt.make_ptr(
+ dtype_logits,
+ int(logits_ptr),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._logits_align,
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ int(target_ptr),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._target_align,
+ )
+ ptr_aux_a = rt.make_ptr(
+ aux_a_dtype,
+ int(aux_a_ptr),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._aux_a_align,
+ )
+ ptr_aux_b = rt.make_ptr(
+ aux_b_dtype,
+ int(aux_b_ptr),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._aux_b_align,
+ )
+ if (
+ self._ptr_aux_c is not None
+ and aux_c_ptr is not None
+ and aux_c_dtype is not None
+ ):
+ ptr_aux_c = rt.make_ptr(
+ aux_c_dtype,
+ int(aux_c_ptr),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=int(self._aux_c_align or 0),
+ )
+ self._compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_aux_a,
+ ptr_aux_b,
+ ptr_aux_c,
+ Int32(int(M)),
+ Int32(int(ld)),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ else:
+ self._compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_aux_a,
+ ptr_aux_b,
+ Int32(int(M)),
+ Int32(int(ld)),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def _get_fast_ptr_cross_entropy_launcher(
+ *,
+ compiled: object,
+ dtype_logits: type[cutlass.Numeric],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ mode: Literal["fwd", "bwd", "fwd_bwd"],
+) -> _PtrCrossEntropyFastLaunch | None:
+ if not fast_launch_enabled():
+ return None
+ key = (
+ f"ptr_fast_{mode}",
+ id(compiled),
+ int(N),
+ dtype_logits,
+ int(device_index),
+ int(stream_handle),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ ptr_logits = rt.make_ptr(
+ dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64, 0, mem_space=rt.AddressSpace.gmem, assumed_align=8
+ )
+ if mode == "fwd":
+ ptr_aux_a = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4
+ ) # loss
+ ptr_aux_b = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4
+ ) # lse
+ ptr_aux_c = None
+ aux_align_b = 4
+ aux_align_c = None
+ elif mode == "bwd":
+ ptr_aux_a = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4
+ ) # dloss
+ ptr_aux_b = rt.make_ptr(
+ dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16
+ ) # dx
+ ptr_aux_c = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4
+ ) # lse
+ aux_align_b = 16
+ aux_align_c = 4
+ elif mode == "fwd_bwd":
+ ptr_aux_a = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4
+ ) # dloss
+ ptr_aux_b = rt.make_ptr(
+ dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16
+ ) # dx
+ ptr_aux_c = None
+ aux_align_b = 16
+ aux_align_c = None
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ arg_m = StableI32Arg(0)
+ arg_ld = StableI32Arg(N)
+ arg_ignore_index = StableI32Arg(-100)
+ stream = cuda.CUstream(int(stream_handle))
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ try:
+ if ptr_aux_c is not None:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_logits,
+ ptr_target,
+ ptr_aux_a,
+ ptr_aux_b,
+ ptr_aux_c,
+ arg_m,
+ arg_ld,
+ arg_ignore_index,
+ stream,
+ )
+ else:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_logits,
+ ptr_target,
+ ptr_aux_a,
+ ptr_aux_b,
+ arg_m,
+ arg_ld,
+ arg_ignore_index,
+ stream,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ disable_fast_launch()
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_logits,
+ ptr_target,
+ ptr_aux_a,
+ ptr_aux_b,
+ ptr_aux_c,
+ arg_m,
+ arg_ld,
+ arg_ignore_index,
+ stream,
+ *adapted_args,
+ )
+ launcher = _PtrCrossEntropyFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_logits=ptr_logits,
+ ptr_target=ptr_target,
+ ptr_aux_a=ptr_aux_a,
+ ptr_aux_b=ptr_aux_b,
+ ptr_aux_c=ptr_aux_c,
+ arg_m=arg_m,
+ arg_ld=arg_ld,
+ arg_ignore_index=arg_ignore_index,
+ stream=stream,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ logits_align=16,
+ target_align=8,
+ aux_a_align=4,
+ aux_b_align=aux_align_b,
+ aux_c_align=aux_align_c,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+def _convert_logits_2d(x: Tensor) -> cute.Tensor:
+ """Convert a 2D logits tensor (M, N) into a CuTe tensor.
+
+ We assume 16-byte alignment and mark the layout compact and row-major
+ in the last dimension, matching the conventions used in the SM100
+ softmax and RMSNorm kernels.
+ """
+ assert x.dim() == 2, "Input logits must be 2D (M, N)"
+ return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0, stride_order=(0, 1)
+ )
+
+
+def _convert_1d(t: Tensor, assumed_align: int) -> cute.Tensor:
+ """Convert a 1D tensor with a fully dynamic layout."""
+ assert t.dim() == 1, "Expected a 1D tensor"
+ return from_dlpack(t.detach(), assumed_align=assumed_align).mark_layout_dynamic()
+
+
+class CrossEntropyFwdSM100(ReductionBase):
+ """SM100-tuned cross-entropy forward kernel.
+
+ This mirrors the structure of ``quack.cross_entropy.CrossEntropy`` but
+ is simplified to always use the single-pass online softmax reduction and
+ never computes gradients inside the forward kernel.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # Use one stage with an Int64 reduction buffer packing (max, sum_exp)
+ # pairs via lite_quack.online_softmax_reduce.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _calculate_threads_per_row(self) -> int:
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ # Match Quack's cluster_n growth policy while keeping it explicit so
+ # we can tune SM100-specific shapes later if needed.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else: # fp32
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ self._set_cluster_n()
+ # If N is not divisible by the full 128-bit vector width, step down
+ # to the largest compatible vector size as in Quack.
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_logits: cute.Pointer,
+ ptr_target: cute.Pointer,
+ ptr_loss: cute.Pointer,
+ ptr_lse: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ mX = cute.make_tensor(ptr_logits, layout_mn)
+ mTarget = cute.make_tensor(ptr_target, layout_m)
+ mLoss = cute.make_tensor(ptr_loss, layout_m)
+ mLSE = cute.make_tensor(ptr_lse, layout_m)
+ self.__call__(mX, mTarget, mLoss, mLSE, ignore_index, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32, # Index to ignore in loss computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape: cute.Shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Quack-style CTA tiling: let CuTe compute the CTA offsets directly.
+ # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on
+ # the common inference/benchmark sizes.)
+ gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y))
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
+
+ # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed.
+ num_copy_elems_X = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ gX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems_X))
+ thr_copy_X = cute.make_tiled_copy_tv(
+ copy_atom_load_X, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+ tXrX = cute.make_fragment_like(tXgX)
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ row = tXcX[0][0]
+ target = Int32.zero
+ if row < shape[0]:
+ target = Int32(mTarget[row])
+
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ # Fill out-of-bounds values with -inf so they are ignored in max/sum.
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ should_ignore = Boolean(target == ignore_index)
+
+ # Load the target logit if this row is not ignored.
+ target_logit = Float32.zero
+ if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
+ target_logit = Float32(mX[row, target])
+
+ max_x, denom, _ = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=False,
+ )
+
+ # Write loss and lse to gmem. Only one CTA in the cluster writes to
+ # avoid duplicate stores.
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ lse = max_x + cute.math.log(denom, fastmath=True)
+ loss_val = (lse - target_logit) if not should_ignore else Float32.zero
+ mLoss[row] = mLoss.element_type(loss_val)
+ if const_expr(mLSE is not None):
+ mLSE[row] = lse
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ ) -> None:
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+class CrossEntropyFwdBwdSM100(ReductionBase):
+ """Fused cross-entropy forward+backward producing dx from (logits, target, dloss).
+
+ This avoids materializing the intermediate `lse` (and loss) in global memory
+ when the only desired output is `dx` for `reduction="none"` semantics.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _calculate_threads_per_row(self) -> int:
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ self._set_cluster_n()
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ ignore_index,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_logits: cute.Pointer,
+ ptr_target: cute.Pointer,
+ ptr_dloss: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ mX = cute.make_tensor(ptr_logits, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ mTarget = cute.make_tensor(ptr_target, layout_m)
+ mDLoss = cute.make_tensor(ptr_dloss, layout_m)
+ self.__call__(mX, mTarget, mDLoss, mdX, ignore_index, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ ignore_index: Int32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ cluster_y = (
+ const_expr(0)
+ if const_expr(self.cluster_n == 1)
+ else cute.arch.block_idx()[1]
+ )
+
+ shape: cute.Shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ gX, gdX, cX = [
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mdX, idX)
+ ]
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
+
+ num_copy_elems_X = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ gX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ copy_atom_store_dX = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems_X))
+ thr_copy_X = cute.make_tiled_copy_tv(
+ copy_atom_load_X, thr_layout, val_layout
+ ).get_slice(tidx)
+ thr_copy_dX = cute.make_tiled_copy_tv(
+ copy_atom_store_dX, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+ tXcFull = thr_copy_X.partition_S(cX)
+ tXgdX = thr_copy_dX.partition_D(gdX)
+
+ tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ row = tXcX[0][0]
+ target = Int32.zero
+ dloss = Float32.zero
+ if row < shape[0]:
+ target = Int32(mTarget[row])
+ should_ignore = Boolean(target == ignore_index)
+ dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
+
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ _max_x, denom, exp_x = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=True,
+ )
+ assert exp_x is not None
+ probs = exp_x * cute.arch.rcp_approx(denom)
+ prob_shifted = probs - 1.0
+
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
+ mask[i] = tXcFull[i][1] == target
+ grad = cute.where(mask.load(), prob_shifted, probs)
+ grad = grad * dloss
+
+ tXrdX.store(grad.to(tXrdX.element_type))
+
+ tXpdX = (
+ predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ ignore_index: Int32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ ignore_index: Int32,
+ ) -> None:
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+class CrossEntropyBackwardSM100:
+ """SM100-tuned cross-entropy backward kernel.
+
+ This is a direct port of ``quack.cross_entropy.CrossEntropyBackward`` to
+ the local lite_quack helpers, using cp.async tiling over the (M, N)
+ logits and broadcasting ``dloss`` / ``lse`` across the row dimension.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ self.dtype = dtype
+ self.N = N
+
+ def _get_num_threads(self) -> int:
+ # Keep in sync with _get_tv_layout() (we tile N in 16k blocks).
+ N = min(self.N, 16384)
+ return 128 if N <= 16384 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ N = min(self.N, 16384) # We split by blocks of 16k in N.
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
+ )
+ )
+
+ def _get_tv_layout(
+ self, num_copy_bits: int = 128
+ ) -> tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ assert self.N % vecsize == 0, (
+ f"Input N {self.N} is not divisible by vector size {vecsize}"
+ )
+ N = min(self.N, 16384)
+ num_threads = 128 if N <= 16384 else 256
+ threads_per_row = self._calculate_threads_per_row()
+ cols_per_block = num_threads // threads_per_row
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
+ tv_layout = cute.make_layout(
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
+ ),
+ )
+ return tiler_mn, tv_layout
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mTarget: cute.Tensor,
+ mDLoss: cute.Tensor,
+ mdX: cute.Tensor,
+ mLSE: cute.Tensor,
+ ignore_index: Int32, # Index to ignore in gradient computation
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ # Broadcast (M,) tensors along the N dimension with stride 0.
+ mDLoss, mTarget, mLSE = [
+ cute.make_tensor(
+ X.iterator,
+ cute.append(X.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+ for X in (mDLoss, mTarget, mLSE)
+ ]
+ smem_size = cute.size_in_bytes(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ kernel = (
+ self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ )
+ )
+ kernel.launch(
+ grid=[
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
+ 1,
+ ],
+ block=[num_threads, 1, 1],
+ smem=smem_size,
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_logits: cute.Pointer,
+ ptr_target: cute.Pointer,
+ ptr_dloss: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ ptr_lse: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ mX = cute.make_tensor(ptr_logits, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ mTarget = cute.make_tensor(ptr_target, layout_m)
+ mDLoss = cute.make_tensor(ptr_dloss, layout_m)
+ mLSE = cute.make_tensor(ptr_lse, layout_m)
+ self.__call__(mX, mTarget, mDLoss, mdX, mLSE, ignore_index, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, bidy, _ = cute.arch.block_idx()
+ shape = mX.shape
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+
+ idX = cute.make_identity_tensor(shape)
+ # Quack-style CTA tiling: avoid extra 64-bit address arithmetic by
+ # letting CuTe compute the CTA offsets directly.
+ gX, gdX, cX = [
+ cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)
+ ]
+
+ num_copy_elems_X = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ gX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ copy_atom_store_dX = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ thr_copy_X = cute.make_tiled_copy(
+ copy_atom_load_X, tv_layout, tiler_mn
+ ).get_slice(tidx)
+ thr_copy_dX = cute.make_tiled_copy(
+ copy_atom_store_dX, tv_layout, tiler_mn
+ ).get_slice(tidx)
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+ tXcFull = thr_copy_X.partition_S(cX)
+ tXgdX = thr_copy_dX.partition_D(gdX)
+
+ tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
+
+ is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
+ row = tXcX[0][0]
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ target = Int32.zero
+ dloss = Float32.zero
+ lse = Float32.zero
+ if row < shape[0]:
+ target = Int32(mTarget[row])
+ should_ignore = Boolean(target == ignore_index)
+ dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
+ lse = Float32(mLSE[row])
+
+ log2_e = math.log2(math.e)
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
+ prob_shifted = probs - 1.0
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
+ mask[i] = tXcFull[i][1] == target
+ grad = cute.where(mask.load(), prob_shifted, probs)
+ grad = grad * dloss
+
+ tXrdX.store(grad.to(tXrdX.element_type))
+ tXpdX = (
+ predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ ) -> None:
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+def cross_entropy_forward(
+ logits: Tensor,
+ target: Tensor,
+ ignore_index: int = -100,
+ reduction: Literal["none", "mean", "sum"] = "none",
+) -> tuple[Tensor, Tensor]:
+ """SM100 CuteDSL cross-entropy forward pass.
+
+ Args:
+ logits: Tensor of shape ``(M, N)`` on CUDA.
+ target: Tensor of shape ``(M,)`` with integer class indices.
+ ignore_index: Target value to ignore when computing the loss.
+ reduction: One of ``"none"``, ``"mean"``, or ``"sum"`` following
+ ``torch.nn.functional.cross_entropy`` semantics.
+
+ Returns:
+ A tuple ``(loss, lse)`` where:
+ - ``loss`` has shape ``(M,)`` if ``reduction="none"`` or is a scalar
+ otherwise.
+ - ``lse`` is the per-example log-sum-exp with shape ``(M,)``.
+ """
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert target.dim() == 1, "target must be 1D (M,)"
+ assert logits.shape[0] == target.shape[0], "Batch dimensions must match"
+ assert logits.is_cuda and target.is_cuda, "logits and target must be on CUDA device"
+ assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype"
+ assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64"
+
+ M, N = logits.shape
+ device = logits.device
+ dtype_cute = TORCH2CUTE_DTYPE[logits.dtype]
+
+ loss = torch.empty(M, device=device, dtype=torch.float32)
+ lse = torch.empty(M, device=device, dtype=torch.float32)
+
+ if _can_use_ptr_path_logits(logits) and _can_use_ptr_path_target(target):
+ _cross_entropy_forward_ptr_into(
+ logits=logits,
+ target=target,
+ loss=loss,
+ lse=lse,
+ ignore_index=int(ignore_index),
+ )
+ if reduction == "none":
+ return loss, lse
+ with torch.no_grad():
+ mask = target != ignore_index
+ if reduction == "sum":
+ reduced = loss.sum()
+ elif reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ reduced = loss[mask].sum() / valid.to(loss.dtype)
+ else:
+ reduced = loss.sum() * 0.0
+ else:
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'."
+ )
+ return reduced, lse
+
+ mX = _convert_logits_2d(logits)
+ mTarget = _convert_1d(target.to(torch.int64), assumed_align=8)
+ mLoss = _convert_1d(loss, assumed_align=4)
+ mLSE = _convert_1d(lse, assumed_align=4)
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype_cute, N)
+ kernel = _FWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = CrossEntropyFwdSM100(dtype_cute, N)
+ kernel = cute.compile(
+ op,
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ Int32(ignore_index),
+ current_stream,
+ )
+ _FWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(mX, mTarget, mLoss, mLSE, Int32(ignore_index), current_stream)
+
+ if reduction == "none":
+ return loss, lse
+
+ with torch.no_grad():
+ mask = target != ignore_index
+ if reduction == "sum":
+ reduced = loss.sum()
+ elif reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ reduced = loss[mask].sum() / valid.to(loss.dtype)
+ else:
+ reduced = loss.sum() * 0.0
+ else:
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'."
+ )
+ return reduced, lse
+
+
+def _cross_entropy_backward_sm100(
+ logits: Tensor,
+ target: Tensor,
+ dloss: Tensor,
+ lse: Tensor,
+ dx: Tensor,
+ ignore_index: int = -100,
+) -> None:
+ """Internal SM100 cross-entropy backward dispatch using CuteDSL."""
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert target.dim() == 1, "target must be 1D (M,)"
+ assert dloss.dim() == 1, "dloss must be 1D (M,)"
+ assert lse.dim() == 1, "lse must be 1D (M,)"
+ assert logits.shape[0] == target.shape[0] == dloss.shape[0] == lse.shape[0], (
+ "Batch dimensions must match"
+ )
+ assert logits.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
+ "All tensors must be on CUDA device"
+ )
+ assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype"
+ assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64"
+
+ M, N = logits.shape
+ dtype_cute = TORCH2CUTE_DTYPE[logits.dtype]
+
+ if (
+ _can_use_ptr_path_logits(logits)
+ and _can_use_ptr_path_logits(dx)
+ and _can_use_ptr_path_target(target)
+ and _can_use_ptr_path_f32_1d(dloss)
+ and _can_use_ptr_path_f32_1d(lse)
+ and logits.stride() == dx.stride()
+ ):
+ _cross_entropy_backward_ptr_into(
+ logits=logits,
+ target=target,
+ dloss=dloss,
+ lse=lse,
+ dx=dx,
+ ignore_index=int(ignore_index),
+ )
+ return
+
+ mX = _convert_logits_2d(logits)
+ mdX = _convert_logits_2d(dx)
+ mTarget = _convert_1d(target.to(torch.int64), assumed_align=8)
+ mDLoss = _convert_1d(dloss.to(torch.float32), assumed_align=4)
+ mLSE = _convert_1d(lse.to(torch.float32), assumed_align=4)
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype_cute, N)
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = CrossEntropyBackwardSM100(dtype_cute, N)
+ kernel = cute.compile(
+ op,
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ Int32(ignore_index),
+ current_stream,
+ )
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(mX, mTarget, mDLoss, mdX, mLSE, Int32(ignore_index), current_stream)
+
+
+def _can_use_ptr_path_logits(x: Tensor) -> bool:
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.dtype not in TORCH2CUTE_DTYPE:
+ return False
+ if x.stride(1) != 1:
+ return False
+ if (x.data_ptr() % 16) != 0:
+ return False
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_target(t: Tensor) -> bool:
+ if not t.is_cuda or t.dim() != 1:
+ return False
+ if t.dtype is not torch.int64:
+ return False
+ if not t.is_contiguous():
+ return False
+ if t.stride(0) != 1:
+ return False
+ if (t.data_ptr() % 8) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_f32_1d(t: Tensor) -> bool:
+ if not t.is_cuda or t.dim() != 1:
+ return False
+ if t.dtype is not torch.float32:
+ return False
+ if not t.is_contiguous():
+ return False
+ if t.stride(0) != 1:
+ return False
+ if (t.data_ptr() % 4) != 0:
+ return False
+ return True
+
+
+def _cross_entropy_forward_ptr_into(
+ *,
+ logits: Tensor,
+ target: Tensor,
+ loss: Tensor,
+ lse: Tensor,
+ ignore_index: int,
+) -> None:
+ assert logits.is_cuda and logits.dim() == 2
+ assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
+ assert target.dtype is torch.int64
+ assert (
+ loss.is_cuda
+ and loss.shape == (logits.shape[0],)
+ and loss.dtype is torch.float32
+ )
+ assert (
+ lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ )
+
+ M, N = logits.shape
+ device_index = logits.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[logits.dtype]
+ key = ("ptr_fwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = CrossEntropyFwdSM100(dtype_x, int(N))
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_loss = rt.make_ptr(
+ cutlass.Float32,
+ loss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_logits,
+ ptr_target,
+ ptr_loss,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ _PTR_FWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_cross_entropy_launcher(
+ compiled=compiled,
+ dtype_logits=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ mode="fwd",
+ )
+ if launcher is not None:
+ launcher.launch(
+ logits_ptr=int(logits.data_ptr()),
+ target_ptr=int(target.data_ptr()),
+ aux_a_ptr=int(loss.data_ptr()),
+ aux_b_ptr=int(lse.data_ptr()),
+ aux_c_ptr=None,
+ M=int(M),
+ ld=int(logits.stride(0)),
+ ignore_index=int(ignore_index),
+ stream_handle=stream_handle,
+ dtype_logits=dtype_x,
+ aux_a_dtype=cutlass.Float32,
+ aux_b_dtype=cutlass.Float32,
+ aux_c_dtype=None,
+ )
+ return
+
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_loss = rt.make_ptr(
+ cutlass.Float32,
+ loss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_loss,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def _cross_entropy_backward_ptr_into(
+ *,
+ logits: Tensor,
+ target: Tensor,
+ dloss: Tensor,
+ lse: Tensor,
+ dx: Tensor,
+ ignore_index: int,
+) -> None:
+ assert logits.is_cuda and logits.dim() == 2
+ assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
+ assert target.dtype is torch.int64
+ assert (
+ dloss.is_cuda
+ and dloss.shape == (logits.shape[0],)
+ and dloss.dtype is torch.float32
+ )
+ assert (
+ lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ )
+ assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype
+ assert dx.stride() == logits.stride(), (
+ "Pointer path expects dx to match logits strides"
+ )
+
+ M, N = logits.shape
+ device_index = logits.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[logits.dtype]
+ key = ("ptr_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_BWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = CrossEntropyBackwardSM100(dtype_x, int(N))
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ _PTR_BWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_cross_entropy_launcher(
+ compiled=compiled,
+ dtype_logits=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ mode="bwd",
+ )
+ if launcher is not None:
+ launcher.launch(
+ logits_ptr=int(logits.data_ptr()),
+ target_ptr=int(target.data_ptr()),
+ aux_a_ptr=int(dloss.data_ptr()),
+ aux_b_ptr=int(dx.data_ptr()),
+ aux_c_ptr=int(lse.data_ptr()),
+ M=int(M),
+ ld=int(logits.stride(0)),
+ ignore_index=int(ignore_index),
+ stream_handle=stream_handle,
+ dtype_logits=dtype_x,
+ aux_a_dtype=cutlass.Float32,
+ aux_b_dtype=dtype_x,
+ aux_c_dtype=cutlass.Float32,
+ )
+ return
+
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def _cross_entropy_fwd_bwd_ptr_into(
+ *,
+ logits: Tensor,
+ target: Tensor,
+ dloss: Tensor,
+ dx: Tensor,
+ ignore_index: int,
+) -> None:
+ """Launch the fused pointer-based cross-entropy fwd+bwd kernel into preallocated `dx`."""
+ assert logits.is_cuda and logits.dim() == 2
+ assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
+ assert target.dtype is torch.int64
+ assert (
+ dloss.is_cuda
+ and dloss.shape == (logits.shape[0],)
+ and dloss.dtype is torch.float32
+ )
+ assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype
+ assert dx.stride() == logits.stride(), (
+ "Pointer path expects dx to match logits strides"
+ )
+
+ M, N = logits.shape
+ device_index = logits.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[logits.dtype]
+ key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = CrossEntropyFwdBwdSM100(dtype_x, int(N))
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ _PTR_FWDBWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_cross_entropy_launcher(
+ compiled=compiled,
+ dtype_logits=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ mode="fwd_bwd",
+ )
+ if launcher is not None:
+ launcher.launch(
+ logits_ptr=int(logits.data_ptr()),
+ target_ptr=int(target.data_ptr()),
+ aux_a_ptr=int(dloss.data_ptr()),
+ aux_b_ptr=int(dx.data_ptr()),
+ aux_c_ptr=None,
+ M=int(M),
+ ld=int(logits.stride(0)),
+ ignore_index=int(ignore_index),
+ stream_handle=stream_handle,
+ dtype_logits=dtype_x,
+ aux_a_dtype=cutlass.Float32,
+ aux_b_dtype=dtype_x,
+ aux_c_dtype=None,
+ )
+ return
+
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def cross_entropy_backward(
+ dloss: Tensor,
+ logits: Tensor,
+ target: Tensor,
+ lse: Tensor,
+ ignore_index: int = -100,
+) -> Tensor:
+ """SM100 CuteDSL cross-entropy backward pass.
+
+ Args:
+ dloss: Upstream gradient of shape ``(M,)`` corresponding to
+ ``reduction="none"``.
+ logits: Input logits tensor of shape ``(M, N)``.
+ target: Integer class indices of shape ``(M,)``.
+ lse: Per-example log-sum-exp tensor of shape ``(M,)`` as returned
+ by :func:`cross_entropy_forward`.
+ ignore_index: Target value to ignore in gradient computation.
+
+ Returns:
+ ``dlogits`` of shape ``(M, N)`` with the same dtype as ``logits``.
+ """
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert dloss.dim() == 1, "dloss must be 1D (M,)"
+ assert logits.size(0) == dloss.size(0), "Batch dimensions must match"
+ assert logits.is_cuda and dloss.is_cuda, "logits and dloss must be on CUDA device"
+
+ dx = torch.empty_like(logits)
+ _cross_entropy_backward_sm100(
+ logits,
+ target,
+ dloss,
+ lse,
+ dx,
+ ignore_index=ignore_index,
+ )
+ return dx
+
+
+def cross_entropy_fwd_bwd(
+ dloss: Tensor,
+ logits: Tensor,
+ target: Tensor,
+ ignore_index: int = -100,
+) -> Tensor:
+ """Fused cross-entropy forward+backward producing ``dx`` for ``reduction='none'``.
+
+ Computes per-logit gradients ``dx`` given:
+ - ``logits``: (M, N)
+ - ``target``: (M,)
+ - ``dloss``: (M,) upstream gradients (float32 recommended)
+
+ The fast path avoids materializing intermediate ``lse`` in global memory.
+ """
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert target.dim() == 1, "target must be 1D (M,)"
+ assert dloss.dim() == 1, "dloss must be 1D (M,)"
+ assert logits.shape[0] == target.shape[0] == dloss.shape[0], (
+ "Batch dimensions must match"
+ )
+ assert logits.is_cuda and target.is_cuda and dloss.is_cuda, (
+ "All tensors must be on CUDA device"
+ )
+ assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype"
+
+ dx = torch.empty_like(logits)
+
+ if (
+ _can_use_ptr_path_logits(logits)
+ and _can_use_ptr_path_logits(dx)
+ and _can_use_ptr_path_target(target)
+ and _can_use_ptr_path_f32_1d(dloss)
+ and logits.stride() == dx.stride()
+ ):
+ _cross_entropy_fwd_bwd_ptr_into(
+ logits=logits,
+ target=target,
+ dloss=dloss,
+ dx=dx,
+ ignore_index=int(ignore_index),
+ )
+ return dx
+
+ # Fallback: reuse the existing forward+backward kernels (DLPack path handles
+ # any necessary dtype conversions).
+ with torch.no_grad():
+ _loss, lse = cross_entropy_forward(
+ logits,
+ target,
+ ignore_index=int(ignore_index),
+ reduction="none",
+ )
+ _cross_entropy_backward_sm100(
+ logits,
+ target,
+ dloss,
+ lse,
+ dx,
+ ignore_index=int(ignore_index),
+ )
+ return dx
+
+
+def cross_entropy(
+ logits: Tensor,
+ target: Tensor,
+ ignore_index: int = -100,
+ reduction: Literal["none", "mean", "sum"] = "mean",
+) -> Tensor:
+ """Convenience wrapper mirroring ``torch.nn.functional.cross_entropy`` reductions.
+
+ This uses :func:`cross_entropy_forward` under the hood but returns only
+ the reduced loss tensor.
+ """
+ loss, _lse = cross_entropy_forward(
+ logits,
+ target,
+ ignore_index=ignore_index,
+ reduction="none",
+ )
+ if reduction == "none":
+ return loss
+ mask = target != ignore_index
+ if reduction == "sum":
+ return loss.sum()
+ if reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ return loss[mask].sum() / valid.to(loss.dtype)
+ return loss.sum() * 0.0
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'."
+ )
+
+
+def verify_cross_entropy_parity(
+ M: int,
+ N: int,
+ dtype: torch.dtype = torch.bfloat16,
+ ignore_index: int = -100,
+) -> None:
+ """Compare SM100 CuteDSL cross-entropy against PyTorch for a single shape."""
+ device = torch.device("cuda")
+ torch.manual_seed(0)
+
+ logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype)
+ logits.requires_grad_(True)
+ target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
+
+ # Optionally sprinkle some ignore_index entries for robustness.
+ if ignore_index != -100:
+ mask = torch.rand(M, device=device) < 0.1
+ target[mask] = ignore_index
+
+ loss, lse = cross_entropy_forward(
+ logits, target, ignore_index=ignore_index, reduction="none"
+ )
+
+ logits_ref = logits.detach().clone().requires_grad_()
+ target_ref = target.detach().clone()
+ loss_ref = torch.nn.functional.cross_entropy(
+ logits_ref.float(),
+ target_ref,
+ ignore_index=ignore_index,
+ reduction="none",
+ )
+
+ # Forward parity
+ if dtype in (torch.float16, torch.bfloat16):
+ atol = 5e-2
+ rtol = 5e-2
+ else:
+ atol = 1e-4
+ rtol = 1e-4
+ torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol)
+
+ # Backward parity
+ dloss = torch.randn_like(loss_ref)
+ (dx_ref,) = torch.autograd.grad(loss_ref, logits_ref, grad_outputs=dloss)
+ dx = cross_entropy_backward(dloss, logits, target, lse, ignore_index=ignore_index)
+ torch.testing.assert_close(dx, dx_ref.to(logits.dtype), atol=atol, rtol=rtol)
+
+
+if __name__ == "__main__":
+ # Minimal functional check when executed directly. For performance
+ # comparisons and detailed tuning, use the dedicated benchmark harness.
+ if not torch.cuda.is_available():
+ print("CUDA not available; cross-entropy parity check skipped.")
+ raise SystemExit(0)
+
+ M, N = 1024, 8192
+ dtype = torch.bfloat16
+ verify_cross_entropy_parity(M, N, dtype=dtype, ignore_index=-100)
+ print("SM100 cross-entropy CuteDSL parity check passed.")
diff --git a/oink/src/kernelagent_oink/blackwell/fast_launch.py b/oink/src/kernelagent_oink/blackwell/fast_launch.py
new file mode 100644
index 0000000..9b288f2
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/fast_launch.py
@@ -0,0 +1,115 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Host-side fast-launch helpers for CuTeDSL pointer entrypoints.
+
+CuTeDSL's Python runtime typically marshals each kernel call by allocating
+`Int32` / `Float32` wrappers and runtime `Pointer` descriptors per invocation.
+For latency-sensitive cases (small/medium M), this overhead can dominate.
+
+These helpers provide:
+- Stable scalar argument wrappers (`StableI32Arg`, `StableF32Arg`) that avoid
+ per-call ctypes allocations.
+- In-place mutation of runtime pointer descriptors (`set_runtime_ptr`) so a
+ compiled kernel can be launched repeatedly with different raw device pointers
+ without rebuilding argument objects.
+- A small thread-local cache to store packed args objects (when supported by the
+ installed CuTeDSL version).
+
+All of this relies on a few private-ish CuTeDSL internals. Callers must treat
+fast-launch as an optional optimization and fall back to the normal launch
+path if those internals are unavailable.
+"""
+
+from __future__ import annotations
+
+import ctypes
+import os
+import threading
+from typing import Any
+
+_FAST_LAUNCH_TLS = threading.local()
+
+
+def _env_flag(name: str, default: bool) -> bool:
+ val = os.environ.get(name)
+ if val is None:
+ return default
+ return val.strip().lower() not in {"0", "false", "no", "off", ""}
+
+
+# Fast-launch uses internal CuTeDSL plumbing (packed args + pointer descriptors).
+# Keep it enabled by default in our pinned environment, but allow disabling it
+# via env var and auto-disable it if CuTeDSL internals change.
+_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True)
+_FAST_LAUNCH_SUPPORTED = True
+
+
+def fast_launch_enabled() -> bool:
+ return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED
+
+
+def disable_fast_launch() -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+
+
+def tls_cache() -> dict[tuple[Any, ...], Any]:
+ cache = getattr(_FAST_LAUNCH_TLS, "cache", None)
+ if cache is None:
+ cache = {}
+ _FAST_LAUNCH_TLS.cache = cache
+ return cache
+
+
+class StableI32Arg:
+ """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: int):
+ self._c_value = ctypes.c_int32(int(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: int) -> None:
+ self._c_value.value = int(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+class StableF32Arg:
+ """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: float):
+ self._c_value = ctypes.c_float(float(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: float) -> None:
+ self._c_value.value = float(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+def set_runtime_ptr(ptr: Any, device_ptr: int) -> None:
+ """Update a CuTeDSL runtime Pointer descriptor in-place.
+
+ This relies on internal runtime pointer fields (`_desc`, `_pointer`, etc.).
+ If these internals change in a future CuTeDSL upgrade, this function may
+ raise AttributeError; callers should catch it and fall back.
+ """
+ device_ptr = int(device_ptr)
+ ptr._pointer = device_ptr # type: ignore[attr-defined]
+ if getattr(ptr, "_c_pointer", None) is None:
+ ptr.__c_pointers__() # type: ignore[attr-defined]
+ ptr._desc.value = device_ptr # type: ignore[attr-defined]
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
new file mode 100644
index 0000000..ada51ec
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -0,0 +1,1981 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+LayerNorm kernel for SM100 (Blackwell) in CuteDSL.
+
+This implementation:
+- Mirrors Quack's LayerNorm tiling / cluster policy / cp.async pipeline
+ but uses only local helpers so that it does not depend on the external
+ `quack` package at runtime.
+- Supports fp16 / bf16 / fp32 inputs with fp32 accumulation.
+- Optionally writes out per-row `rstd` and `mean` buffers for reuse in
+ backward or fused kernels.
+
+Backward is implemented with dedicated CuteDSL kernels for input and
+parameter gradients (dx, dweight, dbias), avoiding PyTorch autograd
+while matching `torch.nn.functional.layer_norm`'s gradients numerically.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import math
+import os
+import re
+import operator
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.layernorm requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase as _ReductionBase,
+ convert_from_dlpack as convert_from_dlpack_cute,
+ get_sm_count,
+ predicate_k,
+ row_reduce,
+ warp_reduce,
+)
+from kernelagent_oink.blackwell.fast_launch import (
+ StableF32Arg,
+ StableI32Arg,
+ disable_fast_launch,
+ fast_launch_enabled,
+ set_runtime_ptr,
+ tls_cache as _tls_fast_launch_cache,
+)
+
+# Simple compile cache for the forward kernel
+_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {}
+_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {}
+
+# Backward compile caches: one for dx, one for parameter gradients.
+_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {}
+_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {}
+
+
+class _PtrLayernormFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: object,
+ ptr_b: Optional[object],
+ ptr_out: object,
+ ptr_rstd: Optional[object],
+ ptr_mean: Optional[object],
+ arg_m: StableI32Arg,
+ arg_ld: StableI32Arg,
+ arg_eps: StableF32Arg,
+ stream: cuda.CUstream,
+ assumed_align_xo: int,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_b = ptr_b
+ self._ptr_out = ptr_out
+ self._ptr_rstd = ptr_rstd
+ self._ptr_mean = ptr_mean
+ self._arg_m = arg_m
+ self._arg_ld = arg_ld
+ self._arg_eps = arg_eps
+ self._stream = stream
+ self._assumed_align_xo = int(assumed_align_xo)
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_b_ptr = -1
+ self._last_out_ptr = -1
+ self._last_rstd_ptr = -1
+ self._last_mean_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+ self._last_eps = float("nan")
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ out: Tensor,
+ rstd: Optional[Tensor],
+ mean: Optional[Tensor],
+ M: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ if not fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ w_ptr = weight.data_ptr()
+ if w_ptr != self._last_w_ptr:
+ try:
+ set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ if self._ptr_b is not None and bias is not None:
+ b_ptr = bias.data_ptr()
+ if b_ptr != self._last_b_ptr:
+ try:
+ set_runtime_ptr(self._ptr_b, b_ptr)
+ self._last_b_ptr = b_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ out_ptr = out.data_ptr()
+ if out_ptr != self._last_out_ptr:
+ try:
+ set_runtime_ptr(self._ptr_out, out_ptr)
+ self._last_out_ptr = out_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ if self._ptr_rstd is not None and rstd is not None:
+ rstd_ptr = rstd.data_ptr()
+ if rstd_ptr != self._last_rstd_ptr:
+ try:
+ set_runtime_ptr(self._ptr_rstd, rstd_ptr)
+ self._last_rstd_ptr = rstd_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ if self._ptr_mean is not None and mean is not None:
+ mean_ptr = mean.data_ptr()
+ if mean_ptr != self._last_mean_ptr:
+ try:
+ set_runtime_ptr(self._ptr_mean, mean_ptr)
+ self._last_mean_ptr = mean_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=M,
+ ld=ld,
+ eps=eps,
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+ if eps != self._last_eps:
+ self._arg_eps.set(eps)
+ self._last_eps = eps
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ self._use_fast_launch = False
+ disable_fast_launch()
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ out: Tensor,
+ rstd: Optional[Tensor],
+ mean: Optional[Tensor],
+ M: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+ ptr_x = rt.make_ptr(
+ dtype_x,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_xo,
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_xo,
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(
+ cutlass.Float32,
+ mean.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if mean is not None
+ else None
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ Int32(int(M)),
+ Int32(int(ld)),
+ stream,
+ Float32(float(eps)),
+ )
+
+
+def _get_fast_ptr_layernorm_launcher(
+ *,
+ compiled: object,
+ N: int,
+ dtype_x: type[cutlass.Numeric],
+ has_bias: bool,
+ has_rstd: bool,
+ has_mean: bool,
+ device_index: int,
+ stream_handle: int,
+ assumed_align_xo: int,
+ eps: float,
+) -> Optional[_PtrLayernormFastLaunch]:
+ if not fast_launch_enabled():
+ return None
+ key = (
+ "ptr_fast",
+ id(compiled),
+ int(N),
+ dtype_x,
+ bool(has_bias),
+ bool(has_rstd),
+ bool(has_mean),
+ int(device_index),
+ int(stream_handle),
+ int(assumed_align_xo),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ ptr_x = rt.make_ptr(
+ dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo)
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo)
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ if has_bias
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4)
+ if has_rstd
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4)
+ if has_mean
+ else None
+ )
+
+ arg_m = StableI32Arg(0)
+ arg_ld = StableI32Arg(N)
+ arg_eps = StableF32Arg(eps)
+ stream = cuda.CUstream(int(stream_handle))
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ arg_m,
+ arg_ld,
+ stream,
+ arg_eps,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ disable_fast_launch()
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ arg_m,
+ arg_ld,
+ arg_eps,
+ stream,
+ *adapted_args,
+ )
+ launcher = _PtrLayernormFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_b=ptr_b,
+ ptr_out=ptr_out,
+ ptr_rstd=ptr_rstd,
+ ptr_mean=ptr_mean,
+ arg_m=arg_m,
+ arg_ld=arg_ld,
+ arg_eps=arg_eps,
+ stream=stream,
+ assumed_align_xo=int(assumed_align_xo),
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+def _convert_row_major(t: Tensor) -> cute.Tensor:
+ """
+ Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact,
+ dynamic layout on the leading dimension.
+ """
+ return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0,
+ stride_order=(0, 1),
+ )
+
+
+class LayerNormSM100(_ReductionBase):
+ """
+ SM100 LayerNorm forward kernel.
+
+ This mirrors `quack.layernorm.LayerNorm`'s schedule:
+ - Stage=2 pipeline: first pass computes mean, second pass computes
+ variance / rstd and normalization.
+ - Threads-per-row and cluster_n policy follow Quack's LayerNorm
+ heuristics to keep tensor-core friendly tiles across N.
+ - Optional `reload_from` hint enables reloading X from SMEM for large-N
+ shapes to shorten register lifetimes.
+
+ Differences vs Quack:
+ - Bias is optional and supported directly in the kernel.
+ - Dtype mapping and reduction helpers come from `lite_quack`.
+ """
+
+ def __init__(
+ self,
+ dtype: type[cutlass.Numeric],
+ N: int,
+ *,
+ copy_bits_x: Optional[int] = None,
+ direct_gmem: bool = False,
+ ):
+ super().__init__(dtype, N, stage=2) # 2 stages for mean and var
+ # Default reload policy mirrors Quack: use SMEM reload only for
+ # very large hidden sizes. We keep this conservative for LayerNorm
+ # and tune primarily via threads-per-block / cluster_n.
+ self.reload_from: Optional[str] = None if N <= 16384 else "smem"
+ # SM100 tuning: for DSv3 hidden sizes where we fuse mean+var stats,
+ # delay loading fp32 weights/bias until after the reductions to lower
+ # register pressure.
+ self.delay_w_load: bool = bool(N in (4096, 6144, 7168, 8192))
+ self.copy_bits_x: Optional[int] = (
+ int(copy_bits_x) if copy_bits_x is not None else None
+ )
+ self.direct_gmem: bool = bool(direct_gmem)
+
+ def _get_num_threads(self) -> int:
+ nt = getattr(self, "_nt_override", None)
+ if nt is not None:
+ return int(nt)
+ return super()._get_num_threads()
+
+ def _calculate_threads_per_row(self) -> int:
+ tpr = getattr(self, "_tpr_override", None)
+ if tpr is not None:
+ return int(tpr)
+ # Match Quack's LayerNorm threads-per-row buckets.
+ N = self.N
+ if N in (4096, 6144):
+ return 128
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ # Tiling and cluster policy (mirrors Quack LayerNorm).
+ self._set_cluster_n()
+ largest_dtype_width = const_expr(
+ max(
+ t.element_type.width
+ for t in (mX, mW, mB, mO, mRstd, mMean)
+ if t is not None
+ )
+ )
+ # Match Quack's unified RMSNorm/LayerNorm kernel: pick vecsize based on
+ # the widest dtype participating in the op (e.g. fp32 weights => fp16
+ # X uses 64b vectorization).
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
+ default_copy_bits_x = vecsize * self.dtype.width
+ num_copy_bits_x = (
+ int(self.copy_bits_x)
+ if self.copy_bits_x is not None
+ else default_copy_bits_x
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+
+ # Expand weight / bias to match tiler_mn[0] rows per CTA.
+ mW = cute.make_tensor(
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+ if const_expr(mMean is not None):
+ mMean = cute.make_tensor(
+ mMean.iterator,
+ cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(self.reload_from),
+ const_expr(self.delay_w_load),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[
+ 1,
+ self.cluster_n,
+ 1,
+ ]
+ if const_expr(self.cluster_n > 1)
+ else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: cute.Pointer,
+ ptr_b: Optional[cute.Pointer],
+ ptr_out: cute.Pointer,
+ ptr_rstd: Optional[cute.Pointer],
+ ptr_mean: Optional[cute.Pointer],
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions.
+
+ This reconstructs cute.Tensor views from raw device pointers + explicit
+ layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule.
+ """
+ # Mirror Quack-style divisibility contracts so the compiler can prove
+ # alignment for vectorized loads/stores (and cp.async when enabled).
+ divby = (
+ int(self.copy_bits_x) // self.dtype.width
+ if const_expr(self.copy_bits_x is not None)
+ else (128 // self.dtype.width)
+ )
+ ld_assumed = cute.assume(ld, divby=divby)
+ # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static.
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_n = cute.make_layout((self.N,), stride=(1,))
+ layout_m = cute.make_layout((M,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+ mW = cute.make_tensor(ptr_w, layout_n)
+ mB = (
+ cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
+ )
+ mRstd = (
+ cute.make_tensor(ptr_rstd, layout_m)
+ if const_expr(ptr_rstd is not None)
+ else None
+ )
+ mMean = (
+ cute.make_tensor(ptr_mean, layout_m)
+ if const_expr(ptr_mean is not None)
+ else None
+ )
+
+ self.__call__(mX, mW, mB, mO, mRstd, mMean, stream, eps)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ reload_from: cutlass.Constexpr,
+ delay_w_load: cutlass.Constexpr,
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Quack-style CTA tiling: let CuTe compute the CTA offsets directly.
+ # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on
+ # the common inference/benchmark sizes.)
+ gX, gO = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO)]
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
+ gB = (
+ cute.local_tile(mB, tiler_mn, (0, cluster_y))
+ if const_expr(mB is not None)
+ else None
+ )
+ gRstd = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
+ if const_expr(mRstd is not None)
+ else None
+ )
+ gMean = (
+ cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
+ if const_expr(mMean is not None)
+ else None
+ )
+
+ # Copy atoms for X / W / B / O (mirror Quack's vector-size contract).
+ num_copy_elems_x = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ num_copy_bits_x = mX.element_type.width * num_copy_elems_x
+ num_copy_bits_x_async = const_expr(min(128, num_copy_bits_x))
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mX.element_type,
+ num_bits_per_copy=num_copy_bits_x,
+ )
+ copy_atom_load_X_async = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mX.element_type,
+ num_bits_per_copy=num_copy_bits_x_async,
+ )
+ num_copy_bits_wb = const_expr(
+ min(128, mW.element_type.width * num_copy_elems_x)
+ )
+ copy_atom_load_WB = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mW.element_type,
+ num_bits_per_copy=num_copy_bits_wb,
+ )
+ copy_atom_store_O = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mO.element_type,
+ num_bits_per_copy=num_copy_bits_x,
+ )
+
+ # Quack-style partitioning: use `make_tiled_copy_tv` (2D thread/value
+ # layout) and let partitioning over the CTA tile handle the N loop.
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems_x))
+ thr_copy = cute.make_tiled_copy_tv(
+ copy_atom_load_X, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tXgX = thr_copy.partition_S(gX)
+ tXsX = thr_copy.partition_D(sX)
+ tXgO = thr_copy.partition_D(gO)
+ tXgW = thr_copy.partition_S(gW)
+ tXgB = thr_copy.partition_S(gB) if const_expr(gB is not None) else None
+ tXrRstd = thr_copy.partition_D(gRstd) if const_expr(mRstd is not None) else None
+ tXrMean = thr_copy.partition_D(gMean) if const_expr(mMean is not None) else None
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
+
+ # Fragments for gmem->rmem.
+ tXrW = cute.make_fragment_like(tXgW)
+ tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=False)
+
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ None if is_even_N else predicate_k(thr_copy.partition_S(cX), limit=shape[1])
+ )
+ row = tXcX[0][0]
+ if const_expr(not self.direct_gmem):
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+
+ if const_expr(not delay_w_load):
+ cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX)
+ if const_expr(mB is not None):
+ cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX)
+
+ if const_expr(not self.direct_gmem):
+ cute.arch.cp_async_wait_group(0)
+ cute.autovec_copy(tXsX, tXrX)
+ else:
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
+ x = tXrX.load().to(Float32)
+ if const_expr(self.cluster_n == 1 and self.N in (4096, 6144, 7168, 8192)):
+ # SM100 tuning for DSv3 hidden sizes:
+ # Compute (sum_x, sum_x2) together so we can derive mean + variance
+ # without a second reduction pass (and without re-materializing
+ # x-mean for the variance reduction).
+ sum_x = x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0)
+ sum_x2 = (x * x).reduce(
+ cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0
+ )
+ sum_x = warp_reduce(
+ sum_x,
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ sum_x2 = warp_reduce(
+ sum_x2,
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
+ if const_expr(warps_per_row > 1 or cluster_n > 1):
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if lane_idx == 0:
+ reduction_buffer[row_idx, col_idx, 0] = sum_x
+ reduction_buffer[row_idx, col_idx, 1] = sum_x2
+ cute.arch.barrier()
+ block_sum_x = 0.0
+ block_sum_x2 = 0.0
+ if lane_idx < warps_per_row:
+ block_sum_x = reduction_buffer[row_idx, lane_idx, 0]
+ block_sum_x2 = reduction_buffer[row_idx, lane_idx, 1]
+ sum_x = warp_reduce(block_sum_x, operator.add)
+ sum_x2 = warp_reduce(block_sum_x2, operator.add)
+ mean = sum_x / shape[1]
+ var = sum_x2 / shape[1] - mean * mean
+ var = cute.arch.fmax(var, 0.0)
+ rstd = cute.math.rsqrt(var + eps, fastmath=True)
+ else:
+ sum_x = row_reduce(
+ x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ hook_fn=(
+ cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None
+ ),
+ )
+ mean = sum_x / shape[1]
+
+ if const_expr(reload_from == "smem"):
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ elif const_expr(reload_from == "gmem"):
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
+ x = tXrX.load().to(Float32)
+
+ sum_sq_x_sub_mean = row_reduce(
+ (x - mean) * (x - mean),
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 1],
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXrRstd[0] = rstd
+
+ if const_expr(mMean is not None):
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXrMean[0] = mean
+
+ if const_expr(delay_w_load):
+ cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX)
+ if const_expr(mB is not None):
+ cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX)
+
+ if const_expr(reload_from == "smem"):
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ elif const_expr(reload_from == "gmem"):
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
+ x = tXrX.load().to(Float32)
+
+ x_hat = (x - mean) * rstd
+ w = tXrW.load().to(Float32)
+ y = x_hat * w
+ if const_expr(mB is not None):
+ b = tXrB.load().to(Float32)
+ y = y + b
+
+ tXrO.store(y.to(tXrO.element_type))
+ if row < shape[0]:
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ reload_from: cutlass.Constexpr,
+ delay_w_load: cutlass.Constexpr,
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ reload_from,
+ delay_w_load,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mW.element_type.width,
+ mB.element_type.width if const_expr(mB is not None) else 0,
+ mO.element_type.width,
+ mRstd.element_type.width if const_expr(mRstd is not None) else 0,
+ mMean.element_type.width if const_expr(mMean is not None) else 0,
+ )
+ )
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
+ default_copy_bits_x = vecsize * mX.element_type.width
+ num_copy_bits_x = (
+ int(self.copy_bits_x)
+ if const_expr(self.copy_bits_x is not None)
+ else default_copy_bits_x
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x)
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(self.reload_from),
+ const_expr(self.delay_w_load),
+ )
+
+
+# -----------------------------------------------------------------------------
+# Public Python API
+# -----------------------------------------------------------------------------
+
+
+def layernorm(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ return_rstd: bool = False,
+ return_mean: bool = False,
+):
+ """
+ LayerNorm forward pass using the SM100 CuteDSL kernel.
+
+ Args:
+ x: Input tensor of shape (M, N).
+ weight: Scale parameter of shape (N,), typically fp32.
+ bias: Optional bias parameter of shape (N,).
+ eps: Small value for numerical stability.
+ return_rstd: Whether to return per-row reciprocal std (shape (M,)).
+ return_mean: Whether to return per-row mean (shape (M,)).
+ """
+ assert x.is_cuda and weight.is_cuda, "x and weight must be CUDA tensors"
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ assert weight.dim() == 1, "weight must be 1D"
+ assert x.shape[1] == weight.shape[0], "Last dim of x must match weight.size(0)"
+ if bias is not None:
+ assert bias.is_cuda, "bias must be on CUDA"
+ assert bias.dim() == 1 and bias.shape[0] == weight.shape[0], (
+ "bias must be 1D and match weight"
+ )
+
+ M, N = x.shape
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32) if return_rstd else None
+ mean = torch.empty(M, device=x.device, dtype=torch.float32) if return_mean else None
+
+ # Fast path: bypass DLPack conversions when the inputs are in the common
+ # contiguous row-major layout and weights/bias are fp32 (Quack-style).
+ if _can_use_ptr_path(x, weight, bias):
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ _layernorm_forward_ptr_into(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ eps=eps,
+ )
+ if return_mean and return_rstd:
+ return out, rstd, mean
+ if return_rstd and not return_mean:
+ return out, rstd
+ if return_mean and not return_rstd:
+ return out, mean
+ return out
+
+ out = torch.empty_like(x)
+ mX = _convert_row_major(x)
+ mO = _convert_row_major(out)
+
+ # Weight/bias live in feature dimension (N).
+ mW = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ mB = (
+ convert_from_dlpack_cute(
+ bias.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ if bias is not None
+ else None
+ )
+
+ mRstd = (
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ if rstd is not None
+ else None
+ )
+ mMean = (
+ from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ if mean is not None
+ else None
+ )
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (N, dtype, mB is not None, mRstd is not None, mMean is not None)
+ compiled = _COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = LayerNormSM100(dtype, N)
+ compiled = cute.compile(
+ op,
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ stream,
+ Float32(eps),
+ )
+ _COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ stream,
+ Float32(eps),
+ )
+
+ if return_mean and return_rstd:
+ return out, rstd, mean
+ if return_rstd and not return_mean:
+ return out, rstd
+ if return_mean and not return_rstd:
+ return out, mean
+ return out
+
+
+def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool:
+ """Return True if we can safely use the pointer-based fast path.
+
+ This is intentionally conservative: we target the common inference-like
+ layout (2D row-major with stride(1)==1) and Quack-style fp32 weights.
+ """
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.stride(1) != 1:
+ return False
+ if not weight.is_cuda or weight.dim() != 1:
+ return False
+ if weight.dtype != torch.float32:
+ return False
+ if not weight.is_contiguous():
+ return False
+ if bias is not None:
+ if not bias.is_cuda or bias.dim() != 1:
+ return False
+ if bias.dtype != torch.float32:
+ return False
+ if not bias.is_contiguous():
+ return False
+ # Require 16B alignment for 128-bit vector copies (matches Quack's assumed_align=16).
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if (weight.data_ptr() % 16) != 0:
+ return False
+ if bias is not None and (bias.data_ptr() % 16) != 0:
+ return False
+ # The kernel uses 128-bit vectorized loads; require the leading dimension
+ # to preserve 16B alignment for every row start.
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _layernorm_forward_ptr_into(
+ *,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ out: Tensor,
+ rstd: Optional[Tensor],
+ mean: Optional[Tensor],
+ eps: float,
+) -> None:
+ """Launch the pointer-based LayerNorm kernel into preallocated outputs."""
+ assert x.is_cuda and x.dim() == 2
+ M, N = x.shape
+ assert weight.is_cuda and weight.dim() == 1 and weight.shape[0] == N
+ if bias is not None:
+ assert bias.is_cuda and bias.dim() == 1 and bias.shape[0] == N
+ assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype
+ assert out.stride() == x.stride(), "Pointer path expects out to match x strides"
+ if rstd is not None:
+ assert rstd.is_cuda and rstd.shape == (M,) and rstd.dtype == torch.float32
+ if mean is not None:
+ assert mean.is_cuda and mean.shape == (M,) and mean.dtype == torch.float32
+
+ device_index = x.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ # Keep the pointer path aligned with Quack's LayerNorm schedule:
+ # - <=128b vectorization (cp.async-compatible)
+ # - shared-memory staging for X (gmem->smem->rmem) to amortize global latency
+ direct_gmem = False
+ copy_bits_x: Optional[int] = None
+ assumed_align_xo = 16
+
+ # DSv3 hidden sizes are often latency-bound on small M. For these N buckets,
+ # a direct-GMEM schedule (skip gmem->smem cp.async) can reduce overhead.
+ #
+ # Keep the Quack-like staged path for large M where cp.async overlap tends to win.
+ if dtype_x.width == 16:
+ # DSv3 default hidden size (7168) is a common inference hot shape and
+ # benefits from the lower-overhead direct-GMEM path on this SM100.
+ if N == 7168 and M <= 65536:
+ direct_gmem = True
+ elif N == 8192 and M <= 16384:
+ direct_gmem = True
+
+ # DSv3 smallest point (M=4096, N=7168) is latency-sensitive. Increasing
+ # per-row parallelism improves the reduction path and consistently beats
+ # Quack on this machine.
+ tpr_override: Optional[int] = None
+ nt_override: Optional[int] = None
+ if dtype_x.width == 16 and N == 7168 and M <= 4096:
+ tpr_override = 224
+ nt_override = 224
+
+ # NOTE: We previously experimented with a direct-GMEM + 256b vectorized
+ # schedule for N=4096, but it was consistently slower on this GB200.
+ # Keep the pointer path on the Quack-like staged (cp.async) schedule.
+ key = (
+ "ptr",
+ int(N),
+ dtype_x,
+ bias is not None,
+ rstd is not None,
+ mean is not None,
+ bool(direct_gmem),
+ int(copy_bits_x) if copy_bits_x is not None else None,
+ tpr_override,
+ nt_override,
+ int(assumed_align_xo),
+ int(device_index),
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = LayerNormSM100(
+ dtype_x,
+ int(N),
+ copy_bits_x=copy_bits_x,
+ direct_gmem=direct_gmem,
+ )
+ if tpr_override is not None:
+ op._tpr_override = tpr_override # type: ignore[attr-defined]
+ if nt_override is not None:
+ op._nt_override = nt_override # type: ignore[attr-defined]
+ ptr_x = rt.make_ptr(
+ dtype_x,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_xo,
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_xo,
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(
+ cutlass.Float32,
+ mean.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if mean is not None
+ else None
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ Int32(int(M)),
+ ld,
+ stream,
+ Float32(float(eps)),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_layernorm_launcher(
+ compiled=compiled,
+ N=int(N),
+ dtype_x=dtype_x,
+ has_bias=bias is not None,
+ has_rstd=rstd is not None,
+ has_mean=mean is not None,
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ assumed_align_xo=int(assumed_align_xo),
+ eps=float(eps),
+ )
+ ld_val = int(x.stride(0))
+ if launcher is not None:
+ launcher.launch(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ M=int(M),
+ ld=ld_val,
+ eps=float(eps),
+ )
+ return
+
+ ptr_x = rt.make_ptr(
+ dtype_x,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_xo,
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_xo,
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(
+ cutlass.Float32,
+ mean.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if mean is not None
+ else None
+ )
+ ld = Int32(ld_val)
+ compiled(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ Int32(int(M)),
+ ld,
+ stream,
+ Float32(float(eps)),
+ )
+
+
+def layernorm_ref(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ eps: float = 1e-6,
+) -> Tensor:
+ """
+ Reference LayerNorm implemented via torch.nn.functional.layer_norm.
+ """
+ x_f32 = x.float()
+ w = weight.float()
+ b = bias.float() if bias is not None else None
+ y = torch.nn.functional.layer_norm(x_f32, (x.shape[-1],), w, b, eps)
+ return y.to(x.dtype)
+
+
+def _as_2d(x: Tensor) -> Tuple[Tensor, Tuple[int, ...]]:
+ if x.dim() == 2:
+ return x, x.shape
+ original_shape = x.shape
+ M = int(torch.prod(torch.tensor(original_shape[:-1])).item())
+ N = original_shape[-1]
+ return x.reshape(M, N), original_shape
+
+
+def _restore_shape(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
+ return x.reshape(shape)
+
+
+@cute.kernel
+def _layernorm_backward_dx_kernel(
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdX: cute.Tensor,
+):
+ """
+ Simple CTA-per-row LayerNorm backward kernel for dx only.
+
+ Each block processes one row of shape (N,), using block_threads threads.
+ It performs two passes over the row:
+ 1) Compute mean_wdy and mean_xhat_wdy in fp32.
+ 2) Compute dx using the standard LayerNorm backward formula:
+ dx = rstd * (wdy - mean_wdy - x_hat * mean_xhat_wdy),
+ where wdy = dy * gamma and x_hat = (x - mean) * rstd.
+ """
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+
+ block_threads = const_expr(256)
+ shape = mX.shape
+ M = shape[0]
+ N = shape[1]
+
+ row = bidx
+ if row < M:
+ # Shared buffers for warp-level reductions across the block.
+ smem = cutlass.utils.SmemAllocator()
+ num_warps = const_expr(block_threads // cute.arch.WARP_SIZE)
+ warp_sums_layout = cute.make_layout((num_warps,), stride=(1,))
+ warp_sums_wdy = smem.allocate_tensor(
+ Float32, warp_sums_layout, byte_alignment=4
+ )
+ warp_sums_xhatwdy = smem.allocate_tensor(
+ Float32, warp_sums_layout, byte_alignment=4
+ )
+
+ lane = cute.arch.lane_idx()
+ warp_idx = cute.arch.warp_idx()
+
+ rstd_val = mRstd[row].to(Float32)
+ mean_val = mMean[row].to(Float32)
+
+ # Pass 1: compute local partial sums of wdy and x_hat*wdy.
+ local_wdy = Float32(0.0)
+ local_xhatwdy = Float32(0.0)
+ for col in cutlass.range(tidx, N, block_threads):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ gamma = mW[col].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ wdy = dy_val * gamma
+ local_wdy += wdy
+ local_xhatwdy += x_hat * wdy
+
+ # Warp-level reduction, then block-level reduction via shared memory.
+ red_op = operator.add # type: ignore[assignment]
+ local_wdy = warp_reduce(local_wdy, red_op)
+ local_xhatwdy = warp_reduce(local_xhatwdy, red_op)
+
+ if lane == 0:
+ warp_sums_wdy[warp_idx] = local_wdy
+ warp_sums_xhatwdy[warp_idx] = local_xhatwdy
+
+ cute.arch.barrier()
+
+ total_wdy = Float32(0.0)
+ total_xhatwdy = Float32(0.0)
+ if warp_idx == 0 and lane == 0:
+ for wi in cutlass.range_constexpr(num_warps):
+ total_wdy += warp_sums_wdy[wi]
+ total_xhatwdy += warp_sums_xhatwdy[wi]
+ # Store totals back into first slots for broadcast.
+ warp_sums_wdy[0] = total_wdy
+ warp_sums_xhatwdy[0] = total_xhatwdy
+
+ cute.arch.barrier()
+
+ total_wdy = warp_sums_wdy[0]
+ total_xhatwdy = warp_sums_xhatwdy[0]
+ inv_N = Float32(1.0 / float(N))
+ mean_wdy = total_wdy * inv_N
+ mean_xhatwdy = total_xhatwdy * inv_N
+
+ # Pass 2: compute dx and write back.
+ for col in cutlass.range(tidx, N, block_threads):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ gamma = mW[col].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ wdy = dy_val * gamma
+ dx_val = (wdy - mean_wdy - x_hat * mean_xhatwdy) * rstd_val
+ mdX[row, col] = dx_val.to(mdX.element_type)
+
+
+@cute.jit
+def _layernorm_backward_dx(
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdX: cute.Tensor,
+ stream: cuda.CUstream,
+) -> None:
+ """
+ JIT wrapper that launches the dx-only LayerNorm backward kernel.
+ One CTA processes one row of length N with 256 threads.
+ """
+ M = mX.shape[0]
+ _layernorm_backward_dx_kernel(
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ ).launch(
+ grid=[M, 1, 1],
+ block=[256, 1, 1],
+ stream=stream,
+ )
+
+
+@cute.kernel
+def _layernorm_backward_param_kernel(
+ mX: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdW_partial: Optional[cute.Tensor],
+ mdB_partial: Optional[cute.Tensor],
+ num_blocks: Int32,
+) -> None:
+ """
+ Parameter-gradient kernel for LayerNorm.
+
+ Each CTA accumulates partial dweight/dbias over a stripe of rows:
+ - Grid dim X: num_blocks (sm_count-style persistent CTAs).
+ - Threads in a CTA partition the N dimension.
+ - For each assigned column, a thread streams over rows
+ row = blockIdx.x, blockIdx.x + num_blocks, ...
+
+ This mirrors the persistent-CTA pattern used by RMSNorm backward,
+ but uses a simpler per-thread accumulation since columns are
+ independent.
+ """
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+
+ block_threads = const_expr(256)
+ M = mX.shape[0]
+ N = mX.shape[1]
+
+ if bidx < num_blocks:
+ for col in cutlass.range(tidx, N, block_threads):
+ dw_local = Float32(0.0)
+ db_local = Float32(0.0)
+ for row in cutlass.range(bidx, M, num_blocks):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ rstd_val = mRstd[row].to(Float32)
+ mean_val = mMean[row].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ dw_local += dy_val * x_hat
+ db_local += dy_val
+
+ if const_expr(mdW_partial is not None):
+ mdW_partial[bidx, col] = dw_local
+ if const_expr(mdB_partial is not None):
+ mdB_partial[bidx, col] = db_local
+
+
+@cute.jit
+def _layernorm_backward_param(
+ mX: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdW_partial: Optional[cute.Tensor],
+ mdB_partial: Optional[cute.Tensor],
+ num_blocks: Int32,
+ stream: cuda.CUstream,
+) -> None:
+ """
+ JIT wrapper that launches the parameter-gradient kernel.
+ """
+ _layernorm_backward_param_kernel(
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ num_blocks,
+ ).launch(
+ grid=[num_blocks, 1, 1],
+ block=[256, 1, 1],
+ stream=stream,
+ )
+
+
+def _layernorm_backward_dx_sm100(
+ dout_2d: Tensor,
+ x_2d: Tensor,
+ weight: Tensor,
+ rstd_1d: Tensor,
+ mean_1d: Tensor,
+ dx_2d: Tensor,
+) -> None:
+ """
+ Host-side helper to run the dx-only LayerNorm backward kernel.
+ """
+ M, N = x_2d.shape
+ assert dout_2d.shape == (M, N)
+ assert rstd_1d.numel() == M
+ assert mean_1d.numel() == M
+
+ dtype = TORCH2CUTE_DTYPE[x_2d.dtype]
+
+ mX = _convert_row_major(x_2d)
+ mdO = _convert_row_major(dout_2d)
+ mdX = _convert_row_major(dx_2d)
+
+ mW = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (N, dtype)
+ compiled = _BWD_DX_COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(
+ _layernorm_backward_dx,
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ stream,
+ )
+ _BWD_DX_COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ stream,
+ )
+
+
+def _layernorm_backward_params_sm100(
+ dout_2d: Tensor,
+ x_2d: Tensor,
+ rstd_1d: Tensor,
+ mean_1d: Tensor,
+ dw_partial: Optional[Tensor],
+ db_partial: Optional[Tensor],
+ sm_count: int,
+) -> None:
+ """
+ Host-side helper to run the parameter-gradient kernel that populates
+ dw_partial / db_partial of shape (sm_count, N).
+ """
+ M, N = x_2d.shape
+ assert dout_2d.shape == (M, N)
+ assert rstd_1d.numel() == M
+ assert mean_1d.numel() == M
+ if dw_partial is None and db_partial is None:
+ return
+
+ dtype = TORCH2CUTE_DTYPE[x_2d.dtype]
+
+ mX = _convert_row_major(x_2d)
+ mdO = _convert_row_major(dout_2d)
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+
+ mdW_partial = (
+ from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if dw_partial is not None
+ else None
+ )
+ mdB_partial = (
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if db_partial is not None
+ else None
+ )
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ has_bias = db_partial is not None
+ key = (N, dtype, has_bias)
+ compiled = _BWD_PARAM_COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(
+ _layernorm_backward_param,
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ Int32(sm_count),
+ stream,
+ )
+ _BWD_PARAM_COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ Int32(sm_count),
+ stream,
+ )
+
+
+def layernorm_backward(
+ dout: Tensor,
+ x: Tensor,
+ weight: Tensor,
+ rstd: Tensor,
+ mean: Tensor,
+ bias: Optional[Tensor] = None,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """
+ LayerNorm backward implemented in CuteDSL / CUTLASS.
+
+ Computes gradients w.r.t. input, weight, and optional bias using
+ two kernels:
+ - A dx kernel (CTA-per-row) that streams over N.
+ - A parameter-gradient kernel that accumulates dw/db over a
+ persistent grid of CTAs across the M dimension.
+ """
+ assert x.shape == dout.shape, "x and dout must have the same shape"
+ assert x.is_cuda and dout.is_cuda, "x and dout must be CUDA tensors"
+ assert weight.dim() == 1, "weight must be 1D"
+ if bias is not None:
+ assert bias.dim() == 1, "bias must be 1D"
+
+ x_2d, orig_shape = _as_2d(x)
+ dout_2d, _ = _as_2d(dout)
+ M, N = x_2d.shape
+
+ # Flatten to 2D for the kernels.
+ mean_flat = mean.view(M)
+ rstd_flat = rstd.view(M)
+
+ dx_2d = torch.empty_like(x_2d)
+ _layernorm_backward_dx_sm100(
+ dout_2d,
+ x_2d,
+ weight,
+ rstd_flat,
+ mean_flat,
+ dx_2d,
+ )
+
+ device = x.device
+ sm_count = get_sm_count(N, device, M=M, dtype=x.dtype)
+
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ db_partial = (
+ torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ if bias is not None
+ else None
+ )
+
+ _layernorm_backward_params_sm100(
+ dout_2d,
+ x_2d,
+ rstd_flat,
+ mean_flat,
+ dw_partial,
+ db_partial,
+ sm_count,
+ )
+
+ dweight = dw_partial.sum(dim=0).to(weight.dtype)
+ dbias = db_partial.sum(dim=0).to(bias.dtype) if bias is not None else None
+
+ dx = _restore_shape(dx_2d, orig_shape)
+ return dx, dweight, dbias
+
+
+if __name__ == "__main__":
+ # Allow direct execution for a quick functional check.
+ if not torch.cuda.is_available():
+ print("CUDA not available; LayerNormSM100 test skipped.")
+ raise SystemExit(0)
+
+ device = "cuda"
+ M, N = 2048, 4096
+ dtype = torch.bfloat16
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=torch.float32)
+ b = torch.randn(N, device=device, dtype=torch.float32)
+
+ y_ref = layernorm_ref(x, w, b)
+ y, rstd, mean = layernorm(x, w, b, return_rstd=True, return_mean=True)
+ torch.testing.assert_close(
+ y,
+ y_ref,
+ atol=5e-2 if dtype != torch.float32 else 1e-5,
+ rtol=5e-2 if dtype != torch.float32 else 1e-5,
+ )
+
+ print("LayerNormSM100 forward correctness check passed.")
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
new file mode 100644
index 0000000..c7402d8
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -0,0 +1,1444 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Lightweight local clone of the small subset of Quack helpers that the SM100
+RMSNorm CuteDSL kernels depend on.
+
+This module intentionally avoids importing the `quack` package so that
+KernelAgent Oink SM100 kernels can run without Quack installed, while keeping
+numerical behaviour and performance identical to the reference kernels.
+"""
+
+from __future__ import annotations
+
+import math
+import operator
+import importlib.metadata
+import re
+from functools import partial
+from typing import Callable, Optional, Tuple, Type
+
+import cuda.bindings.driver as cuda # type: ignore
+import torch
+from torch import Tensor
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute.runtime import from_dlpack
+from cutlass.cutlass_dsl import T, dsl_user_op
+from cutlass._mlir.dialects import llvm, nvvm, vector
+
+
+def _parse_version_tuple(version: str) -> tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when
+# we detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
+
+# Cache device properties lookups (notably `multi_processor_count`) since some
+# dispatch paths call `get_sm_count` inside tight benchmark loops.
+_DEVICE_NUM_SMS_CACHE: dict[int, int] = {}
+
+
+def get_num_sms(device: torch.device) -> int:
+ """Return the number of SMs for a CUDA device (cached)."""
+ device_index = device.index
+ if device_index is None:
+ device_index = torch.cuda.current_device()
+ device_index = int(device_index)
+ cached = _DEVICE_NUM_SMS_CACHE.get(device_index)
+ if cached is not None:
+ return cached
+ num_sms = int(torch.cuda.get_device_properties(device_index).multi_processor_count)
+ _DEVICE_NUM_SMS_CACHE[device_index] = num_sms
+ return num_sms
+
+
+# -------------------------
+# Dtype mapping (from quack.cute_dsl_utils)
+# -------------------------
+
+TORCH2CUTE_DTYPE = {
+ torch.float16: cutlass.Float16,
+ torch.bfloat16: cutlass.BFloat16,
+ torch.float32: cutlass.Float32,
+}
+
+
+# -------------------------
+# Tensor conversion helpers (from quack.utils)
+# -------------------------
+
+
+def convert_from_dlpack(
+ x: Tensor,
+ leading_dim: int,
+ alignment: int = 16,
+ divisibility: int = 1,
+) -> cute.Tensor:
+ return (
+ from_dlpack(x, assumed_align=alignment)
+ .mark_layout_dynamic(leading_dim=leading_dim)
+ .mark_compact_shape_dynamic(
+ mode=leading_dim,
+ stride_order=x.dim_order(),
+ divisibility=divisibility,
+ )
+ )
+
+
+# -------------------------
+# SM90/SM100 cluster helpers (from quack.utils)
+# -------------------------
+
+
+@dsl_user_op
+def elem_pointer(
+ x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None
+) -> cute.Pointer:
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
+
+
+@dsl_user_op
+def set_block_rank(
+ smem_ptr: cute.Pointer,
+ peer_cta_rank_in_cluster: cute.Int32,
+ *,
+ loc=None,
+ ip=None,
+) -> cutlass.Int32:
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
+ return cutlass.Int32(
+ llvm.inline_asm(
+ T.i32(),
+ [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
+ "mapa.shared::cluster.u32 $0, $1, $2;",
+ "=r,r,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ )
+ )
+
+
+@dsl_user_op
+def store_shared_remote(
+ val: float | Float32 | Int32 | cutlass.Int64,
+ smem_ptr: cute.Pointer,
+ mbar_ptr: cute.Pointer,
+ peer_cta_rank_in_cluster: cute.typing.Int,
+ *,
+ loc=None,
+ ip=None,
+) -> None:
+ remote_smem_ptr_i32 = set_block_rank(
+ smem_ptr,
+ peer_cta_rank_in_cluster,
+ loc=loc,
+ ip=ip,
+ ).ir_value()
+ remote_mbar_ptr_i32 = set_block_rank(
+ mbar_ptr,
+ peer_cta_rank_in_cluster,
+ loc=loc,
+ ip=ip,
+ ).ir_value()
+ if const_expr(isinstance(val, float)):
+ val = Float32(val)
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), (
+ "val must be Float32, Int32, or Int64"
+ )
+ suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
+ constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
+ llvm.inline_asm(
+ None,
+ [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
+ f"r,{constraint},r",
+ has_side_effects=True,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ )
+
+
+@dsl_user_op
+def atomic_add_f32(
+ a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
+) -> Float32:
+ """Atomic add into global memory (float32)."""
+ return nvvm.atomicrmw(
+ res=T.f32(),
+ op=nvvm.AtomicOpKind.FADD,
+ ptr=gmem_ptr.llvm_ptr,
+ a=Float32(a).ir_value(loc=loc, ip=ip),
+ loc=loc,
+ ip=ip,
+ )
+
+
+@cute.jit
+def atomic_add_tensor_f32(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+) -> None:
+ """Atomic-add a register fragment into a GMEM tile (float32)."""
+ if const_expr(pred is None):
+ for i in cutlass.range_constexpr(cute.size(src.shape)):
+ coord = cute.idx2crd(i, src.shape)
+ atomic_add_f32(src[i], elem_pointer(dst, coord))
+ else:
+ for i in cutlass.range_constexpr(cute.size(src.shape)):
+ # CuTeDSL 4.3.4+ disallows introducing new tuple-typed values inside
+ # a dynamic `if`. Compute `coord` unconditionally, then predicate the
+ # atomic update.
+ coord = cute.idx2crd(i, src.shape)
+ if pred[i]:
+ atomic_add_f32(src[i], elem_pointer(dst, coord))
+
+
+@cute.jit
+def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if".
+ tApA = cute.make_fragment(
+ cute.make_layout(
+ (
+ cute.size(tAcA, mode=[0, 1]),
+ cute.size(tAcA, mode=[1]),
+ cute.size(tAcA, mode=[2]),
+ ),
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
+ ),
+ cutlass.Boolean,
+ )
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
+ tApA[rest_v, 0, rest_k] = cute.elem_less(
+ tAcA[(0, rest_v), 0, rest_k][1], limit
+ )
+ return tApA
+
+
+@dsl_user_op
+def domain_offset_i64(
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
+) -> cute.Tensor:
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
+ assert len(flat_coord_i64) == len(flat_stride), (
+ "Coordinate and stride must have the same length"
+ )
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
+ assert isinstance(tensor.iterator, cute.Pointer)
+ new_ptr = cute.make_ptr(
+ tensor.element_type,
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
+ tensor.memspace,
+ assumed_align=tensor.iterator.max_alignment,
+ )
+ return cute.make_tensor(new_ptr, tensor.layout)
+
+
+@dsl_user_op
+def coord_offset_i64(
+ idx: cute.typing.Int,
+ tensor: cute.Tensor,
+ dim: int,
+ *,
+ loc=None,
+ ip=None,
+) -> cute.Tensor:
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
+ assert isinstance(tensor.iterator, cute.Pointer)
+ new_ptr = cute.make_ptr(
+ tensor.element_type,
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
+ tensor.memspace,
+ assumed_align=tensor.iterator.max_alignment,
+ )
+ return cute.make_tensor(new_ptr, tensor.layout)
+
+
+@cute.jit
+def fill_oob(
+ tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric
+) -> None:
+ """Fill out-of-bounds values in shared memory tensor."""
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
+ tXrX_fill.fill(fill_value)
+ for rest_v in cutlass.range_constexpr(const_expr(tXsX.shape[0][1])):
+ for rest_k in cutlass.range_constexpr(const_expr(tXsX.shape[2])):
+ if const_expr(tXpX is not None):
+ if not tXpX[rest_v, 0, rest_k]:
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
+ else:
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
+
+
+@dsl_user_op
+def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
+ """Pack two f32 values into a single i64.
+
+ This mirrors quack.utils.f32x2_to_i64 and is used by online_softmax_reduce
+ to store (max, sum_exp) pairs in an Int64 reduction buffer.
+ """
+ vec_f32x2 = vector.from_elements(
+ T.vector(2, T.f32()),
+ (a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)),
+ loc=loc,
+ ip=ip,
+ )
+ vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip)
+ res = cutlass.Int64(
+ vector.extract(
+ vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip
+ )
+ )
+ return res
+
+
+@dsl_user_op
+def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
+ """Unpack a single i64 into two f32 values, inverse of f32x2_to_i64."""
+ vec_i64x1 = vector.from_elements(
+ T.vector(1, T.i64()),
+ (c.ir_value(loc=loc, ip=ip),),
+ loc=loc,
+ ip=ip,
+ )
+ vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip)
+ res0 = Float32(
+ vector.extract(
+ vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
+ )
+ )
+ res1 = Float32(
+ vector.extract(
+ vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip
+ )
+ )
+ return res0, res1
+
+
+# -------------------------
+# Reduction helpers (from quack.reduce)
+# -------------------------
+
+
+@cute.jit
+def warp_reduce(
+ val: cute.TensorSSA | cute.Numeric,
+ op: Callable,
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
+) -> cute.TensorSSA | cute.Numeric:
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
+ res = cute.make_fragment(val.shape, val.dtype)
+ res.store(val)
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
+ res[i] = warp_reduce(res[i], op, width)
+ return res.load()
+ return cute.arch.warp_reduction(val, op, threads_in_group=width)
+
+
+@cute.jit
+def block_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ init_val: cute.Numeric = 0.0,
+) -> cute.Numeric:
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)."""
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ warps_per_row = cute.size(reduction_buffer.shape[1])
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if lane_idx == 0:
+ reduction_buffer[row_idx, col_idx] = val
+ cute.arch.barrier()
+ block_reduce_val = init_val
+ if lane_idx < warps_per_row:
+ block_reduce_val = reduction_buffer[row_idx, lane_idx]
+ return warp_reduce(block_reduce_val, op)
+
+
+@cute.jit
+def cluster_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ mbar_ptr: cute.Pointer,
+ init_val: cute.Numeric = 0.0,
+ phase: Optional[cutlass.Int32] = None,
+) -> cute.Numeric:
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))."""
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if warp_idx == 0:
+ with cute.arch.elect_one():
+ num_warps = rows_per_block * warps_per_row
+ cute.arch.mbarrier_arrive_and_expect_tx(
+ mbar_ptr,
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
+ )
+ if lane_idx < cluster_n:
+ store_shared_remote(
+ val,
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
+ mbar_ptr,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
+ block_reduce_val = init_val
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
+ for i in cutlass.range_constexpr(num_iter):
+ idx = lane_idx + i * cute.arch.WARP_SIZE
+ if idx < cute.size(reduction_buffer, mode=[1]):
+ block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
+ return warp_reduce(block_reduce_val, op)
+
+
+@cute.jit
+def block_or_cluster_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ mbar_ptr: Optional[cute.Pointer],
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+) -> cute.Numeric:
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
+ if cutlass.const_expr(mbar_ptr is None):
+ return block_reduce(val, op, reduction_buffer, init_val=init_val)
+ return cluster_reduce(
+ val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase
+ )
+
+
+@cute.jit
+def row_reduce(
+ x: cute.TensorSSA | cute.Numeric,
+ op: cute.ReductionOp,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+ hook_fn: Optional[Callable] = None,
+) -> cute.Numeric:
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))."""
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
+ else:
+ val = x
+ warp_op = {
+ cute.ReductionOp.ADD: operator.add,
+ cute.ReductionOp.MAX: cute.arch.fmax
+ if cutlass.const_expr(x.dtype == Float32)
+ else max,
+ cute.ReductionOp.MIN: min,
+ cute.ReductionOp.MUL: operator.mul,
+ }[op]
+ val = warp_reduce(
+ val,
+ warp_op,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ val = block_or_cluster_reduce(
+ val,
+ warp_op,
+ reduction_buffer,
+ mbar_ptr,
+ phase=phase,
+ init_val=init_val,
+ )
+ return val
+
+
+@cute.jit
+def row_reduce_add(
+ x: cute.TensorSSA | cute.Numeric,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+ hook_fn: Optional[Callable] = None,
+) -> cute.Numeric:
+ """Specialized row_reduce for ADD reductions.
+
+ This mirrors row_reduce but hardcodes the ADD operation so we avoid
+ dynamic dispatch on the reduction op. It is used by bandwidth-bound
+ kernels like RMSNorm backward where the reduction is always ADD in
+ Float32.
+ """
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
+ val = x.reduce(cute.ReductionOp.ADD, init_val=init_val, reduction_profile=0)
+ else:
+ val = x
+ val = warp_reduce(
+ val,
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ val = block_or_cluster_reduce(
+ val,
+ operator.add,
+ reduction_buffer,
+ mbar_ptr,
+ phase=phase,
+ init_val=init_val,
+ )
+ return val
+
+
+@cute.jit
+def online_softmax_reduce(
+ x: cute.TensorSSA,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ hook_fn: Optional[Callable] = None,
+ phase: Optional[cutlass.Int32] = None,
+ return_exp_x: bool = False,
+) -> tuple[Float32, Float32, Optional[cute.TensorSSA]]:
+ """Online softmax reduction over a row.
+
+ This mirrors quack.reduce.online_softmax_reduce and computes:
+ - max_x: row-wise maximum of x
+ - sum_exp_x: row-wise sum of exp(x - max_x)
+ - exp_x (optional): per-element exp(x - max_x_final) if return_exp_x is True
+ """
+ assert x.dtype == Float32, "x must be of type Float32"
+ # reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)
+ max_x = warp_reduce(
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
+ cute.arch.fmax,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ log2_e = math.log2(math.e)
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
+ sum_exp_x = warp_reduce(
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ assert reduction_buffer.element_type == cutlass.Int64, (
+ "reduction_buffer must be of type Int64"
+ )
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if cutlass.const_expr(mbar_ptr is None):
+ if lane_idx == 0:
+ reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
+ cute.arch.barrier()
+ max_x_single_warp = -Float32.inf
+ sum_exp_x = 0.0
+ if lane_idx < warps_per_row:
+ max_x_single_warp, sum_exp_x = i64_to_f32x2(
+ reduction_buffer[row_idx, lane_idx]
+ )
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
+ sum_exp_x *= cute.math.exp(
+ max_x_single_warp - max_x_final, fastmath=True
+ )
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
+ if cutlass.const_expr(return_exp_x):
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
+ max_x = max_x_final
+ else:
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ if warp_idx == 0:
+ with cute.arch.elect_one():
+ num_warps = rows_per_block * warps_per_row
+ cute.arch.mbarrier_arrive_and_expect_tx(
+ mbar_ptr,
+ num_warps
+ * cluster_n
+ * reduction_buffer.element_type.width
+ // 8,
+ )
+ if lane_idx < cluster_n:
+ store_shared_remote(
+ f32x2_to_i64(max_x, sum_exp_x),
+ elem_pointer(
+ reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
+ ),
+ mbar_ptr,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+ cute.arch.mbarrier_wait(
+ mbar_ptr, phase=phase if phase is not None else 0
+ )
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
+ max_x_single_warp.fill(-Float32.inf)
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
+ sum_exp_x_single_warp.fill(0.0)
+ for i in cutlass.range_constexpr(num_iter):
+ idx = lane_idx + i * cute.arch.WARP_SIZE
+ if idx < cute.size(reduction_buffer, mode=[1]):
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
+ reduction_buffer[row_idx, idx]
+ )
+ max_x_final = max_x_single_warp.load().reduce(
+ cute.ReductionOp.MAX,
+ init_val=-Float32.inf,
+ reduction_profile=0,
+ )
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
+ sum_exp_x = 0.0
+ for i in cutlass.range_constexpr(num_iter):
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
+ max_x_single_warp[i] - max_x_final,
+ fastmath=True,
+ )
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
+ if cutlass.const_expr(return_exp_x):
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
+ max_x = max_x_final
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
+
+
+# -------------------------
+# Copy helpers (minimal subset of quack.copy_utils)
+# -------------------------
+
+
+@dsl_user_op
+def get_copy_atom(
+ dtype: Type[cutlass.Numeric],
+ num_copy_elems: int,
+ is_async: bool = False,
+ *,
+ loc=None,
+ ip=None,
+) -> cute.CopyAtom:
+ from cutlass.cute.nvgpu import cpasync
+
+ # cp.async is limited to 128b per op; synchronous vectorized copies can go wider.
+ max_bits = const_expr(128 if is_async else 256)
+ num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
+ # Match Quack's default cp.async cache policy (leave cache_mode unspecified).
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
+ return cute.make_copy_atom(
+ copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip
+ )
+
+
+@dsl_user_op
+def copy(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+ loc=None,
+ ip=None,
+ **kwargs,
+) -> None:
+ copy_atom = get_copy_atom(
+ src.element_type, num_copy_elems, is_async, loc=loc, ip=ip
+ )
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
+
+
+# -------------------------
+# Reduction base (from quack.reduction_base)
+# -------------------------
+
+
+class ReductionBase:
+ def __init__(
+ self,
+ dtype: Type[cutlass.Numeric],
+ N: int,
+ stage: int,
+ reduction_dtype: Type[cutlass.Numeric] = cutlass.Float32,
+ ):
+ self.dtype = dtype
+ self.N = N
+ self.stage = stage
+ self.reduction_dtype = reduction_dtype
+
+ def _calculate_threads_per_row(self) -> int:
+ raise NotImplementedError()
+
+ def _set_cluster_n(self) -> None:
+ self.cluster_n = 1
+
+ def _get_num_threads(self) -> int:
+ return 128 if self.N <= 16384 else 256
+
+ def _get_tv_layout(
+ self, num_copy_bits: int = 128
+ ) -> Tuple[cute.Shape, cute.Layout]:
+ """Return (tiler_mn, tv_layout) for SM100 reduction kernels.
+
+ This intentionally mirrors Quack's `ReductionBase._get_tiled_copy(...)`:
+ - `tiler_mn` spans the full N range for the CTA, including any "K-loop"
+ repeats (`num_blocks_N`).
+ - `tv_layout` is the *tiled* thread/value layout used by CuTe's copy
+ partitioning (does **not** bake in `num_blocks_N`), matching
+ `quack.copy_utils.tiled_copy_2d(...).layout_tv_tiled`.
+ """
+ if num_copy_bits > 128:
+ raise ValueError(
+ f"num_copy_bits={num_copy_bits} exceeds 128b; Quack-style SM100 reduction "
+ "tiling assumes <=128b vectorization (cp.async and common CopyAtoms)."
+ )
+ vecsize = num_copy_bits // self.dtype.width
+ assert self.N % vecsize == 0, (
+ f"Input N {self.N} is not divisible by vector size {vecsize}"
+ )
+ num_threads = self._get_num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+
+ threads_per_row = self._calculate_threads_per_row()
+ self._set_cluster_n()
+ num_blocks_N = cute.ceil_div(
+ self.N // vecsize, threads_per_row * self.cluster_n
+ )
+ cols_per_block = num_threads // threads_per_row
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
+
+ # Construct the same tv layout that Quack gets from `tiled_copy_2d(...).layout_tv_tiled`.
+ copy_atom = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ self.dtype,
+ num_bits_per_copy=num_copy_bits,
+ )
+ thr_layout = cute.make_ordered_layout(
+ (cols_per_block, threads_per_row),
+ order=(1, 0),
+ )
+ val_layout = cute.make_layout((1, vecsize))
+ tv_layout = cute.make_tiled_copy_tv(
+ copy_atom, thr_layout, val_layout
+ ).layout_tv_tiled
+ return tiler_mn, tv_layout
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
+ # Mirror the allocation order used by the SM100 reduction kernels:
+ # 1) sX (byte_alignment=16)
+ # 2) reduction_buffer (byte_alignment=8)
+ # 3) mbar_ptr (Int64, 8B)
+ #
+ # CuTeDSL's SmemAllocator may insert padding between allocations to satisfy
+ # alignment. Be conservative and round up offsets accordingly so we never
+ # under-allocate dynamic shared memory.
+
+ def _align_up(x: int, align: int) -> int:
+ return ((x + align - 1) // align) * align
+
+ sx_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)))
+ reduction_bytes = int(
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ )
+ mbar_bytes = int(self.stage * (cutlass.Int64.width // 8))
+
+ offset = _align_up(sx_bytes, 16)
+ offset = _align_up(offset, 8) + reduction_bytes
+ offset = _align_up(offset, 8) + mbar_bytes
+ return int(offset)
+
+ def _get_reduction_buffer_layout(
+ self, tv_layout: cute.Layout, cluster_n: int
+ ) -> cute.Layout:
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ warps_per_row = (
+ num_warps
+ if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
+ )
+ return cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+
+ def _allocate_reduction_buffer_and_mbar(
+ self,
+ smem: cutlass.utils.SmemAllocator,
+ tv_layout: cute.Layout,
+ is_persistent: bool = False,
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype,
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
+ byte_alignment=8,
+ )
+ if cutlass.const_expr(self.cluster_n > 1):
+ mbar_ptr = smem.allocate_array(
+ cutlass.Int64,
+ num_elems=self.stage if not is_persistent else self.stage * 2,
+ )
+ else:
+ mbar_ptr = None
+ return reduction_buffer, mbar_ptr
+
+ @cute.jit
+ def _initialize_cluster(
+ self,
+ tidx: cutlass.Int32,
+ mbar_ptr: Optional[cute.Pointer],
+ num_warps: int,
+ is_persistent: bool = False,
+ ) -> None:
+ if cutlass.const_expr(self.cluster_n > 1 and mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ if cutlass.const_expr(is_persistent):
+ cute.arch.mbarrier_init(
+ mbar_ptr + self.stage + tidx,
+ num_warps * self.cluster_n,
+ )
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
+# -------------------------
+# RMSNorm backward base (from quack.rmsnorm.RMSNormBackward)
+# -------------------------
+
+
+class RMSNormBackward(ReductionBase):
+ def __init__(self, dtype: cutlass.Numeric, N: int):
+ # 2 stages for double buffering when computing mean of x_hat * wdy
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
+ # Optional optimization: atomically accumulate mdW into a single (N,)
+ # buffer instead of writing an (sm_count, N) partial buffer + torch.sum.
+ self.atomic_dw = False
+ if self.N > 128 * 1024 and self.dtype.width >= 32:
+ raise ValueError(
+ "RMSNormBackward does not support N > 128k with dtype >= 32 bits"
+ )
+
+ def _get_num_threads(self) -> int:
+ return 128 if self.N <= 4096 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (
+ 32
+ if N <= 256
+ else (64 if N <= 512 else (128 if N <= 4096 else 256))
+ )
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ N = self.N
+ cluster_n = (
+ 1
+ if N <= 8 * 1024
+ else (
+ 2
+ if N <= 16 * 1024
+ else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int:
+ if do_dtype is None:
+ do_dtype = self.dtype
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
+ + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
+ + self.stage
+ * num_warps
+ * self.cluster_n
+ * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8) * 2
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ sm_count: Int32,
+ stream: cuda.CUstream,
+ ):
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mdO, mdResO, mdX, mdRes = [
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mdO, mdResO, mdX, mdRes)
+ ]
+ self._set_cluster_n()
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mW.element_type.width if mW is not None else 0,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ # Quack-style policy: cap the *largest* dtype to 128b, then scale the
+ # activation copy width down proportionally (e.g. fp16 + fp32-weight
+ # => 64b activation vectors so the fp32 path stays at 128b).
+ num_copy_bits = const_expr(128 // largest_dtype_width * mX.element_type.width)
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=int(num_copy_bits))
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ if const_expr(mW is not None):
+ mW_expanded_layout = cute.prepend(
+ mW.layout,
+ cute.make_layout((tiler_mn[0],), stride=(0,)),
+ )
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
+
+ num_blocks = sm_count
+ kernel = (
+ self.kernel(
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
+ )
+ kernel.launch(
+ grid=[num_blocks, self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
+ smem=self._smem_size_in_bytes(
+ tiler_mn, num_warps, do_dtype=mdO.element_type
+ ),
+ stream=stream,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx_start, _, _ = cute.arch.block_idx()
+ gdim, _, _ = cute.arch.grid_dim()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mX.shape
+ M = shape[0]
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+
+ idX = cute.make_identity_tensor(shape)
+
+ smem = cutlass.utils.SmemAllocator()
+ smem_layout = cute.make_ordered_layout(
+ (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)
+ )
+ sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
+ sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem,
+ tv_layout,
+ is_persistent=True,
+ )
+ if const_expr(mbar_ptr is not None):
+ mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
+ else:
+ mbar_full_ptr, mbar_empty_ptr = None, None
+
+ num_copy_elems_X = (
+ tv_layout.shape[1]
+ if cutlass.const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ copy_atom_load_X = get_copy_atom(
+ mX.element_type, num_copy_elems_X, is_async=False
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems_X))
+ thr_copy_X = cute.make_tiled_copy_tv(
+ copy_atom_load_X, thr_layout, val_layout
+ ).get_slice(tidx)
+ copy_fn = partial(copy, num_copy_elems=num_copy_elems_X)
+
+ gX, gdO, gdResO, gdX, gdRes, cX = [
+ cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
+ for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
+ ]
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
+ gdW, gdB = [
+ cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
+ if const_expr(mT is not None)
+ else None
+ for mT in (mdW, mdB)
+ ]
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXgdO = thr_copy_X.partition_S(gdO)
+ tXsdO = thr_copy_X.partition_D(sdO)
+ tXgdX = thr_copy_X.partition_D(gdX)
+ if const_expr(mdResO is not None):
+ tXgdResO = thr_copy_X.partition_S(gdResO)
+ if const_expr(mdRes is not None):
+ tXgdRes = thr_copy_X.partition_D(gdRes)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
+
+ tXrX, tXrdO, tXrdX = [
+ cute.make_fragment_like(thr[None, None, None, 0])
+ for thr in (tXgX, tXgdO, tXgdX)
+ ]
+ tXrdResO = None
+ if const_expr(mdResO is not None):
+ tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
+ tXrdRes = None
+ if const_expr(mdRes is not None):
+ tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
+
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+
+ tXgdW, tXrdW = None, None
+ tXgdB, tXrdB = None, None
+ if const_expr(mdW is not None):
+ tXgdW = thr_copy_X.partition_S(gdW)
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
+ if const_expr(mdB is not None):
+ tXgdB = thr_copy_X.partition_S(gdB)
+ tXrdB = cute.make_fragment_like(tXgdB, Float32)
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
+
+ tXrW = None
+ if const_expr(mW is not None):
+ tXgW = thr_copy_X.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if not is_even_N:
+ tXrW.fill(0.0)
+ copy_fn(tXgW, tXrW, pred=tXpX)
+
+ row = tXcX[None, None, None, bidx_start][0][0]
+ if row < M:
+ tXgX_cur = coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
+ tXgdO_cur = coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
+ copy_fn(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
+ copy_fn(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
+ elif tiler_mn[0] > 1:
+ fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
+ fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
+ cute.arch.cp_async_commit_group()
+
+ if const_expr(self.cluster_n > 1):
+ cute.arch.cluster_wait()
+
+ if const_expr(mdW is not None):
+ tXrdW.fill(0.0)
+ if const_expr(mdB is not None):
+ tXrdB.fill(0.0)
+ stage = Int32(0)
+ producer_phase = Int32(1)
+ consumer_phase = Int32(0)
+ for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
+ row = tXcX[None, None, None, bidx][0][0]
+ if row + gdim * tiler_mn[0] < M:
+ tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[
+ None, None, None, 0
+ ]
+ tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[
+ None, None, None, 0
+ ]
+ copy_fn(
+ tXgX_cur,
+ tXsX[None, None, None, stage ^ 1],
+ pred=tXpX,
+ is_async=True,
+ )
+ copy_fn(
+ tXgdO_cur,
+ tXsdO[None, None, None, stage ^ 1],
+ pred=tXpX,
+ is_async=True,
+ )
+ elif tiler_mn[0] > 1:
+ fill_oob(
+ tXsX[None, None, None, stage ^ 1],
+ None,
+ fill_value=mX.element_type.zero,
+ )
+ fill_oob(
+ tXsdO[None, None, None, stage ^ 1],
+ None,
+ fill_value=mdO.element_type.zero,
+ )
+ cute.arch.cp_async_commit_group()
+ rstd_val = cutlass.Float.zero
+ if row < M or tiler_mn[0] == 1:
+ rstd_val = mRstd[row]
+ if const_expr(mdResO is not None):
+ tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[
+ None, None, None, 0
+ ]
+ if row < M or tiler_mn[0] == 1:
+ copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX)
+ elif tiler_mn[0] > 1:
+ tXrdResO.fill(0.0)
+ cute.arch.cp_async_wait_group(1)
+ cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
+ x = tXrX.load().to(cute.Float32)
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
+ dout = tXrdO.load().to(cute.Float32)
+ x_hat = x * rstd_val
+ wdy = dout
+ if const_expr(mW is not None):
+ wdy *= tXrW.load().to(Float32)
+ if const_expr(self.cluster_n > 1):
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
+ mean_xhat_wdy = (
+ row_reduce_add(
+ x_hat * wdy,
+ threads_per_row,
+ reduction_buffer[None, None, stage],
+ (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
+ phase=consumer_phase,
+ init_val=0.0,
+ )
+ / shape[1]
+ )
+
+ if const_expr(self.cluster_n > 1):
+ cute.arch.fence_proxy(
+ cute.arch.ProxyKind.async_shared,
+ space=cute.arch.SharedSpace.shared_cta,
+ )
+ cute.arch.sync_warp()
+ lane_idx = cute.arch.lane_idx()
+ if lane_idx < self.cluster_n:
+ cute.arch.mbarrier_arrive(
+ mbar_empty_ptr + stage,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+
+ if const_expr(self.reload_wdy == "smem"):
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
+ dout = tXrdO.load().to(cute.Float32)
+ wdy = dout
+ if const_expr(mW is not None):
+ wdy *= tXrW.load().to(Float32)
+
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd_val
+ if const_expr(mdResO is not None):
+ dx += tXrdResO.load().to(cute.Float32)
+ tXrdX.store(dx.to(tXrdX.element_type))
+ if row < M or tiler_mn[0] == 1:
+ tXgdX_cur = coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
+ copy_fn(tXrdX, tXgdX_cur, pred=tXpX)
+ if const_expr(mdRes is not None):
+ tXrdRes.store(dx.to(tXrdRes.element_type))
+ tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[
+ None, None, None, 0
+ ]
+ if row < M or tiler_mn[0] == 1:
+ copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX)
+ if const_expr(mdW is not None):
+ tXrdW.store(tXrdW.load() + dout * x_hat)
+ if const_expr(mdB is not None):
+ tXrdB.store(tXrdB.load() + dout)
+
+ stage ^= 1
+ if stage == 0:
+ consumer_phase ^= 1
+ producer_phase ^= 1
+
+ if const_expr(tiler_mn[0] > 1):
+ if const_expr(mdW is not None):
+ sdW = cute.make_tensor(
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ tXsdW = thr_copy_X.partition_D(sdW)
+ cute.arch.barrier()
+ row0 = tXcX[None, None, None, 0][0][0]
+ if row0 > 0:
+ cute.autovec_copy(tXrdW, tXsdW)
+ cute.arch.barrier()
+ if row0 == 0:
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
+ tXrdW_other = cute.make_fragment_like(tXrdW)
+ tXsdW_other = cute.make_tensor(
+ tXsdW.iterator + i * sdW.stride[0],
+ tXsdW.layout,
+ )
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
+ if const_expr(self.atomic_dw):
+ atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX)
+ else:
+ copy_fn(tXrdW, tXgdW, pred=tXpX)
+ cute.arch.barrier()
+ if const_expr(mdB is not None):
+ sdB = cute.make_tensor(
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ tXsdB = thr_copy_X.partition_D(sdB)
+ cute.arch.barrier()
+ row0 = tXcX[None, None, None, 0][0][0]
+ if row0 > 0:
+ cute.autovec_copy(tXrdB, tXsdB)
+ cute.arch.barrier()
+ if row0 == 0:
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
+ tXrdB_other = cute.make_fragment_like(tXrdB)
+ tXsdB_other = cute.make_tensor(
+ tXsdB.iterator + i * sdB.stride[0],
+ tXsdB.layout,
+ )
+ cute.autovec_copy(tXsdB_other, tXrdB_other)
+ tXrdB.store(tXrdB.load() + tXrdB_other.load())
+ copy_fn(tXrdB, tXgdB, pred=tXpX)
+ else:
+ if const_expr(mdW is not None):
+ if const_expr(self.atomic_dw):
+ atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX)
+ else:
+ copy_fn(tXrdW, tXgdW, pred=tXpX)
+ if const_expr(mdB is not None):
+ copy_fn(tXrdB, tXgdB, pred=tXpX)
+
+ if const_expr(self.cluster_n > 1):
+ stage ^= 1
+ if stage == 0:
+ producer_phase ^= 1
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mdO,
+ mdResO,
+ mRstd,
+ mdX,
+ mdW,
+ mdB,
+ mdRes,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ ):
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
+ )
+ self._kernel_impl(
+ mX,
+ mW,
+ mdO,
+ mdResO,
+ mRstd,
+ mdX,
+ mdW,
+ mdB,
+ mdRes,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+# -------------------------
+# SM count helper (from quack.rmsnorm._get_sm_count)
+# -------------------------
+
+
+def get_sm_count(
+ N: int,
+ device: torch.device,
+ M: Optional[int] = None,
+ dtype: Optional[torch.dtype] = None,
+) -> int:
+ """
+ SM count heuristic for reduction-style kernels.
+
+ This starts from Quack's _get_sm_count policy and layers on SM100 /
+ DSv3-specific tuning so that:
+ - For DSv3-style shapes (large-M, N in {6144, 8192}, fp16/bf16),
+ sm_count is reduced for very large M to cut down the number of
+ dw_partial/db_partial rows that ever hit HBM.
+ - For Quack-suite hidden=4096, small-M shapes, sm_count is modestly
+ increased to improve SM occupancy, matching the existing SM100
+ tuning used by both RMSNorm and LayerNorm.
+ """
+ num_sms = get_num_sms(device)
+
+ sm_count_multiple = (
+ 16
+ if N <= 256
+ else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
+ )
+ sm_count = num_sms
+ if N <= 8192:
+ sm_count = sm_count * sm_count_multiple
+ elif N <= 16384:
+ sm_count = sm_count // 2
+ else:
+ sm_count = sm_count * 2
+
+ # Quack-suite tuning: for small-M, hidden=4096 shapes (M<=8192) and
+ # 16-bit dtypes, increase sm_count to improve occupancy. This mirrors
+ # the existing SM100 RMSNorm/LayerNorm heuristics.
+ if (
+ dtype in (torch.float16, torch.bfloat16)
+ and M is not None
+ and M <= 8192
+ and N == 4096
+ ):
+ sm_count = min(sm_count * 2, num_sms * 4)
+
+ return sm_count
diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
new file mode 100644
index 0000000..a96a4c7
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Torch custom ops wrapping Oink's Blackwell RMSNorm kernels.
+
+These ops are designed to be:
+- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back
+ to a safe reference elsewhere).
+- Layout-preserving for 2D row-major inputs, including padded MLA-style
+ layouts where stride(0) > N and stride(1) == 1.
+- torch.compile-friendly via proper fake implementations that mirror
+ runtime shapes and strides.
+
+Public ops (Python signatures):
+
+ torch.ops.oink.rmsnorm(x: Tensor, weight: Tensor, eps: float) -> Tensor
+ Functional RMSNorm. Returns a new tensor with the same shape and
+ stride as x when using the fast CuTeDSL path.
+
+ torch.ops.oink.fused_add_rms_norm(
+ x: Tensor, residual: Tensor, weight: Tensor, eps: float
+ ) -> None
+ In-place fused residual-add + RMSNorm matching vLLM semantics:
+ residual = x + residual (stored into `residual`)
+ x = RMSNorm(residual, w) (stored into `x`)
+ Mutates `x` and `residual` in-place and returns None.
+"""
+
+from __future__ import annotations
+
+import importlib
+import threading
+
+import torch
+from torch.library import custom_op
+
+_RMSNORM_MOD: object | None = None
+_RMSNORM_MOD_LOCK = threading.Lock()
+
+
+def _get_rmsnorm_mod():
+ """Lazy import to keep plugin registration lightweight.
+
+ Importing the CuTeDSL kernel stack can be expensive and may require a CUDA
+ context. We defer it until the first actual execution of the custom op.
+ """
+ global _RMSNORM_MOD
+
+ cached = _RMSNORM_MOD
+ if cached is not None:
+ return cached
+
+ with _RMSNORM_MOD_LOCK:
+ if _RMSNORM_MOD is None:
+ _RMSNORM_MOD = importlib.import_module("kernelagent_oink.blackwell.rmsnorm")
+ return _RMSNORM_MOD
+
+
+def _get_sm(device: torch.device | None = None) -> int:
+ """Return SM version as an int (e.g., 100 for SM100 / Blackwell)."""
+ if device is None:
+ device = torch.device("cuda")
+ major, minor = torch.cuda.get_device_capability(device)
+ return 10 * major + minor
+
+
+#
+# RMSNorm (functional)
+#
+
+
+@custom_op("oink::rmsnorm", mutates_args=())
+def oink_rmsnorm(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> torch.Tensor:
+ """
+ Functional RMSNorm entrypoint.
+
+ This op is model-agnostic. It expects a 2D [M, N] view of the input
+ where the last dimension is contiguous (stride(1) == 1). The leading
+ dimension stride(0) may be larger than N (padded-row layouts), and
+ will be preserved on the fast CuTeDSL path.
+
+ On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell
+ RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the
+ best internal schedule (including DSv3-specific stage-2 kernels where
+ applicable) and preserves the input's 2D stride when using the
+ pointer-based path.
+
+ On older architectures it falls back to a safe PyTorch reference
+ implementation for correctness.
+ """
+ assert x.is_cuda, "oink::rmsnorm requires CUDA tensors"
+ assert x.dim() == 2, "oink::rmsnorm expects a 2D [M, N] tensor view"
+ assert weight.dim() == 1, "weight must be 1D [N]"
+
+ sm = _get_sm(x.device)
+ if sm >= 100:
+ # Use the tuned CuTeDSL SM100 kernel. The public API already
+ # contains all necessary gating and layout checks internally.
+ _rms = _get_rmsnorm_mod()
+ y, _rstd, _res = _rms.rmsnorm_forward(
+ x,
+ weight=weight,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=False,
+ )
+ return y
+
+ # Fallback: reference implementation (correctness-first).
+ _rms = _get_rmsnorm_mod()
+ return _rms.rmsnorm_ref(
+ x,
+ w=weight,
+ b=None,
+ residual=None,
+ eps=eps,
+ )
+
+
+@oink_rmsnorm.register_fake
+def oink_rmsnorm_fake(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> torch.Tensor:
+ """
+ Fake (meta) implementation for oink::rmsnorm.
+
+ We must preserve x's logical layout (shape + stride) so that Inductor's
+ CUDA graph capture sees the same stride contract as the real kernel.
+ """
+ # x is a FakeTensor here; x.shape/x.stride()/x.device/x.dtype are defined.
+ return torch.empty_strided(
+ x.shape,
+ x.stride(),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+
+#
+# Fused residual-add + RMSNorm (in-place, vLLM semantics)
+#
+
+
+@custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual"))
+def oink_fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> None:
+ """
+ In-place fused residual-add + RMSNorm:
+
+ residual <- x + residual
+ x <- RMSNorm(residual, weight, eps)
+
+ Returns:
+ None (mutates `x` and `residual` in-place).
+ """
+ assert x.is_cuda and residual.is_cuda, (
+ "oink::fused_add_rms_norm requires CUDA tensors"
+ )
+ assert x.shape == residual.shape, "x and residual must have the same shape"
+ assert x.dtype == residual.dtype, "x and residual must have the same dtype"
+ assert weight.dim() == 1, "weight must be 1D [N]"
+
+ sm = _get_sm(x.device)
+ if sm >= 100:
+ _rms = _get_rmsnorm_mod()
+ # Prefer the lowest-overhead in-place entrypoint (returns None).
+ if hasattr(_rms, "fused_add_rmsnorm_inplace_"):
+ _rms.fused_add_rmsnorm_inplace_( # type: ignore[misc]
+ x,
+ residual,
+ weight,
+ eps=eps,
+ )
+ return None
+ # Backward-compatible wrapper (returns (x, residual)).
+ if hasattr(_rms, "fused_add_rmsnorm_forward_inplace"):
+ _rms.fused_add_rmsnorm_forward_inplace( # type: ignore[misc]
+ x,
+ residual,
+ weight,
+ eps=eps,
+ )
+ return None
+
+ # Extremely defensive fallback if the Oink module doesn't provide
+ # the in-place entrypoint.
+ y, z = _rms.fused_add_rmsnorm_forward(x, residual, weight, eps=eps)
+ x.copy_(y)
+ residual.copy_(z)
+ return None
+
+ # Non-SM100 fallback: keep semantics in-place (correctness-first).
+ residual.add_(x)
+ _rms = _get_rmsnorm_mod()
+ y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps)
+ x.copy_(y)
+ return None
+
+
+@oink_fused_add_rms_norm.register_fake
+def oink_fused_add_rms_norm_fake(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> None:
+ """
+ Fake (meta) implementation for oink::fused_add_rms_norm.
+
+ Because this op mutates its inputs in-place, the outputs alias the input
+ buffers and therefore have the same shapes and strides.
+ """
+ return None
+
+
+__all__ = [
+ "oink_rmsnorm",
+ "oink_fused_add_rms_norm",
+]
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
new file mode 100644
index 0000000..252df6a
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -0,0 +1,4418 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+RMSNorm kernel for SM100 (Blackwell) in CuteDSL.
+
+This implementation targets Blackwell with:
+- A stride-preserving pointer path for padded-row layouts (e.g. MLA stride0> N).
+- A one-pass fused-add RMSNorm schedule for bf16/fp16 (DSv3 N=7168) that keeps
+ `x + residual` in registers (avoids re-reading gmem) and uses FP32 accumulation.
+- Optional experimental schedule knobs (env vars) to explore copy widths and
+ stage-2 cp.async variants.
+
+Note: This file expects the local CuTeDSL (cutlass) and SM100 helper modules
+to be available in the Python environment (e.g., `nvidia-cutlass-dsl` and
+`cuda-python`). It is shipped as part of the KernelAgent Oink vLLM plugin.
+"""
+
+from __future__ import annotations
+
+import ctypes
+import importlib.metadata
+import os
+import re
+import subprocess
+import sys
+import threading
+from typing import Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions (e.g. 4.3.2 vs 4.3.4), and cross-version cache
+# sharing causes noisy "invalid section ID" warnings (and disables cache reuse).
+#
+# If the user has not pinned `CUTE_DSL_CACHE_DIR`, isolate by version so multiple
+# CuTeDSL envs can coexist on the same machine without stepping on each other.
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.rmsnorm requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import torch # noqa: E402
+from torch import Tensor # noqa: E402
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python # noqa: E402
+
+import cutlass # noqa: E402
+import cutlass.cute as cute # noqa: E402
+from cutlass import Float32, Int32, const_expr # noqa: E402
+from cutlass.cute import runtime as rt # noqa: E402
+from cutlass.cute.runtime import from_dlpack # noqa: E402
+
+# Simple compile cache declared early so direct execution works
+_PTR_COMPILE_CACHE = {}
+
+# Thread-local cache for the fast-launch path. We keep per-thread packed args and
+# pointer/scalar storage so concurrent callers don't race on in-place updates.
+_PTR_FAST_LAUNCH_TLS = threading.local()
+
+
+# Cache a (1, sm_count) fp32 ones row used for GEMM-based dw/db partial reductions.
+#
+# On SM100, `dw_partial.sum(dim=0)` can be a double-digit microsecond tail for
+# Quack-suite small shapes (e.g. M=8192, N=4096). A cached GEMM-based reduction
+# is consistently faster and avoids per-call allocation overhead.
+_DW_REDUCE_ONES_CACHE: dict[tuple[int, int], Tensor] = {}
+
+
+def _get_dw_reduce_ones(device_index: int, sm_count: int) -> Tensor:
+ key = (int(device_index), int(sm_count))
+ ones = _DW_REDUCE_ONES_CACHE.get(key)
+ if ones is None or ones.shape != (1, sm_count) or ones.device.index != device_index:
+ ones = torch.ones(
+ (1, sm_count),
+ device=torch.device("cuda", device_index),
+ dtype=torch.float32,
+ )
+ _DW_REDUCE_ONES_CACHE[key] = ones
+ return ones
+
+
+def _reduce_partial_sum_fp32(partial: Tensor, *, device_index: int) -> Tensor:
+ """Reduce a (sm_count, N) fp32 partial buffer into an (N,) fp32 result."""
+ assert partial.dtype is torch.float32
+ assert partial.dim() == 2
+ ones = _get_dw_reduce_ones(device_index, int(partial.shape[0]))
+ return torch.mm(ones, partial).squeeze(0)
+
+
+def _env_flag(name: str, default: bool) -> bool:
+ val = os.environ.get(name)
+ if val is None:
+ return default
+ return val.strip().lower() not in {"0", "false", "no", "off", ""}
+
+
+# Fast-launch uses a few private-ish CuTeDSL internals (packed args plumbing and
+# runtime pointer descriptors). Keep it enabled by default for our pinned CuTeDSL
+# environment, but allow disabling it via env var and auto-disable it if those
+# internals are not present in a future upgrade.
+_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True)
+_FAST_LAUNCH_SUPPORTED = True
+
+# Fused-add RMSNorm schedule knobs (read once at import time; set env vars before
+# importing this module if you want to override).
+_DIRECT_GMEM_POLICY = (
+ os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto"
+)
+_COPY_BITS_POLICY = (
+ os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto"
+)
+_ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False)
+_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag(
+ "OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False
+)
+_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False)
+_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False)
+
+# Forward dispatch control:
+# - Default behavior: use the pointer-based path when safe, otherwise fall back
+# to the stage-2 module (then the torch reference).
+# - If you want to force stage-2 even when the pointer path is available (for
+# experimentation / A-B testing), set this env var **before** importing this
+# module.
+_FORCE_RMSNORM_STAGE2_FWD = _env_flag(
+ "KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False
+)
+
+# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule.
+#
+# Some CuTeDSL builds segfault during JIT compilation when combining:
+# - cluster launches (cluster_n>1) and
+# - direct-GMEM loads/stores (no staging SMEM tiles).
+#
+# We keep the schedule gated behind `OINK_RMSNORM_ENABLE_CLUSTER_ILP=1` +
+# `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`, and additionally run a one-time
+# out-of-process compile probe so we can safely fall back to the staged SMEM
+# path instead of crashing the parent process.
+#
+# This is (currently) sensitive to the vector width: we have observed
+# reproducible segfaults for the 256b universal-copy path, while the 128b path
+# can succeed. Cache the maximum supported copy width (0 = unsupported).
+_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS: Optional[int] = None
+_CLUSTER_DIRECT_GMEM_PROBE_LOCK = threading.Lock()
+_CLUSTER_DIRECT_GMEM_PROBE_WARNED = False
+
+
+def _probe_cluster_direct_gmem_max_copy_bits() -> int:
+ global _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+ global _CLUSTER_DIRECT_GMEM_PROBE_WARNED
+
+ override = os.environ.get("OINK_RMSNORM_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS")
+ if override is not None and override.strip() != "":
+ try:
+ value = int(override)
+ except ValueError:
+ value = 0
+ value = 256 if value >= 256 else 128 if value >= 128 else 0
+ _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = value
+ return value
+
+ if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None:
+ return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+
+ with _CLUSTER_DIRECT_GMEM_PROBE_LOCK:
+ if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None:
+ return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+
+ script_template = r"""
+import os
+
+os.environ["OINK_CUTEDSL_FAST_LAUNCH"] = "0"
+
+import cutlass
+import cutlass.cute as cute
+import cuda.bindings.driver as cuda
+from cutlass import Float32, Int32
+from cutlass.cute import runtime as rt
+
+from kernelagent_oink.blackwell import rmsnorm
+
+N = 7168
+dtype = cutlass.BFloat16
+
+copy_bits = int(os.environ["OINK_PROBE_COPY_BITS"])
+assumed_align = int(os.environ["OINK_PROBE_ASSUMED_ALIGN"])
+
+op = rmsnorm.RMSNormSM100(
+ N,
+ dtype,
+ stage=1,
+ copy_bits=copy_bits,
+ use_async=False,
+ direct_gmem=True,
+)
+op._cluster_n_override = 2 # 2 CTAs per row
+
+ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+
+_ = cute.compile(
+ op.launch_from_ptrs_fused_add_inplace,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(4096),
+ Int32(N),
+ Int32(N),
+ cuda.CUstream(0),
+ Float32(1e-6),
+)
+print(f"ok {copy_bits}")
+"""
+
+ env = os.environ.copy()
+ # The probe runs in a fresh subprocess, so it won't inherit any
+ # benchmark-harness sys.path tweaks. Ensure the in-tree Oink source is
+ # importable so `import kernelagent_oink...` works reliably.
+ oink_src = os.path.abspath(os.path.join(_HERE, "..", ".."))
+ if os.path.isdir(oink_src):
+ py_path = env.get("PYTHONPATH")
+ env["PYTHONPATH"] = oink_src + (os.pathsep + py_path if py_path else "")
+ env["PYTHONNOUSERSITE"] = "1"
+
+ def run_probe(copy_bits: int, assumed_align: int):
+ probe_env = env.copy()
+ probe_env["OINK_PROBE_COPY_BITS"] = str(copy_bits)
+ probe_env["OINK_PROBE_ASSUMED_ALIGN"] = str(assumed_align)
+ return subprocess.run(
+ [sys.executable, "-c", script_template],
+ env=probe_env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ timeout=120.0,
+ )
+
+ proc_256 = None
+ proc_128 = None
+ try:
+ proc_256 = run_probe(256, 32)
+ if proc_256.returncode == 0:
+ max_bits = 256
+ else:
+ proc_128 = run_probe(128, 16)
+ max_bits = 128 if proc_128.returncode == 0 else 0
+ except Exception:
+ max_bits = 0
+
+ if not _CLUSTER_DIRECT_GMEM_PROBE_WARNED and max_bits != 256:
+ _CLUSTER_DIRECT_GMEM_PROBE_WARNED = True
+ if max_bits == 128:
+ print(
+ "Oink: cluster_n>1 + direct_gmem 256b compile probe failed; "
+ "using 128b copies for the cluster ILP schedule.",
+ file=sys.stderr,
+ )
+ if proc_256 is not None and proc_256.stderr:
+ tail = "\n".join(proc_256.stderr.splitlines()[-12:])
+ print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr)
+ else:
+ rc = (
+ proc_128.returncode
+ if proc_128 is not None
+ else (proc_256.returncode if proc_256 is not None else "unknown")
+ )
+ print(
+ "Oink: cluster_n>1 + direct_gmem compile probe failed; "
+ f"falling back to staged SMEM path (returncode={rc}).",
+ file=sys.stderr,
+ )
+ failing_proc = proc_128 if proc_128 is not None else proc_256
+ if failing_proc is not None and failing_proc.stderr:
+ tail = "\n".join(failing_proc.stderr.splitlines()[-12:])
+ print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr)
+
+ _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits
+ return max_bits
+
+
+def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we
+# detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
+
+if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE:
+ # We have observed reproducible segfaults in some CuTeDSL builds when using
+ # cluster launches for this schedule. Require an explicit UNSAFE opt-in to
+ # avoid accidental crashes.
+ _ENABLE_CLUSTER_ILP = False
+ print(
+ "Oink: OINK_RMSNORM_ENABLE_CLUSTER_ILP requested but disabled by default due to "
+ "known instability; set OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1 to force-enable.",
+ file=sys.stderr,
+ )
+
+
+def _fast_launch_enabled() -> bool:
+ return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED
+
+
+def _direct_gmem_from_policy(*, default: bool) -> bool:
+ """Resolve the direct-GMEM schedule flag from the (import-time) policy string."""
+ if _DIRECT_GMEM_POLICY in {"0", "false", "no", "off"}:
+ return False
+ if _DIRECT_GMEM_POLICY in {"1", "true", "yes", "on"}:
+ return True
+ return default
+
+
+def _copy_bits_from_policy(*, default: int, can_use_256: bool) -> int:
+ """Resolve copy width (in bits) from the (import-time) policy string."""
+ if _COPY_BITS_POLICY in {"64"}:
+ return 64
+ if _COPY_BITS_POLICY in {"128"}:
+ return 128
+ if _COPY_BITS_POLICY in {"256"} and can_use_256:
+ return 256
+ return default
+
+
+class _StableI32Arg:
+ """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: int):
+ self._c_value = ctypes.c_int32(int(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: int) -> None:
+ self._c_value.value = int(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+class _StableF32Arg:
+ """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: float):
+ self._c_value = ctypes.c_float(float(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: float) -> None:
+ self._c_value.value = float(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+def _tls_fast_launch_cache() -> dict[tuple[object, ...], object]:
+ cache = getattr(_PTR_FAST_LAUNCH_TLS, "cache", None)
+ if cache is None:
+ cache = {}
+ _PTR_FAST_LAUNCH_TLS.cache = cache
+ return cache
+
+
+def _set_runtime_ptr(ptr: object, device_ptr: int) -> None:
+ # Runtime pointer objects cache a `ctypes.c_void_p` descriptor and pass
+ # its address to the compiled function. Updating `_desc.value` updates
+ # the device pointer without changing the address of the descriptor.
+ #
+ # This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`,
+ # etc.). If these internals change in a future CuTeDSL upgrade, callers
+ # should catch AttributeError and fall back to the regular launch path.
+ device_ptr = int(device_ptr)
+ ptr._pointer = device_ptr # type: ignore[attr-defined]
+ if getattr(ptr, "_c_pointer", None) is None:
+ ptr.__c_pointers__() # type: ignore[attr-defined]
+ ptr._desc.value = device_ptr # type: ignore[attr-defined]
+
+
+class _PtrRmsnormFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: Optional[object],
+ ptr_out: object,
+ arg_m: _StableI32Arg,
+ arg_n: _StableI32Arg,
+ arg_ld: _StableI32Arg,
+ arg_eps: _StableF32Arg,
+ stream: cuda.CUstream,
+ assumed_align: int,
+ weight_dtype: Optional[type[cutlass.Numeric]],
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_out = ptr_out
+ self._arg_m = arg_m
+ self._arg_n = arg_n
+ self._arg_ld = arg_ld
+ self._arg_eps = arg_eps
+ self._stream = stream
+ self._assumed_align = int(assumed_align)
+ self._weight_dtype = weight_dtype
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_out_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+ self._last_eps = float("nan")
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ out: Tensor,
+ M: int,
+ N: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ if not _fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
+ return
+
+ if self._ptr_w is not None:
+ w_ptr = weight.data_ptr() # type: ignore[union-attr]
+ if w_ptr != self._last_w_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
+ return
+
+ out_ptr = out.data_ptr()
+ if out_ptr != self._last_out_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_out, out_ptr)
+ self._last_out_ptr = out_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+ if eps != self._last_eps:
+ self._arg_eps.set(eps)
+ self._last_eps = eps
+
+ # Clear the error slot before launch (mirrors JitExecutor behavior).
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ self._use_fast_launch = False
+ _FAST_LAUNCH_SUPPORTED = False
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ out: Tensor,
+ M: int,
+ N: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ # If the packed-args or runtime pointer mutation path stops working
+ # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path.
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_out = rt.make_ptr(
+ dtype,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_w = (
+ rt.make_ptr(
+ self._weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ if weight is not None
+ else None
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ Int32(ld),
+ self._stream,
+ Float32(eps),
+ )
+
+
+class _PtrFusedAddRmsnormFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: object,
+ ptr_res: object,
+ arg_m: _StableI32Arg,
+ arg_n: _StableI32Arg,
+ arg_ld_x: _StableI32Arg,
+ arg_eps: _StableF32Arg,
+ stream: cuda.CUstream,
+ assumed_align: int,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_res = ptr_res
+ self._arg_m = arg_m
+ self._arg_n = arg_n
+ self._arg_ld_x = arg_ld_x
+ self._arg_eps = arg_eps
+ self._stream = stream
+ self._assumed_align = int(assumed_align)
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_res_ptr = -1
+ self._last_m = -1
+ self._last_ld_x = -1
+ self._last_eps = float("nan")
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+ M: int,
+ N: int,
+ ld_x: int,
+ eps: float,
+ ) -> None:
+ if not _fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ w_ptr = weight.data_ptr()
+ if w_ptr != self._last_w_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ res_ptr = residual.data_ptr()
+ if res_ptr != self._last_res_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_res, res_ptr)
+ self._last_res_ptr = res_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld_x != self._last_ld_x:
+ self._arg_ld_x.set(ld_x)
+ self._last_ld_x = ld_x
+ if eps != self._last_eps:
+ self._arg_eps.set(eps)
+ self._last_eps = eps
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ self._use_fast_launch = False
+ _FAST_LAUNCH_SUPPORTED = False
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+ M: int,
+ N: int,
+ ld_x: int,
+ eps: float,
+ ) -> None:
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_res = rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_w = rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(M),
+ Int32(N),
+ Int32(ld_x),
+ self._stream,
+ Float32(eps),
+ )
+
+
+class _PtrRmsnormBwdFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: Optional[object],
+ ptr_dout: object,
+ ptr_rstd: object,
+ ptr_dx: object,
+ ptr_dw_partial: Optional[object],
+ arg_m: _StableI32Arg,
+ arg_n: _StableI32Arg,
+ arg_ld: _StableI32Arg,
+ arg_sm_count: _StableI32Arg,
+ stream: cuda.CUstream,
+ assumed_align_x: int,
+ assumed_align_w: int,
+ assumed_align_dw: int,
+ weight_dtype: Optional[type[cutlass.Numeric]],
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_dout = ptr_dout
+ self._ptr_rstd = ptr_rstd
+ self._ptr_dx = ptr_dx
+ self._ptr_dw_partial = ptr_dw_partial
+ self._arg_m = arg_m
+ self._arg_n = arg_n
+ self._arg_ld = arg_ld
+ self._arg_sm_count = arg_sm_count
+ self._stream = stream
+ self._assumed_align_x = int(assumed_align_x)
+ self._assumed_align_w = int(assumed_align_w)
+ self._assumed_align_dw = int(assumed_align_dw)
+ self._weight_dtype = weight_dtype
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_dout_ptr = -1
+ self._last_rstd_ptr = -1
+ self._last_dx_ptr = -1
+ self._last_dw_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+ self._last_sm_count = -1
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dx: Tensor,
+ dw_partial: Optional[Tensor],
+ M: int,
+ N: int,
+ ld: int,
+ sm_count: int,
+ ) -> None:
+ if not _fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ if self._ptr_w is not None:
+ w_ptr = weight.data_ptr() # type: ignore[union-attr]
+ if w_ptr != self._last_w_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ dout_ptr = dout.data_ptr()
+ if dout_ptr != self._last_dout_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_dout, dout_ptr)
+ self._last_dout_ptr = dout_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ rstd_ptr = rstd.data_ptr()
+ if rstd_ptr != self._last_rstd_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_rstd, rstd_ptr)
+ self._last_rstd_ptr = rstd_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ dx_ptr = dx.data_ptr()
+ if dx_ptr != self._last_dx_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_dx, dx_ptr)
+ self._last_dx_ptr = dx_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ if self._ptr_dw_partial is not None:
+ dw_ptr = dw_partial.data_ptr() # type: ignore[union-attr]
+ if dw_ptr != self._last_dw_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_dw_partial, dw_ptr)
+ self._last_dw_ptr = dw_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld,
+ sm_count=sm_count,
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+ if sm_count != self._last_sm_count:
+ self._arg_sm_count.set(sm_count)
+ self._last_sm_count = sm_count
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ self._use_fast_launch = False
+ _FAST_LAUNCH_SUPPORTED = False
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dx: Tensor,
+ dw_partial: Optional[Tensor],
+ M: int,
+ N: int,
+ ld: int,
+ sm_count: int,
+ ) -> None:
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_x,
+ )
+ ptr_dout = rt.make_ptr(
+ dtype,
+ dout.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_x,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype,
+ dx.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_x,
+ )
+ ptr_rstd = rt.make_ptr(
+ TORCH2CUTE_DTYPE[rstd.dtype],
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_x,
+ )
+ ptr_w = (
+ rt.make_ptr(
+ self._weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_w,
+ )
+ if weight is not None
+ else None
+ )
+ ptr_dw_partial = (
+ rt.make_ptr(
+ TORCH2CUTE_DTYPE[dw_partial.dtype],
+ dw_partial.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align_dw,
+ )
+ if dw_partial is not None
+ else None
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ ptr_dout,
+ ptr_rstd,
+ ptr_dx,
+ ptr_dw_partial,
+ Int32(M),
+ Int32(N),
+ Int32(ld),
+ Int32(sm_count),
+ self._stream,
+ )
+
+
+def _get_fast_ptr_rmsnorm_bwd_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ weight_dtype: Optional[type[cutlass.Numeric]],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ has_weight: bool,
+ has_dw_partial: bool,
+ assumed_align_x: int,
+ assumed_align_w: int,
+ assumed_align_dw: int,
+) -> Optional[_PtrRmsnormBwdFastLaunch]:
+ if not _fast_launch_enabled():
+ return None
+ key = (
+ "ptr_bwd_fast",
+ id(compiled),
+ N,
+ dtype,
+ weight_dtype,
+ device_index,
+ int(stream_handle),
+ has_weight,
+ has_dw_partial,
+ int(assumed_align_x),
+ int(assumed_align_w),
+ int(assumed_align_dw),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ assumed_align_x = int(assumed_align_x)
+ assumed_align_w = int(assumed_align_w)
+ assumed_align_dw = int(assumed_align_dw)
+
+ ptr_x = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ 0,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_w,
+ )
+ if has_weight
+ else None
+ )
+ ptr_dout = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x
+ )
+ ptr_rstd = rt.make_ptr(
+ cutlass.Float32,
+ 0,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x
+ )
+ ptr_dw_partial = (
+ rt.make_ptr(
+ cutlass.Float32,
+ 0,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_dw,
+ )
+ if has_dw_partial
+ else None
+ )
+
+ arg_m = _StableI32Arg(0)
+ arg_n = _StableI32Arg(N)
+ arg_ld = _StableI32Arg(N)
+ arg_sm_count = _StableI32Arg(0)
+ stream = cuda.CUstream(int(stream_handle))
+
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ ptr_dout,
+ ptr_rstd,
+ ptr_dx,
+ ptr_dw_partial,
+ arg_m,
+ arg_n,
+ arg_ld,
+ arg_sm_count,
+ stream,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_dout,
+ ptr_rstd,
+ ptr_dx,
+ ptr_dw_partial,
+ arg_m,
+ arg_n,
+ arg_ld,
+ arg_sm_count,
+ stream,
+ *adapted_args,
+ )
+
+ launcher = _PtrRmsnormBwdFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_dout=ptr_dout,
+ ptr_rstd=ptr_rstd,
+ ptr_dx=ptr_dx,
+ ptr_dw_partial=ptr_dw_partial,
+ arg_m=arg_m,
+ arg_n=arg_n,
+ arg_ld=arg_ld,
+ arg_sm_count=arg_sm_count,
+ stream=stream,
+ assumed_align_x=assumed_align_x,
+ assumed_align_w=assumed_align_w,
+ assumed_align_dw=assumed_align_dw,
+ weight_dtype=weight_dtype if has_weight else None,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+def _get_fast_ptr_rmsnorm_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ weight_dtype: Optional[type[cutlass.Numeric]] = None,
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ has_weight: bool,
+ assumed_align: int = 16,
+ eps: float,
+) -> Optional[_PtrRmsnormFastLaunch]:
+ if not _fast_launch_enabled():
+ return None
+ # Keyed by the compiled object identity so schedule changes (e.g. copy width,
+ # async/staged variants, etc.) never alias in the fast-launch cache.
+ key = (
+ "ptr_fast",
+ id(compiled),
+ N,
+ dtype,
+ weight_dtype,
+ device_index,
+ int(stream_handle),
+ has_weight,
+ int(assumed_align),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ # Create stable runtime args and pointer descriptors once.
+ assumed_align = int(assumed_align)
+ ptr_x = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_out = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ 0,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if has_weight
+ else None
+ )
+
+ arg_m = _StableI32Arg(0)
+ arg_n = _StableI32Arg(N)
+ arg_ld = _StableI32Arg(N)
+ arg_eps = _StableF32Arg(eps)
+
+ stream = cuda.CUstream(int(stream_handle))
+
+ # Create an executor (loads the CUDA library once).
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ # Use generate_execution_args once to build the packed args array, and keep
+ # any adapted args alive for the lifetime of the cache entry.
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ arg_m,
+ arg_n,
+ arg_ld,
+ stream,
+ arg_eps,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_out,
+ arg_m,
+ arg_n,
+ arg_ld,
+ arg_eps,
+ stream,
+ *adapted_args,
+ )
+
+ launcher = _PtrRmsnormFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_out=ptr_out,
+ arg_m=arg_m,
+ arg_n=arg_n,
+ arg_ld=arg_ld,
+ arg_eps=arg_eps,
+ stream=stream,
+ assumed_align=assumed_align,
+ weight_dtype=weight_dtype if has_weight else None,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+def _get_fast_ptr_fused_add_rmsnorm_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ copy_bits: int,
+ use_async: bool,
+ tpr: int,
+ direct_gmem: bool,
+ assumed_align: int,
+ eps: float,
+) -> Optional[_PtrFusedAddRmsnormFastLaunch]:
+ if not _fast_launch_enabled():
+ return None
+ key = (
+ "ptr_fused_add_fast",
+ id(compiled),
+ N,
+ dtype,
+ device_index,
+ int(stream_handle),
+ int(copy_bits),
+ bool(use_async),
+ int(tpr),
+ bool(direct_gmem),
+ int(assumed_align),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ ptr_x = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_res = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_w = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+
+ arg_m = _StableI32Arg(0)
+ arg_n = _StableI32Arg(N)
+ arg_ld_x = _StableI32Arg(N)
+ arg_eps = _StableF32Arg(eps)
+
+ stream = cuda.CUstream(int(stream_handle))
+
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ arg_m,
+ arg_n,
+ arg_ld_x,
+ stream,
+ arg_eps,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ arg_m,
+ arg_n,
+ arg_ld_x,
+ arg_eps,
+ stream,
+ *adapted_args,
+ )
+
+ launcher = _PtrFusedAddRmsnormFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_res=ptr_res,
+ arg_m=arg_m,
+ arg_n=arg_n,
+ arg_ld_x=arg_ld_x,
+ arg_eps=arg_eps,
+ stream=stream,
+ assumed_align=assumed_align,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+# Local helpers for reduction, dtype mapping, and coordinate/predicate utilities.
+#
+# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may
+# mishandle that form (module=None in the AST). Use fully-qualified imports.
+from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402
+from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402
+ TORCH2CUTE_DTYPE,
+ RMSNormBackward as BaseRMSNormBackward,
+ convert_from_dlpack as convert_from_dlpack_cute,
+ get_sm_count,
+ row_reduce,
+)
+
+
+# -------------------------
+# Copy helpers (allow up to 256b)
+# -------------------------
+
+
+@cute.jit
+def get_copy_atom_bw(
+ dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False
+) -> cute.CopyAtom:
+ # cp.async (SIMT) supports up to 128b per op; use 256b for sync when possible
+ max_bits = const_expr(128 if is_async else 256)
+ num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
+ from cutlass.cute.nvgpu import cpasync
+
+ # Prefer GLOBAL cache policy for bulk streaming reads at large M.
+ copy_op = (
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL)
+ if is_async
+ else cute.nvgpu.CopyUniversalOp()
+ )
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
+
+
+@cute.jit
+def copy_tiled(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+) -> None:
+ atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async)
+ cute.copy(atom, src, dst, pred=pred)
+
+
+# -------------------------
+# RMSNorm Kernel (SM100)
+# -------------------------
+
+
+class RMSNormSM100:
+ def __init__(
+ self,
+ N: int,
+ dtype: type[cutlass.Numeric],
+ stage: Optional[int] = None,
+ *,
+ copy_bits: int = 128,
+ use_async: bool = True,
+ direct_gmem: bool = False,
+ ):
+ self.N = N
+ self.dtype = dtype
+ # Match Quack default for RMSNorm: stage = 1 unless explicitly overridden
+ self.stage = 1 if stage is None else stage
+ self.reduction_dtype = cutlass.Float32
+ self.copy_bits = int(copy_bits)
+ self.use_async = bool(use_async)
+ self.direct_gmem = bool(direct_gmem)
+
+ def _threads_per_row(self) -> int:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ pass
+ # Tune mid-size buckets for large-M rows.
+ N = self.N
+ # DSv3 MLA (padded/strided) hot shape. Prefer a threads-per-row that
+ # makes the tile width exactly match N with 128b vectors (bf16/fp16),
+ # avoiding the ~33% padded work from rounding 1536 -> 2048.
+ if N == 1536 and self.dtype.width == 16:
+ return 96
+ # DSv3 default hidden size (7168). Choose a threads-per-row that matches
+ # the selected vector width to avoid padded work. Using 224 threads/row
+ # yields exact tiles for all supported copy widths we use on SM100:
+ # - 64b copies (vec=4 for bf16/fp16): 7168/4 = 1792 = 8 * 224
+ # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 4 * 224
+ # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224
+ #
+ if N == 7168 and self.dtype.width == 16:
+ return 224
+ # DSv3-ish N buckets (6144/8192): use larger threads/row so each thread
+ # holds fewer elements in registers. For 256b vectors, pick a threads/row
+ # that yields an exact tile without padding.
+ if self.dtype.width == 16:
+ if N == 6144:
+ if self.copy_bits >= 256:
+ return 192
+ if self.copy_bits <= 128:
+ return 256
+ if N == 8192:
+ return 256
+ # For small-N, use at least one full warp per row. The kernel
+ # implementation assumes one row per CTA; returning <32 here can
+ # produce multi-row tiles (cols_per_block > 1) which is not supported.
+ if N <= 1024:
+ return 32
+ elif N <= 4096:
+ return 128
+ elif N <= 8192:
+ # Allow an override (used by 2-rows/CTA path for N≈6k/8k)
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 128
+ elif N <= 16384:
+ return 256
+ else:
+ return 256
+
+ def _cluster_n(self) -> int:
+ try:
+ return self._cluster_n_override # type: ignore[attr-defined]
+ except Exception:
+ pass
+ N = self.N
+ # Default policy
+ if N <= 8192:
+ return 1
+ if const_expr(self.dtype.width == 16):
+ if N <= 16 * 1024:
+ return 2
+ elif N <= 32 * 1024:
+ return 2
+ elif N <= 64 * 1024:
+ return 4
+ elif N <= 128 * 1024:
+ return 8
+ else:
+ return 16
+ else:
+ if N <= 32 * 1024:
+ return 1
+ elif N <= 64 * 1024:
+ return 2
+ elif N <= 128 * 1024:
+ return 4
+ elif N <= 256 * 1024:
+ return 8
+ else:
+ return 16
+
+ def _num_threads(self) -> int:
+ # Favor 128 threads up to N=16k to reduce per-row partitioning overhead.
+ # This keeps cols_per_block=1 at N=8192 (bf16), which benchmarks faster for large-M.
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ if self.N == 1536 and self.dtype.width == 16:
+ return 96
+ if self.N == 7168 and self.dtype.width == 16:
+ return 224
+ if self.dtype.width == 16:
+ if self.N == 6144:
+ if self.copy_bits >= 256:
+ return 192
+ if self.copy_bits <= 128:
+ return 256
+ if self.N == 8192:
+ return 256
+ if self.N <= 1024:
+ return 32
+ return 128 if self.N <= 16384 else 256
+
+ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ num_threads = self._num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+ tpr = self._threads_per_row()
+ cluster_n = self._cluster_n()
+ # Allow tails: compute number of vector columns with ceil
+ num_cols_vec = cute.ceil_div(self.N, vecsize)
+ num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n)
+ cols_per_block = num_threads // tpr
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
+ tv_layout = cute.make_layout(
+ ((tpr, cols_per_block), (vecsize, num_blocks_N)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * tpr),
+ ),
+ )
+ return tiler_mn, tv_layout
+
+ def _smem_bytes(self, tiler_mn, num_warps) -> int:
+ # smem for X tile (+ residual if present) + reduction buffers + mbar(s)
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
+ + self.stage
+ * num_warps
+ * self._cluster_n()
+ * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8)
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ # Make last dim static (N)
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mRes, mO, mResO = [
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ copy_bits = int(self.copy_bits)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = (
+ tv_layout.shape[0][0]
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._threads_per_row()
+ )
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+
+ if const_expr(mW is not None):
+ mW = cute.make_tensor(
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+
+ # No SMEM reload mode switch; overlap is controlled in the K-loop path
+
+ # Compute smem usage considering staged buffers.
+ #
+ # In direct-gmem mode, we skip the gmem->smem tiles entirely and only
+ # keep the reduction buffers in shared memory.
+ stage_bufs = 2 if self.stage > 1 else 1
+ tile_bytes_x = (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ if const_expr(not self.direct_gmem)
+ else 0
+ )
+ tile_bytes_res = (
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn))
+ * stage_bufs
+ if const_expr(mRes is not None and not self.direct_gmem)
+ else 0
+ )
+ red_bytes = (
+ self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ )
+ # mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds
+ # require mbarrier state to be 16B-aligned in shared memory; account for
+ # the alignment padding when computing dynamic smem bytes.
+ smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes
+ if cluster_n > 1:
+ # Align up to 16B before placing the mbarrier array.
+ smem_bytes = ((smem_bytes + 15) // 16) * 16
+ smem_bytes += self.stage * (cutlass.Int64.width // 8)
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(cluster_n),
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=([1, cluster_n, 1] if cluster_n > 1 else None),
+ smem=smem_bytes,
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: Optional[cute.Pointer],
+ ptr_b: Optional[cute.Pointer],
+ ptr_res: Optional[cute.Pointer],
+ ptr_out: cute.Pointer,
+ ptr_res_out: Optional[cute.Pointer],
+ ptr_rstd: Optional[cute.Pointer],
+ M: Int32,
+ N_dyn: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ """Pointer-based entrypoint to reuse the existing RMSNorm schedule.
+
+ This reconstructs cute.Tensor views from raw pointers plus sizes,
+ avoiding any DLPack conversions at the Python boundary.
+ """
+ # Use a dynamic N for the leading-dimension stride so that the
+ # subsequent cute.assume(...) in __call__ sees a dynamic expression
+ # rather than a plain Python int.
+ # The compile-time N for the kernel (self.N) is still used to
+ # specialize the schedule.
+ # Assume row-major [M, N] with an arbitrary leading-dimension stride
+ # (common for padded-row / packed-attention layouts).
+ layout_mn = cute.make_layout((M, N_dyn), stride=(ld, 1))
+ layout_n = cute.make_layout((N_dyn,), stride=(1,))
+ layout_m = cute.make_layout((M,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+
+ mRes = (
+ cute.make_tensor(ptr_res, layout_mn)
+ if const_expr(ptr_res is not None)
+ else None
+ )
+ mResO = (
+ cute.make_tensor(ptr_res_out, layout_mn)
+ if const_expr(ptr_res_out is not None)
+ else None
+ )
+ mW = (
+ cute.make_tensor(ptr_w, layout_n) if const_expr(ptr_w is not None) else None
+ )
+ mB = (
+ cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
+ )
+ mRstd = (
+ cute.make_tensor(ptr_rstd, layout_m)
+ if const_expr(ptr_rstd is not None)
+ else None
+ )
+
+ # Reuse the main JIT entry to launch the scheduled kernel.
+ self.__call__(mX, mW, mB, mRes, mO, mResO, mRstd, stream, eps)
+
+ @cute.jit
+ def launch_from_ptrs_fused_add_inplace(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: cute.Pointer,
+ ptr_res: cute.Pointer,
+ M: Int32,
+ N_dyn: Int32,
+ ld_x: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ """Pointer-based entrypoint for vLLM-style fused_add_rms_norm semantics.
+
+ This specialized entrypoint supports:
+ - `x` / output with an arbitrary leading-dimension stride (`ld_x`), and
+ - `residual` / residual-out as a contiguous [M, N] tensor (ld_res = N).
+
+ Both `x` and `residual` are updated in-place:
+ residual <- x + residual
+ x <- RMSNorm(residual) * weight
+ """
+ layout_x = cute.make_layout((M, N_dyn), stride=(ld_x, 1))
+ layout_res = cute.make_layout((M, N_dyn), stride=(N_dyn, 1))
+ layout_n = cute.make_layout((N_dyn,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_x)
+ mO = cute.make_tensor(ptr_x, layout_x)
+ mRes = cute.make_tensor(ptr_res, layout_res)
+ mResO = cute.make_tensor(ptr_res, layout_res)
+ mW = cute.make_tensor(ptr_w, layout_n)
+
+ self.__call__(
+ mX,
+ mW,
+ None, # bias
+ mRes,
+ mO,
+ mResO,
+ None, # rstd
+ stream,
+ eps,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ cluster_n: cutlass.Constexpr[int],
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(cluster_n > 1):
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ else:
+ cta_rank_in_cluster = const_expr(0)
+ n_off = cta_rank_in_cluster * tiler_mn[1]
+
+ smem = cutlass.utils.SmemAllocator()
+ # Allocate one or two SMEM buffers depending on stage depth
+ sX0 = (
+ smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(not self.direct_gmem)
+ else None
+ )
+ sX1 = (
+ smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(self.stage > 1 and not self.direct_gmem)
+ else None
+ )
+ sRes0 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None and not self.direct_gmem)
+ else None
+ )
+ sRes1 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None and self.stage > 1 and not self.direct_gmem)
+ else None
+ )
+
+ # Reduction buffers + mbar for cluster reduce (reused by row_reduce helper)
+ red_layout = cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype, red_layout, byte_alignment=4
+ )
+ if const_expr(cluster_n > 1):
+ # Some CuTeDSL builds appear sensitive to the shared-memory alignment of
+ # mbarrier state. `SmemAllocator.allocate_array` does not currently
+ # expose an alignment parameter, so allocate an Int64 tensor with an
+ # explicit alignment and pass its iterator as the pointer.
+ mbar_tensor = smem.allocate_tensor(
+ cutlass.Int64,
+ cute.make_layout((self.stage,), stride=(1,)),
+ byte_alignment=16,
+ )
+ mbar_ptr = mbar_tensor.iterator
+ else:
+ mbar_ptr = None
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+ limit_k = shape[1] - n_off
+
+ # Tiled copy setup
+ num_copy_elems_X = tv_layout.shape[1][0]
+ use_async = const_expr(
+ self.use_async and self.N >= 1024 and not self.direct_gmem
+ )
+ copy_atom = get_copy_atom_bw(
+ mX.element_type, num_copy_elems_X, is_async=use_async
+ )
+ thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
+
+ # Tail predicate for the N dimension (when tile width > N). Reuse this
+ # for W/B loads so we never read past the end of those 1D tensors.
+ is_even_N_wb = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ if const_expr(not is_even_N_wb):
+ cX0 = cute.local_tile(idX, tiler_mn, (0, 0))
+ tXp_wb = qutils.predicate_k(thr_copy.partition_S(cX0), limit=limit_k)
+ else:
+ tXp_wb = None
+
+ # Weight/bias loads:
+ #
+ # - Direct-GMEM schedule: load weight/bias up front to hide latency.
+ # - Staged SMEM schedule: loading after the reduction reduces register
+ # pressure during the long-scoreboard reduction phase (better for large-M),
+ # but it measurably hurts small-M latency for the non-fused (no residual,
+ # no bias) case. For that specific case, prefetch weight up front as well.
+ tXrW = None
+ tXrB = None
+ prefetch_w_early = bool(
+ mW is not None and (self.direct_gmem or (mRes is None and mB is None))
+ )
+ if const_expr(prefetch_w_early):
+ gW = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)
+ )
+ tXgW = thr_copy.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if const_expr(not is_even_N_wb):
+ tXrW.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ pred=tXp_wb,
+ )
+ if const_expr(self.direct_gmem and mB is not None):
+ gB = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)
+ )
+ tXgB = thr_copy.partition_S(gB)
+ tXrB = cute.make_fragment_like(tXgB)
+ if const_expr(not is_even_N_wb):
+ tXrB.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ pred=tXp_wb,
+ )
+
+ # Non-persistent per-CTA execution (one tile in M)
+ self._init_cluster(tidx, mbar_ptr)
+
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t)
+ if t is not None
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((0, n_off), t) if t is not None else None
+ for t in (mX_i, mRes_i, mO_i, mResO_i)
+ ]
+ gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0))
+ gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0))
+ gRes_i = (
+ cute.local_tile(mRes_i, tiler_mn, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+ gResO_i = (
+ cute.local_tile(mResO_i, tiler_mn, (0, 0))
+ if const_expr(mResO is not None)
+ else None
+ )
+ gRstd_i = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, 0))
+ if const_expr(mRstd is not None)
+ else None
+ )
+ cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0))
+
+ # Common identity/row index partitions reused by both default and K-loop paths
+ tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
+ row_i = tXcX_i[0][0]
+ tXgRstd_i = (
+ thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ )
+
+ # Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces
+ # per-thread fragment size and can improve memory-latency hiding for
+ # N=7168 at large M. It is enabled by setting `stage=2` when constructing
+ # the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`).
+ if const_expr(
+ self.stage > 1
+ and not self.direct_gmem
+ and use_async
+ and cluster_n == 1
+ and shape[1] == 7168
+ ):
+ vecsize = tv_layout.shape[1][0]
+ tpr = threads_per_row
+ target_tile_n = const_expr(4096)
+ tile_factor = const_expr(target_tile_n // (vecsize * tpr))
+ if const_expr(tile_factor > 0):
+ tile_n = vecsize * tpr * tile_factor
+ num_tiles = cute.ceil_div(shape[1], tile_n)
+
+ tiler_mn_tile = (tiler_mn[0], tile_n)
+ sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0))
+ sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0))
+ sRes0_tile = (
+ cute.local_tile(sRes0, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+ sRes1_tile = (
+ cute.local_tile(sRes1, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+
+ tv_layout_tile = cute.make_layout(
+ ((tpr, tiler_mn[0]), (vecsize, tile_factor)),
+ stride=(
+ (vecsize * tiler_mn[0], 1),
+ (tiler_mn[0], tiler_mn[0] * vecsize * tpr),
+ ),
+ )
+ thr_copy_tile = cute.make_tiled_copy(
+ copy_atom, tv_layout_tile, tiler_mn_tile
+ ).get_slice(tidx)
+
+ # Accumulate per-thread partial sums across tiles; reduce once.
+ sum_sq_thread = cute.Float32(0.0)
+
+ # Preload tile 0 into sX0/sRes0.
+ k_off0 = const_expr(0) * tile_n
+ gX_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, 0)
+ )
+ tXgX_0 = thr_copy_tile.partition_S(gX_0)
+ tXsX_0 = thr_copy_tile.partition_D(sX0_tile)
+ cX_0 = cute.local_tile(
+ cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, 0)
+ )
+ tXc_0 = thr_copy_tile.partition_S(cX_0)
+ tXp_0 = qutils.predicate_k(tXc_0, limit=limit_k)
+
+ tXp_ping = tXp_0
+ tXp_pong = tXp_0
+
+ if row_i < shape[0]:
+ copy_tiled(
+ tXgX_0,
+ tXsX_0,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_0,
+ )
+ if const_expr(mRes is not None):
+ gRes_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_0 = thr_copy_tile.partition_S(gRes_0)
+ tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile)
+ copy_tiled(
+ tXgRes_0,
+ tXsRes_0,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_0,
+ )
+ cute.arch.cp_async_commit_group()
+
+ for t in cutlass.range_constexpr(num_tiles):
+ next_t = t + 1
+ if next_t < num_tiles:
+ k_off_n = next_t * tile_n
+ gX_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgX_n = thr_copy_tile.partition_S(gX_n)
+ cX_n = cute.local_tile(
+ cute.domain_offset((0, k_off_n), cX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXc_n = thr_copy_tile.partition_S(cX_n)
+ tXp_n = qutils.predicate_k(tXc_n, limit=limit_k)
+
+ if const_expr((t % 2) == 0):
+ tXsX_n = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ tXp_pong = tXp_n
+ else:
+ tXsX_n = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ tXp_ping = tXp_n
+
+ if row_i < shape[0]:
+ copy_tiled(
+ tXgX_n,
+ tXsX_n,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_n,
+ )
+ if const_expr(mRes is not None):
+ gRes_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_n = thr_copy_tile.partition_S(gRes_n)
+ copy_tiled(
+ tXgRes_n,
+ tXsRes_n,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_n,
+ )
+ cute.arch.cp_async_commit_group()
+
+ cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0)
+
+ # Current tile buffer (ping/pong).
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ pred_cur = tXp_ping
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ pred_cur = tXp_pong
+
+ k_off = t * tile_n
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX_t = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX_t)
+ x_t = tXrX_t.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes_t = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes_t)
+ x_t += tXrRes_t.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ gResO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mResO_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgResO_t = thr_copy_tile.partition_D(gResO_t)
+ tXrResO_t = cute.make_fragment_like(tXgResO_t)
+ tXrResO_t.store(x_t.to(tXrResO_t.element_type))
+ if row_i < shape[0]:
+ copy_tiled(
+ tXrResO_t,
+ tXgResO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=pred_cur,
+ )
+
+ sum_sq_thread = sum_sq_thread + (x_t * x_t).reduce(
+ cute.ReductionOp.ADD,
+ init_val=0.0,
+ reduction_profile=0,
+ )
+
+ sum_sq = row_reduce(
+ sum_sq_thread,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if tXcX_i[0][1] == 0 and row_i < shape[0]:
+ tXgRstd_i[0] = rstd
+
+ for t in cutlass.range_constexpr(num_tiles):
+ k_off = t * tile_n
+ cX_t = cute.local_tile(
+ cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0)
+ )
+ tXc_t = thr_copy_tile.partition_S(cX_t)
+ tXp_t = qutils.predicate_k(tXc_t, limit=limit_k)
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX_t = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX_t)
+ x_t = tXrX_t.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes_t = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes_t)
+ x_t += tXrRes_t.load().to(cute.Float32)
+
+ y_t = x_t * rstd
+ if const_expr(mW is not None):
+ gW_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mW),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tWgW_t = thr_copy_tile.partition_S(gW_t)
+ tWrW_t = cute.make_fragment_like(tWgW_t)
+ copy_tiled(
+ tWgW_t,
+ tWrW_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
+ y_t = y_t * tWrW_t.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ gB_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mB),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tWgB_t = thr_copy_tile.partition_S(gB_t)
+ tWrB_t = cute.make_fragment_like(tWgB_t)
+ copy_tiled(
+ tWgB_t,
+ tWrB_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
+ y_t = y_t + tWrB_t.load().to(cute.Float32)
+
+ gO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mO_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgO_t = thr_copy_tile.partition_D(gO_t)
+ tXrO_t = cute.make_fragment_like(tXgO_t)
+ tXrO_t.store(y_t.to(tXrO_t.element_type))
+ if row_i < shape[0]:
+ copy_tiled(
+ tXrO_t,
+ tXgO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
+
+ return
+
+ # Single-stage path: one-row-per-CTA
+ tXgX_i = thr_copy.partition_S(gX_i)
+ tXgRes_i = (
+ thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ )
+ tXgO_i = thr_copy.partition_D(gO_i)
+ tXgResO_i = (
+ thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ )
+ # tXgRstd_i / tXcX_i / row_i prepared above
+ is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ tXpX_i = (
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k)
+ if not is_even_N_i
+ else None
+ )
+
+ tXrX = cute.make_fragment_like(tXgX_i)
+ tXrRes = (
+ cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None
+ )
+ if const_expr(self.direct_gmem):
+ if const_expr(not is_even_N_i):
+ tXrX.fill(0)
+ if const_expr(tXrRes is not None):
+ tXrRes.fill(0)
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, tXrX, pred=tXpX_i)
+ if const_expr(tXrRes is not None):
+ cute.copy(copy_atom, tXgRes_i, tXrRes, pred=tXpX_i)
+ else:
+ # If N is not a multiple of the tile width, the predicated gmem->smem
+ # copies leave out-of-bounds lanes uninitialized. Clear the SMEM tile
+ # so masked lanes read as 0 for reduction/output.
+ if const_expr(not is_even_N_i):
+ thr_copy.partition_D(sX0).fill(0)
+ if const_expr(mRes is not None):
+ thr_copy.partition_D(sRes0).fill(0)
+
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i)
+ if const_expr(mRes is not None):
+ cute.copy(
+ copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i
+ )
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ cute.autovec_copy(thr_copy.partition_D(sX0), tXrX)
+ if const_expr(tXrRes is not None):
+ cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes)
+ x_red = tXrX.load().to(cute.Float32)
+ if const_expr(tXrRes is not None):
+ x_red += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ tXrResO = cute.make_fragment_like(tXgResO_i)
+ tXrResO.store(x_red.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(
+ tXrResO.element_type, num_copy_elems_X, is_async=False
+ ),
+ tXrResO,
+ tXgResO_i,
+ pred=tXpX_i,
+ )
+
+ sum_sq = row_reduce(
+ x_red * x_red,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ if const_expr(not self.direct_gmem and (mRes is not None or mB is not None)):
+ # Load weight/bias after the reduction so they don't inflate register
+ # pressure during the long-scoreboard reduction phase (helping occupancy
+ # when registers are the limiting factor).
+ if const_expr(mW is not None):
+ gW = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)
+ )
+ tXgW = thr_copy.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if const_expr(not is_even_N_wb):
+ tXrW.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ pred=tXp_wb,
+ )
+ if const_expr(mB is not None):
+ gB = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)
+ )
+ tXgB = thr_copy.partition_S(gB)
+ tXrB = cute.make_fragment_like(tXgB)
+ if const_expr(not is_even_N_wb):
+ tXrB.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ pred=tXp_wb,
+ )
+
+ # Reuse `x_red` (x + residual, in fp32) for the output path so we don't
+ # keep both `tXrX` and `tXrRes` fragments live across the reduction.
+ y = x_red * rstd
+ if const_expr(mW is not None):
+ y = y * tXrW.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ y = y + tXrB.load().to(cute.Float32)
+
+ tXrO = cute.make_fragment_like(tXgO_i)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False),
+ tXrO,
+ tXgO_i,
+ pred=tXpX_i,
+ )
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ cluster_n: cutlass.Constexpr[int],
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ cluster_n,
+ num_warps,
+ warps_per_row,
+ threads_per_row,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ copy_bits = int(self.copy_bits)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = self._num_threads()
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = self._threads_per_row()
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(cluster_n),
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+
+ @cute.jit
+ def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]):
+ if const_expr(mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
+def _can_use_ptr_path(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+) -> bool:
+ """Fast path precondition for the pointer-based CuTeDSL entry.
+
+ We require a row-major 2D layout where the last dimension is
+ contiguous (stride(1) == 1). The leading dimension (stride(0))
+ may be larger than N (padded-row / packed-attention layouts),
+ and is passed to the kernel as `ld`.
+ """
+ if x.stride(1) != 1:
+ return False
+ # All participating tensors are interpreted as the same element type
+ # (derived from x.dtype) in the pointer-based path. If dtypes differ,
+ # we'd read the wrong bit patterns and silently produce incorrect output.
+ if residual is not None and residual.dtype != x.dtype:
+ return False
+ if weight is not None and weight.dtype != x.dtype:
+ # Allow the common "Quack-style" API where weights are fp32 even when
+ # activations are bf16/fp16. The pointer path constructs a weight tensor
+ # view with the correct element type (fp32) inside the compiled graph.
+ if weight.dtype is not torch.float32:
+ return False
+ if x.dtype not in (torch.float16, torch.bfloat16):
+ return False
+ if bias is not None and bias.dtype != x.dtype:
+ return False
+ # The kernel assumes `ld` satisfies a divisibility constraint used by
+ # cute.assume(..., divby=...) for vectorization.
+ elem_bits = TORCH2CUTE_DTYPE[x.dtype].width
+ divby = 256 // elem_bits
+ if (x.stride(0) % divby) != 0:
+ return False
+ # The kernel uses 128-bit vectorized copies (16B). Require at least 16B
+ # alignment on all participating tensors to avoid misaligned global loads.
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if residual is not None and residual.stride(1) != 1:
+ return False
+ if residual is not None and residual.stride(0) != x.stride(0):
+ return False
+ if residual is not None and (residual.data_ptr() % 16) != 0:
+ return False
+ if weight is not None and not weight.is_contiguous():
+ return False
+ if bias is not None and not bias.is_contiguous():
+ return False
+ if weight is not None:
+ # For fp32 weights we use 256b universal copies (32B) by default.
+ # Require 32B alignment so the compiler can safely vectorize loads.
+ if weight.dtype is torch.float32:
+ if (weight.data_ptr() % 32) != 0:
+ return False
+ else:
+ if (weight.data_ptr() % 16) != 0:
+ return False
+ if bias is not None and (bias.data_ptr() % 16) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_fused_add_inplace(
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+) -> bool:
+ """Fast-path precondition for fused_add_rmsnorm_forward_inplace.
+
+ We allow the common vLLM layout where:
+ - `x` is strided/padded row-major (stride(1) == 1, stride(0) >= N)
+ - `residual` is contiguous row-major (stride(0) == N)
+ """
+ if x.stride(1) != 1:
+ return False
+ if residual.dtype != x.dtype:
+ return False
+ if weight.dtype != x.dtype:
+ return False
+ if residual.stride(1) != 1:
+ return False
+ if not residual.is_contiguous():
+ return False
+ if not weight.is_contiguous():
+ return False
+
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 256 // dtype.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ if (residual.stride(0) % divby) != 0:
+ return False
+
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if (residual.data_ptr() % 16) != 0:
+ return False
+ if (weight.data_ptr() % 16) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_bwd(
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+) -> bool:
+ """Fast-path precondition for the pointer-based RMSNorm backward entry.
+
+ This path is only used for the common Quack-style signature:
+ - no bias gradient
+ - no residual / dresidual_out
+ - weight is either the same dtype as x, or fp32 for bf16/fp16 activations
+ """
+ if x.dim() != 2 or dout.dim() != 2:
+ return False
+ if rstd.dim() != 1:
+ return False
+ if x.shape != dout.shape:
+ return False
+ if rstd.numel() != x.shape[0]:
+ return False
+ # SM100 backward kernel assumes N is divisible by 8 (for 256b fp32 stores
+ # into dw_partial rows).
+ if (x.shape[1] % 8) != 0:
+ return False
+ if x.stride(1) != 1 or dout.stride(1) != 1:
+ return False
+ if dout.stride(0) != x.stride(0):
+ return False
+ if dout.dtype != x.dtype:
+ return False
+ if rstd.dtype != torch.float32 or not rstd.is_contiguous():
+ return False
+ if weight is None:
+ return False
+ if weight.dim() != 1 or weight.shape[0] != x.shape[1]:
+ return False
+ if not weight.is_contiguous():
+ return False
+ if weight.dtype != x.dtype:
+ if weight.dtype is not torch.float32:
+ return False
+ if x.dtype not in (torch.float16, torch.bfloat16):
+ return False
+
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 256 // dtype.width
+ if (x.stride(0) % divby) != 0:
+ return False
+
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if (dout.data_ptr() % 16) != 0:
+ return False
+ # Torch CUDA allocations are typically >=256B aligned, but keep the check
+ # explicit so we never assume tighter alignment than is true.
+ if (rstd.data_ptr() % 4) != 0:
+ return False
+ if (weight.data_ptr() % (32 if weight.dtype is torch.float32 else 16)) != 0:
+ return False
+ return True
+
+
+def _rmsnorm_forward_ptr(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+ eps: float,
+ store_rstd: bool,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """Pointer-based RMSNorm forward that bypasses DLPack entirely.
+
+ This path reconstructs cute.Tensor views from raw device pointers
+ and explicit layouts inside the JIT graph, avoiding any runtime
+ DLPack conversions while reusing the tuned RMSNormSM100 schedule.
+ """
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+
+ # Preserve the input's 2D stride so downstream users that rely on
+ # padded-row layouts (stride0 > N) continue to see the expected layout.
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ residual_out: Optional[Tensor] = None
+ rstd: Optional[Tensor] = None
+
+ if residual is not None:
+ residual_out = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ if store_rstd:
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32)
+
+ _rmsnorm_forward_ptr_into(
+ x=x,
+ weight=weight,
+ bias=bias,
+ residual=residual,
+ out=out,
+ residual_out=residual_out,
+ rstd=rstd,
+ eps=eps,
+ )
+ return out, rstd, residual_out
+
+
+def _rmsnorm_forward_ptr_into(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+ out: Tensor,
+ residual_out: Optional[Tensor],
+ rstd: Optional[Tensor],
+ eps: float,
+) -> None:
+ """Internal helper that launches the pointer-based kernel into preallocated outputs.
+
+ This enables integration into frameworks like vLLM that manage their
+ own buffers and prefer in-place or out-parameter semantics.
+ """
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+ device_index = x.get_device()
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ if bias is None and residual is None and residual_out is None and rstd is None:
+ # Fast-launch path: cache packed args and update pointers/scalars in-place to
+ # avoid Python-side argument marshalling overhead that dominates small-batch cases.
+ #
+ # If fast-launch is disabled (or CuTeDSL internals changed), we fall back
+ # to calling the compiled function directly.
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ has_weight = weight is not None
+
+ weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if has_weight else None
+
+ # Schedule selection (pointer fast path).
+ #
+ # Goals:
+ # - Keep vLLM inference fast path (contiguous/padded row-major) fast.
+ # - Enable higher vector widths when all participating pointers are 32B-aligned.
+ # - Prefer direct-GMEM for SM100-friendly hidden sizes to reduce SMEM/barrier
+ # overhead, especially for small/medium-M cases.
+ direct_gmem = _direct_gmem_from_policy(
+ default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192})
+ )
+ use_async = not direct_gmem
+
+ can_use_256 = bool(
+ dtype.width == 16
+ and (x.data_ptr() % 32) == 0
+ and (out.data_ptr() % 32) == 0
+ and (not has_weight or (weight.data_ptr() % 32) == 0) # type: ignore[union-attr]
+ )
+ default_copy_bits = 256 if can_use_256 else 128
+ # Quack-style fp32-weight policy: cap the *widest* dtype to 128b, so when
+ # weights are fp32 we use 64b activation vectors (helps register pressure).
+ if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32:
+ default_copy_bits = 64
+ copy_bits = _copy_bits_from_policy(
+ default=default_copy_bits, can_use_256=can_use_256
+ )
+ assumed_align = 32 if copy_bits >= 256 else 16
+
+ stage = 1
+ if (
+ _ENABLE_STAGE2
+ and dtype.width == 16
+ and N == 7168
+ and (not direct_gmem)
+ and M >= 4096
+ ):
+ stage = 2
+
+ compiled_key = (
+ "ptr",
+ N,
+ dtype,
+ weight_dtype,
+ False, # residual
+ has_weight,
+ False, # bias
+ False, # residual_out
+ False, # rstd
+ stage,
+ int(copy_bits),
+ bool(use_async),
+ bool(direct_gmem),
+ int(assumed_align),
+ device_index,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(compiled_key)
+ if compiled is None:
+ op = RMSNormSM100(
+ N,
+ dtype,
+ stage=stage,
+ copy_bits=int(copy_bits),
+ use_async=bool(use_async),
+ direct_gmem=bool(direct_gmem),
+ )
+ ld_val = int(x.stride(0))
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_out = rt.make_ptr(
+ dtype,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if has_weight
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(ld_val)
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[compiled_key] = compiled
+
+ launcher = _get_fast_ptr_rmsnorm_launcher(
+ compiled=compiled,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ N=N,
+ device_index=device_index,
+ stream_handle=stream_handle,
+ has_weight=has_weight,
+ assumed_align=assumed_align,
+ eps=eps,
+ )
+ ld_val = int(x.stride(0))
+ if launcher is not None:
+ launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps)
+ return
+
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_out = rt.make_ptr(
+ dtype,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if has_weight
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(ld_val)
+ compiled(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ return
+
+ # General path (supports bias/residual/rstd, but is slower to launch).
+ #
+ # Keep the same schedule-selection policy as the fast path so correctness-only
+ # features (bias/residual/rstd) don't accidentally fall off a performance cliff.
+ weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if weight is not None else None
+ direct_gmem = _direct_gmem_from_policy(
+ default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192})
+ )
+ use_async = not direct_gmem
+ can_use_256 = bool(
+ dtype.width == 16
+ and (x.data_ptr() % 32) == 0
+ and (out.data_ptr() % 32) == 0
+ and (weight is None or (weight.data_ptr() % 32) == 0)
+ and (bias is None or (bias.data_ptr() % 32) == 0)
+ and (residual is None or (residual.data_ptr() % 32) == 0)
+ and (residual_out is None or (residual_out.data_ptr() % 32) == 0)
+ )
+ default_copy_bits = 256 if can_use_256 else 128
+ if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32:
+ default_copy_bits = 64
+ copy_bits = _copy_bits_from_policy(
+ default=default_copy_bits, can_use_256=can_use_256
+ )
+ assumed_align = 32 if copy_bits >= 256 else 16
+
+ stage = 1
+ if (
+ _ENABLE_STAGE2
+ and dtype.width == 16
+ and N == 7168
+ and (not direct_gmem)
+ and M >= 4096
+ ):
+ stage = 2
+
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ key = (
+ "ptr",
+ N,
+ dtype,
+ weight_dtype,
+ residual is not None,
+ weight is not None,
+ bias is not None,
+ residual_out is not None,
+ rstd is not None,
+ stage,
+ int(copy_bits),
+ bool(use_async),
+ bool(direct_gmem),
+ int(assumed_align),
+ device_index,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = RMSNormSM100(
+ N,
+ dtype,
+ stage=stage,
+ copy_bits=int(copy_bits),
+ use_async=bool(use_async),
+ direct_gmem=bool(direct_gmem),
+ )
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_out = rt.make_ptr(
+ dtype,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_res = (
+ rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if residual is not None
+ else None
+ )
+ ptr_res_out = (
+ rt.make_ptr(
+ dtype,
+ residual_out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if residual_out is not None
+ else None
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if weight is not None
+ else None
+ )
+ ptr_b = (
+ rt.make_ptr(
+ dtype,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_res,
+ ptr_out,
+ ptr_res_out,
+ ptr_rstd,
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_out = rt.make_ptr(
+ dtype,
+ out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_res = (
+ rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if residual is not None
+ else None
+ )
+ ptr_res_out = (
+ rt.make_ptr(
+ dtype,
+ residual_out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if residual_out is not None
+ else None
+ )
+ ptr_w = (
+ rt.make_ptr(
+ weight_dtype or dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if weight is not None
+ else None
+ )
+ ptr_b = (
+ rt.make_ptr(
+ dtype,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(int(x.stride(0)))
+ compiled(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_res,
+ ptr_out,
+ ptr_res_out,
+ ptr_rstd,
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+
+
+def _fused_add_rmsnorm_forward_ptr_inplace(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float,
+) -> None:
+ """Pointer-based fused_add_rmsnorm that updates `x` and `residual` in-place."""
+ assert x.is_cuda
+ assert x.dim() == 2
+ assert residual.is_cuda
+ assert residual.dim() == 2
+ assert x.shape == residual.shape
+
+ M, N = x.shape
+ device_index = x.get_device()
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ stage = 1
+
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+
+ # Latency-optimized schedule for small-M cases: avoid the gmem->smem
+ # staging path (large dynamic smem + extra barriers) and load directly
+ # from gmem into registers.
+ copy_bits = 128
+ # Use a direct-GMEM schedule (no staging SMEM tiles) for DSv3 hidden size
+ # (7168, bf16/fp16). This improves both:
+ # - small-M latency (fewer barriers + less dynamic shared memory), and
+ # - large-M bandwidth (lower overhead, better vectorization when 32B-aligned).
+ #
+ # This is a policy decision: it is tuned for DSv3's N=7168. If you want to
+ # benchmark other models/shapes, you can override it with:
+ # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path)
+ # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path)
+ # Default direct-GMEM policy:
+ # - small/medium M: direct-GMEM reduces staging/barrier overhead
+ # - large M: staged cp.async tends to win on sustained bandwidth
+ direct_gmem = _direct_gmem_from_policy(
+ default=bool(dtype.width == 16 and N == 7168 and M <= 16384)
+ )
+ use_async = not direct_gmem
+ tpr_override: Optional[int] = None
+ nt_override: Optional[int] = None
+ cluster_n_override: Optional[int] = None
+ direct_gmem_max_copy_bits: Optional[int] = None
+
+ # Experimental stage-2 cp.async path (2-tile ping-pong) for N=7168. This is
+ # primarily about improving memory-latency hiding / reducing long-scoreboard
+ # stalls for large-M workloads.
+ if _ENABLE_STAGE2 and dtype.width == 16 and N == 7168 and M >= 4096:
+ stage = 2
+ direct_gmem = False
+ use_async = True
+
+ # Experimental ILP variant (clusters): split each row across 2 CTAs.
+ #
+ # NOTE: This is currently opt-in because some CuTeDSL builds exhibit
+ # instability with cluster launches for this specific schedule. To reduce
+ # the chance of accidental crashes, we require an additional explicit
+ # opt-in via `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`.
+ if _ENABLE_CLUSTER_ILP and not _ENABLE_STAGE2:
+ if dtype.width == 16 and N == 7168 and M >= 4096:
+ cluster_n_override = 2
+ if direct_gmem:
+ # Cluster launches + direct-GMEM has exhibited reproducible compiler
+ # instability (segfaults) in some CuTeDSL builds, especially for the
+ # 256b vector path. Probe it out-of-process once so we can safely
+ # select a working copy width (or fall back to the staged SMEM path)
+ # instead of crashing the parent process.
+ max_bits = _probe_cluster_direct_gmem_max_copy_bits()
+ if max_bits == 0:
+ direct_gmem = False
+ use_async = True
+ else:
+ direct_gmem_max_copy_bits = max_bits
+
+ # Experimental per-row partitioning: use 256 threads/row for N=7168 to
+ # increase concurrency/ILP (accepts a small tail-predicate region).
+ if _ENABLE_TPR256 and cluster_n_override is None and not _ENABLE_STAGE2:
+ if dtype.width == 16 and N == 7168 and M >= 4096:
+ tpr_override = 256
+ nt_override = 256
+
+ can_use_256 = bool(
+ direct_gmem
+ and (direct_gmem_max_copy_bits is None or direct_gmem_max_copy_bits >= 256)
+ and dtype.width == 16
+ and (x.data_ptr() % 32) == 0
+ and (residual.data_ptr() % 32) == 0
+ and (weight.data_ptr() % 32) == 0
+ )
+ assumed_align = 32 if can_use_256 else 16
+ if can_use_256:
+ copy_bits = 256
+
+ copy_bits = _copy_bits_from_policy(default=copy_bits, can_use_256=can_use_256)
+ if copy_bits == 128:
+ assumed_align = 16
+ elif copy_bits == 256 and can_use_256:
+ assumed_align = 32
+ else:
+ copy_bits = 128
+ assumed_align = 16
+
+ key = (
+ "ptr_fused_add_inplace",
+ N,
+ dtype,
+ stage,
+ device_index,
+ copy_bits,
+ use_async,
+ tpr_override,
+ nt_override,
+ direct_gmem,
+ cluster_n_override,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = RMSNormSM100(
+ N,
+ dtype,
+ stage=stage,
+ copy_bits=copy_bits,
+ use_async=use_async,
+ direct_gmem=direct_gmem,
+ )
+ if tpr_override is not None:
+ op._tpr_override = tpr_override # type: ignore[attr-defined]
+ if nt_override is not None:
+ op._nt_override = nt_override # type: ignore[attr-defined]
+ if cluster_n_override is not None:
+ op._cluster_n_override = cluster_n_override # type: ignore[attr-defined]
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_res = rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_w = rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld_x = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs_fused_add_inplace,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(M),
+ Int32(N),
+ ld_x,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+ launcher = _get_fast_ptr_fused_add_rmsnorm_launcher(
+ compiled=compiled,
+ dtype=dtype,
+ N=N,
+ device_index=device_index,
+ stream_handle=stream_handle,
+ copy_bits=copy_bits,
+ use_async=use_async,
+ tpr=tpr_override or 0,
+ direct_gmem=direct_gmem,
+ assumed_align=assumed_align,
+ eps=eps,
+ )
+ if launcher is not None:
+ launcher.launch(
+ x=x,
+ weight=weight,
+ residual=residual,
+ M=M,
+ N=N,
+ ld_x=int(x.stride(0)),
+ eps=eps,
+ )
+ return
+
+ # Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back
+ # to calling the compiled function directly.
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_res = rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ ptr_w = rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld_x = Int32(int(x.stride(0)))
+ compiled(ptr_x, ptr_w, ptr_res, Int32(M), Int32(N), ld_x, stream, Float32(eps))
+
+
+# -------------------------
+# Public API (forward + verify)
+# -------------------------
+
+
+def rmsnorm_forward(
+ x: Tensor,
+ weight: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ store_rstd: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+
+ # Fast path: use the pointer-based entry whenever we can represent the
+ # inputs as a row-major [M, N] view with stride(1) == 1 and dtype contracts
+ # are satisfied (vLLM uses this in inference).
+ #
+ # When the pointer path can't be used (e.g. float32 weights for Quack-style
+ # APIs, or non-standard layouts), fall back to the CuTeDSL stage-2 module
+ # (ported from `/tmp/oink_main/Blackwell`) before using the slow torch
+ # reference implementation.
+ force_stage2 = _FORCE_RMSNORM_STAGE2_FWD
+
+ use_ptr = (not force_stage2) and _can_use_ptr_path(x, weight, bias, residual)
+
+ if use_ptr:
+ return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd)
+
+ # CuTeDSL fallback for cases that aren't safe for the pointer path.
+ # Import lazily to keep vLLM plugin startup and common inference fast paths
+ # lightweight.
+ try:
+ import importlib
+
+ rms2 = importlib.import_module(
+ ".rmsnorm_with_stage2",
+ package=__package__ or "kernelagent_oink.blackwell",
+ )
+ except Exception:
+ rms2 = None # type: ignore[assignment]
+ if rms2 is not None:
+ y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2(
+ x,
+ weight=weight,
+ bias=bias,
+ residual=residual,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ # Preserve stride contracts for torch.compile consistency, even
+ # when using the optional stage-2 implementation.
+ if y.stride() != x.stride():
+ y_strided = torch.empty_strided(
+ x.shape, x.stride(), device=x.device, dtype=x.dtype
+ )
+ y_strided.copy_(y)
+ y = y_strided
+ if residual is not None and residual_out is not None:
+ if residual_out.stride() != residual.stride():
+ residual_out_strided = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ residual_out_strided.copy_(residual_out)
+ residual_out = residual_out_strided
+ return y, rstd, residual_out
+
+ # Safe fallback (correctness-first). This is expected to be rare in vLLM.
+ y = rmsnorm_ref(x, weight, bias, residual, eps)
+ # Preserve the input stride contract even on the fallback path so
+ # torch.compile sees a consistent output layout across all branches.
+ if y.stride() != x.stride():
+ y_strided = torch.empty_strided(
+ x.shape, x.stride(), device=x.device, dtype=x.dtype
+ )
+ y_strided.copy_(y)
+ y = y_strided
+ rstd = None
+ if store_rstd:
+ xf = x.float()
+ if residual is not None:
+ xf = xf + residual.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32)
+ residual_out = None
+ if residual is not None:
+ residual_out = (x.float() + residual.float()).to(x.dtype)
+ if residual_out.stride() != residual.stride():
+ residual_out_strided = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ residual_out_strided.copy_(residual_out)
+ residual_out = residual_out_strided
+ return y, rstd, residual_out
+
+
+def rmsnorm_ref(
+ x: Tensor,
+ w: Optional[Tensor] = None,
+ b: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+) -> Tensor:
+ xf = x.float()
+ if residual is not None:
+ xf = xf + residual.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1, keepdim=True) + eps)
+ y = xf * rstd
+ if w is not None:
+ y = y * w.float()
+ if b is not None:
+ y = y + b.float()
+ return y.to(x.dtype)
+
+
+def fused_add_rmsnorm_forward(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> Tuple[Tensor, Tensor]:
+ """Fused residual-add + RMSNorm for SM100 in CuteDSL.
+
+ This is a convenience wrapper around ``rmsnorm_forward`` that matches the
+ semantics of vLLM's ``fused_add_rms_norm``:
+
+ z = x + residual
+ y = RMSNorm(z, weight, eps)
+
+ It returns ``(y, z)`` where ``z`` has the same dtype/shape as the inputs.
+ """
+ assert x.is_cuda and residual.is_cuda
+ assert x.shape == residual.shape
+ assert x.dtype == residual.dtype
+
+ orig_shape = x.shape
+ N = orig_shape[-1]
+
+ x_2d = x.view(-1, N)
+ res_2d = residual.view(-1, N)
+
+ y_2d, _rstd, z_2d = rmsnorm_forward(
+ x_2d,
+ weight=weight,
+ bias=None,
+ residual=res_2d,
+ eps=eps,
+ store_rstd=False,
+ )
+
+ y = y_2d.view(orig_shape)
+ z = z_2d.view(orig_shape)
+ return y, z
+
+
+def fused_add_rmsnorm_forward_inplace(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> Tuple[Tensor, Tensor]:
+ """In-place fused residual-add + RMSNorm matching vLLM semantics.
+
+ This variant writes:
+
+ z = x + residual (stored into ``residual``)
+ y = RMSNorm(z, w) (stored into ``x``)
+
+ i.e., it uses ``x`` as the normalized output buffer and ``residual`` as
+ the residual-out buffer, mirroring vLLM's fused_add_rms_norm kernel.
+ """
+ fused_add_rmsnorm_inplace_(x, residual, weight, eps=eps)
+ return x, residual
+
+
+def fused_add_rmsnorm_inplace_(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> None:
+ """In-place fused residual-add + RMSNorm matching vLLM semantics.
+
+ This is the lowest-overhead Python entrypoint (returns `None`) intended
+ for performance-critical call sites like `torch.ops.oink.fused_add_rms_norm`.
+ """
+ assert x.is_cuda and residual.is_cuda
+ assert x.shape == residual.shape
+ assert x.dtype == residual.dtype
+
+ N = x.shape[-1]
+ x_2d = x if x.dim() == 2 else x.view(-1, N)
+ res_2d = residual if residual.dim() == 2 else residual.view(-1, N)
+
+ # Fast path: vLLM-compatible layout where x may be strided/padded but
+ # residual is contiguous. This updates both tensors in-place without
+ # additional allocations.
+ if _can_use_ptr_path_fused_add_inplace(x_2d, weight, res_2d):
+ _fused_add_rmsnorm_forward_ptr_inplace(x_2d, res_2d, weight, eps)
+ return None
+
+ # Fallback: allocate via the regular fused path, then copy results into
+ # the user-provided buffers so that semantics remain identical.
+ y, z = fused_add_rmsnorm_forward(x, residual, weight, eps)
+ x.copy_(y)
+ residual.copy_(z)
+ return None
+
+
+# -------------------------
+# Backward kernel (SM100)
+# -------------------------
+
+
+class RMSNormBackwardSM100(BaseRMSNormBackward):
+ """SM100-tuned RMSNorm backward.
+
+ This is a thin wrapper around the generic `lite_quack.RMSNormBackward`
+ base implementation, with SM100-friendly tiling heuristics that mirror
+ the forward policy used by Oink.
+ """
+
+ def __init__(self, dtype: cutlass.Numeric, N: int):
+ super().__init__(dtype, N)
+
+ def _get_num_threads(self) -> int:
+ nt = getattr(self, "_nt_override", None)
+ if nt is not None:
+ return int(nt)
+ return super()._get_num_threads()
+
+ def _calculate_threads_per_row(self) -> int:
+ tpr = getattr(self, "_tpr_override", None)
+ if tpr is not None:
+ return int(tpr)
+ return super()._calculate_threads_per_row()
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: cute.Pointer,
+ ptr_dout: cute.Pointer,
+ ptr_rstd: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ ptr_dw_partial: cute.Pointer,
+ M: Int32,
+ N_dyn: Int32,
+ ld: Int32,
+ sm_count: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions.
+
+ This is the performance-critical path used by the benchmark harness
+ (and any future training integrations) for the common case:
+ - weight gradient enabled (dw_partial is provided)
+ - no bias/residual gradients
+ """
+ # Weight-grad stores use vectorized float32 copies. For the SM100
+ # schedule we want to allow up to 256b (8x f32) stores, which requires
+ # the leading dimension to be divisible by 8 to prove 32B alignment for
+ # every row in `dw_partial`.
+ N_assumed = cute.assume(N_dyn, divby=8)
+
+ layout_mn = cute.make_layout((M, N_assumed), stride=(ld, 1))
+ layout_n = cute.make_layout((N_assumed,), stride=(1,))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ # Default: write a full (sm_count, N) partial buffer (Quack-style),
+ # then reduce on the host with `torch.sum(dim=0)`.
+ #
+ # Optional: atomic-reduce directly into a single (N,) buffer by using
+ # a broadcasted leading dimension (stride0 = 0). This avoids the extra
+ # reduction kernel launch and is primarily used for tiny-M regimes.
+ if const_expr(self.atomic_dw):
+ layout_partial = cute.make_layout((sm_count, N_assumed), stride=(0, 1))
+ else:
+ layout_partial = cute.make_layout(
+ (sm_count, N_assumed), stride=(N_assumed, 1)
+ )
+
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mW = cute.make_tensor(ptr_w, layout_n)
+ mdO = cute.make_tensor(ptr_dout, layout_mn)
+ mRstd = cute.make_tensor(ptr_rstd, layout_m)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ mdW = cute.make_tensor(ptr_dw_partial, layout_partial)
+
+ self.__call__(
+ mX,
+ mW,
+ mdO,
+ None, # dresidual_out
+ mRstd,
+ mdX,
+ mdW,
+ None, # dresidual
+ None, # db_partial
+ sm_count,
+ stream,
+ )
+
+ def _get_num_threads(self) -> int:
+ # Keep 128 threads only up to N=4k; use 256 for larger rows to ensure
+ # threads_per_row <= num_threads across buckets.
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ return 128 if self.N <= 4096 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ pass
+ # Match Quack's backward tiling: use 256 threads/row for N > 4096.
+ #
+ # The earlier "mirror forward" policy (128 threads/row for N<=8192)
+ # regresses DSv3 backward at N=6144/7168/8192 on SM100.
+ N = self.N
+ for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]:
+ if N <= limit:
+ return threads
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ # Reuse the SM100 forward cluster growth policy so large-N shapes can
+ # fan out across multiple CTAs in the same row.
+ try:
+ self.cluster_n = self._cluster_n_override # type: ignore[attr-defined]
+ return
+ except Exception:
+ pass
+
+ N = self.N
+ if N <= 8192:
+ cluster_n = 1
+ elif self.dtype.width == 16:
+ if N <= 16 * 1024:
+ cluster_n = 2
+ elif N <= 32 * 1024:
+ cluster_n = 2
+ elif N <= 64 * 1024:
+ cluster_n = 4
+ elif N <= 128 * 1024:
+ cluster_n = 8
+ else:
+ cluster_n = 16
+ else:
+ if N <= 32 * 1024:
+ cluster_n = 1
+ elif N <= 64 * 1024:
+ cluster_n = 2
+ elif N <= 128 * 1024:
+ cluster_n = 4
+ elif N <= 256 * 1024:
+ cluster_n = 8
+ else:
+ cluster_n = 16
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ sm_count: Int32,
+ stream: cuda.CUstream,
+ ):
+ # Match forward's 32B alignment on the leading dimension to unlock
+ # wider vectorization when legal.
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mdO, mdResO, mdX, mdRes = [
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mdO, mdResO, mdX, mdRes)
+ ]
+
+ self._set_cluster_n()
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
+ )
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ if const_expr(mW is not None):
+ mW_expanded_layout = cute.prepend(
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
+ )
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
+
+ num_blocks = sm_count
+ kernel = (
+ self.kernel(
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
+ )
+ kernel.launch(
+ grid=[num_blocks, self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
+ smem=self._smem_size_in_bytes(
+ tiler_mn, num_warps, do_dtype=mdO.element_type
+ ),
+ stream=stream,
+ )
+
+
+_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_BWD_PTR_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+def _rmsnorm_bwd_sm100(
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dx: Tensor,
+ dw_partial: Optional[Tensor],
+ db_partial: Optional[Tensor] = None,
+ dresidual_out: Optional[Tensor] = None,
+ dresidual: Optional[Tensor] = None,
+ sm_count: Optional[int] = None,
+) -> None:
+ """SM100-specific RMSNorm backward dispatch.
+
+ Mirrors Quack's `quack.rmsnorm._rmsnorm_bwd`, but instantiates
+ `RMSNormBackwardSM100` (SM100-tuned heuristics).
+ """
+ assert x.dim() == 2, "Input must be 2D"
+ assert x.is_cuda, "Input tensor must be on CUDA device"
+ assert x.dtype in (torch.float16, torch.bfloat16, torch.float32)
+
+ if weight is not None:
+ assert weight.dim() == 1
+ assert x.shape[-1] == weight.shape[0]
+ assert weight.is_cuda
+ assert weight.dtype in (torch.float32, torch.bfloat16, torch.float16)
+ if dresidual_out is not None:
+ assert dresidual_out.shape == x.shape
+ assert dresidual_out.is_cuda
+ assert dresidual_out.dtype in (torch.float16, torch.bfloat16, torch.float32)
+ if dresidual is not None:
+ assert dresidual.shape == x.shape
+ assert dresidual.is_cuda
+ assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32)
+
+ M, N = x.size(0), x.size(1)
+ if dw_partial is None and db_partial is None:
+ assert sm_count is not None
+ else:
+ sm_count = (
+ dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
+ )
+
+ # Match Quack's conversion strategy for activations/gradients: keep the
+ # (M, N) layout dynamic without enforcing additional compact-shape
+ # constraints. This reduces per-call Python overhead for small-M shapes.
+ def _convert_layout_dynamic(t: Tensor) -> cute.Tensor:
+ return from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(
+ leading_dim=1
+ )
+
+ x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
+ _convert_layout_dynamic(t) if t is not None else None
+ for t in (x, dout, dresidual_out, dx, dresidual)
+ ]
+
+ if weight is not None:
+ weight_dtype = TORCH2CUTE_DTYPE[weight.dtype]
+ weight_tensor = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ divisibility=128 // weight_dtype.width,
+ )
+ else:
+ weight_tensor = None
+
+ dw_partial_tensor = (
+ from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if dw_partial is not None
+ else None
+ )
+ db_partial_tensor = (
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if db_partial is not None
+ else None
+ )
+ rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (
+ M,
+ N,
+ x_tensor.element_type,
+ weight_tensor.element_type if weight is not None else None,
+ db_partial.dtype if db_partial is not None else None,
+ dresidual.dtype if dresidual is not None else None,
+ dresidual_out.dtype if dresidual_out is not None else None,
+ )
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = RMSNormBackwardSM100(x_tensor.element_type, N)
+
+ # Shape-specific tuning overrides for DSv3-style N=8192 rows.
+ if isinstance(op, RMSNormBackwardSM100) and N == 8192:
+ if M >= 65536:
+ op._tpr_override = 256 # type: ignore[attr-defined]
+ op._nt_override = 256 # type: ignore[attr-defined]
+ elif M >= 16384:
+ op._tpr_override = 256 # type: ignore[attr-defined]
+
+ kernel = cute.compile(
+ op,
+ x_tensor,
+ weight_tensor,
+ dout_tensor,
+ dres_out_tensor,
+ rstd_tensor,
+ dx_tensor,
+ dw_partial_tensor,
+ dres_tensor,
+ db_partial_tensor,
+ Int32(sm_count if sm_count is not None else 0),
+ current_stream,
+ )
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(
+ x_tensor,
+ weight_tensor,
+ dout_tensor,
+ dres_out_tensor,
+ rstd_tensor,
+ dx_tensor,
+ dw_partial_tensor,
+ dres_tensor,
+ db_partial_tensor,
+ Int32(sm_count if sm_count is not None else 0),
+ current_stream,
+ )
+
+
+def _rmsnorm_bwd_sm100_ptr(
+ x: Tensor,
+ weight: Tensor,
+ dout: Tensor,
+ rstd: Tensor,
+ dx: Tensor,
+ dw_partial: Tensor,
+ sm_count: int,
+ *,
+ atomic_dw: bool = False,
+) -> None:
+ """Pointer-based SM100 RMSNorm backward launch (no DLPack conversions).
+
+ When `atomic_dw=True`, `dw_partial` is treated as a single (N,) fp32 buffer
+ and the kernel atomically accumulates weight gradients into it (avoids the
+ extra `dw_partial.sum(dim=0)` reduction kernel).
+ """
+ assert _can_use_ptr_path_bwd(x, weight, dout, rstd)
+ assert dx.shape == x.shape
+ assert dx.dtype == x.dtype
+ assert dw_partial.dtype == torch.float32
+
+ M, N = x.size(0), x.size(1)
+ if atomic_dw:
+ assert dw_partial.dim() == 1 and dw_partial.numel() == N
+ assert dw_partial.is_contiguous()
+ else:
+ assert dw_partial.dim() == 2 and dw_partial.shape[1] == N
+ device_index = x.get_device()
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ weight_dtype = TORCH2CUTE_DTYPE[weight.dtype]
+ assumed_align_x = 16
+ assumed_align_w = 32 if weight.dtype is torch.float32 else 16
+ assumed_align_dw = 32
+ assert (dw_partial.data_ptr() % assumed_align_dw) == 0
+
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ ld_val = int(x.stride(0))
+ key = (
+ "bwd_ptr",
+ N,
+ dtype,
+ weight_dtype,
+ int(assumed_align_x),
+ int(assumed_align_w),
+ int(assumed_align_dw),
+ device_index,
+ bool(atomic_dw),
+ )
+ compiled = _BWD_PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = RMSNormBackwardSM100(dtype, N)
+ op.atomic_dw = bool(atomic_dw)
+ # 16-bit activations + 16-bit weights (vLLM-style) backward at N=4096:
+ # Use a 1-row/CTA schedule with 256 threads/row. This reduces per-thread
+ # work and improves bandwidth on large-M shapes on SM100.
+ if (
+ (not atomic_dw)
+ and N == 4096
+ and dtype.width == 16
+ and weight_dtype.width == 16
+ ):
+ op._tpr_override = 256 # type: ignore[attr-defined]
+ op._nt_override = 256 # type: ignore[attr-defined]
+ # 16-bit activations + fp32 weights backward at N=4096:
+ # Use a 256-thread schedule (tpr=256) to improve throughput.
+ if (
+ (not atomic_dw)
+ and N == 4096
+ and dtype.width == 16
+ and weight_dtype is cutlass.Float32
+ ):
+ op._tpr_override = 256 # type: ignore[attr-defined]
+ op._nt_override = 256 # type: ignore[attr-defined]
+ # FP16 + fp32-weight DSv3 backward: Quack's default (1 row/CTA with
+ # 256 threads/row) underperforms. Use a 2-rows/CTA schedule (256 threads
+ # total, 128 threads/row) to improve memory-level parallelism.
+ if (
+ (not atomic_dw)
+ and N == 6144
+ and dtype is cutlass.Float16
+ and weight_dtype is cutlass.Float32
+ ):
+ op._tpr_override = 128 # type: ignore[attr-defined]
+ op._nt_override = 256 # type: ignore[attr-defined]
+
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_w = rt.make_ptr(
+ weight_dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_w,
+ )
+ ptr_dout = rt.make_ptr(
+ dtype,
+ dout.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_rstd = rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype,
+ dx.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_dw = rt.make_ptr(
+ cutlass.Float32,
+ dw_partial.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_dw,
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ ptr_dout,
+ ptr_rstd,
+ ptr_dx,
+ ptr_dw,
+ Int32(M),
+ Int32(N),
+ Int32(ld_val),
+ Int32(int(sm_count)),
+ stream,
+ )
+ _BWD_PTR_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_rmsnorm_bwd_launcher(
+ compiled=compiled,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ N=N,
+ device_index=device_index,
+ stream_handle=stream_handle,
+ has_weight=True,
+ has_dw_partial=True,
+ assumed_align_x=assumed_align_x,
+ assumed_align_w=assumed_align_w,
+ assumed_align_dw=assumed_align_dw,
+ )
+ if launcher is not None:
+ launcher.launch(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ M=M,
+ N=N,
+ ld=ld_val,
+ sm_count=int(sm_count),
+ )
+ return
+
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_w = rt.make_ptr(
+ weight_dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_w,
+ )
+ ptr_dout = rt.make_ptr(
+ dtype,
+ dout.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_rstd = rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype,
+ dx.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_x,
+ )
+ ptr_dw = rt.make_ptr(
+ cutlass.Float32,
+ dw_partial.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align_dw,
+ )
+ compiled(
+ ptr_x,
+ ptr_w,
+ ptr_dout,
+ ptr_rstd,
+ ptr_dx,
+ ptr_dw,
+ Int32(M),
+ Int32(N),
+ Int32(ld_val),
+ Int32(int(sm_count)),
+ stream,
+ )
+
+
+def rmsnorm_backward(
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dresidual_out: Optional[Tensor] = None,
+ has_bias: bool = False,
+ has_residual: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
+ """Public SM100 RMSNorm backward entry point.
+
+ Signature mirrors `quack.rmsnorm.rmsnorm_bwd` for easy comparisons.
+ """
+ device = x.device
+ M, N = x.size(0), x.size(1)
+ dx = torch.empty_like(x)
+ if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
+ dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
+ else:
+ dresidual = None
+
+ # Shared SM100 tuning policy (used by both RMSNorm and LayerNorm).
+ sm_count = get_sm_count(N, device, M=M, dtype=x.dtype)
+
+ # Quack-suite smallest case (M=8192, N=4096) is extremely sensitive to
+ # Python/allocator overhead because the kernel itself is very fast.
+ #
+ # The default `lite_quack.get_sm_count` adds a small-M occupancy boost for
+ # N=4096, which increases `dw_partial` size and can amplify allocator
+ # pressure in benchmark/verify loops. Clamp to Quack's baseline policy
+ # (`sm_count = num_sms * 2` for N=4096) for this regime.
+ if N == 4096 and M <= 8192 and x.dtype in (torch.float16, torch.bfloat16):
+ num_sms = qutils.get_num_sms(device)
+ sm_count = min(int(sm_count), int(num_sms) * 2)
+
+ use_atomic_dw = False
+ # DSv3 backward (N=6144/7168/8192) is dominated by the (sm_count, N) partial
+ # write + reduction for dW. Use the atomic-dW path to accumulate directly
+ # into a single (N,) fp32 buffer (no separate reduction kernel).
+ if (
+ weight is not None
+ and (not has_bias)
+ and (not has_residual)
+ and dresidual_out is None
+ and dresidual is None
+ and N == 8192
+ and weight.dtype is torch.float32
+ and M >= 65536
+ and x.dtype in (torch.float16, torch.bfloat16)
+ and _can_use_ptr_path_bwd(x, weight, dout, rstd)
+ ):
+ use_atomic_dw = True
+
+ if weight is not None:
+ if use_atomic_dw:
+ dw_partial = torch.zeros(N, device=device, dtype=torch.float32)
+ else:
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ else:
+ dw_partial = None
+ db_partial = (
+ torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ if has_bias
+ else None
+ )
+
+ if (
+ weight is not None
+ and dw_partial is not None
+ and (not has_bias)
+ and (not has_residual)
+ and dresidual_out is None
+ and dresidual is None
+ and _can_use_ptr_path_bwd(x, weight, dout, rstd)
+ ):
+ _rmsnorm_bwd_sm100_ptr(
+ x=x,
+ weight=weight,
+ dout=dout,
+ rstd=rstd,
+ dx=dx,
+ dw_partial=dw_partial,
+ sm_count=int(sm_count),
+ atomic_dw=bool(use_atomic_dw),
+ )
+ else:
+ _rmsnorm_bwd_sm100(
+ x,
+ weight,
+ dout,
+ rstd,
+ dx,
+ dw_partial,
+ db_partial,
+ dresidual_out,
+ dresidual,
+ sm_count,
+ )
+
+ if weight is not None and dw_partial is not None:
+ if use_atomic_dw:
+ dw_fp32 = dw_partial
+ else:
+ dw_fp32 = _reduce_partial_sum_fp32(dw_partial, device_index=x.get_device())
+ dw = dw_fp32 if weight.dtype is torch.float32 else dw_fp32.to(weight.dtype)
+ else:
+ dw = None
+ db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
+ if has_residual and dresidual is None:
+ dresidual = dx
+ return dx, dw, db, dresidual
+
+
+# Quack-style alias for benchmarks
+rmsnorm_bwd = rmsnorm_backward
+
+
+if __name__ == "__main__":
+ # Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness.
+ if not torch.cuda.is_available():
+ print("CUDA not available; functional test skipped.")
+ sys.exit(0)
+ M, N = 1024, 8192
+ dtype = torch.bfloat16
+ x = torch.randn(M, N, device="cuda", dtype=dtype)
+ w = torch.randn(N, device="cuda", dtype=dtype)
+ y_ref = rmsnorm_ref(x, w)
+ y, _, _ = rmsnorm_forward(x, w)
+ torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3)
+ print("RMSNormSM100 correctness check passed.")
+
+# (compile cache moved to top)
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
new file mode 100644
index 0000000..2b5b36d
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
@@ -0,0 +1,1007 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+RMSNorm kernel for SM100 (Blackwell) in CuteDSL, with the experimental
+stage-2 cp.async ping-pong path preserved for N≈6k/8k.
+
+This file is a fork of rmsnorm.py that keeps the K-loop cp.async path
+behind `self.stage > 1` while the main implementation has been simplified
+to a single-stage schedule.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import re
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, const_expr
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell import lite_quack as qutils
+from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce
+
+_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+def _parse_version_tuple(version: str) -> tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when
+# we detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
+
+
+@cute.jit
+def get_copy_atom_bw(
+ dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False
+) -> cute.CopyAtom:
+ max_bits = const_expr(128 if is_async else 256)
+ num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
+ from cutlass.cute.nvgpu import cpasync
+
+ copy_op = (
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL)
+ if is_async
+ else cute.nvgpu.CopyUniversalOp()
+ )
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
+
+
+@cute.jit
+def copy_tiled(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+) -> None:
+ atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async)
+ cute.copy(atom, src, dst, pred=pred)
+
+
+class RMSNormSM100WithStage2:
+ def __init__(
+ self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None
+ ):
+ self.N = N
+ self.dtype = dtype
+ self.stage = 1 if stage is None else stage
+ self.reduction_dtype = cutlass.Float32
+
+ def _threads_per_row(self) -> int:
+ N = self.N
+ if N <= 64:
+ return 8
+ elif N <= 128:
+ return 16
+ elif N <= 1024:
+ return 32
+ elif N <= 4096:
+ return 128
+ elif N <= 8192:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 128
+ elif N <= 16384:
+ return 256
+ else:
+ return 256
+
+ def _cluster_n(self) -> int:
+ N = self.N
+ if N <= 8192:
+ return 1
+ if const_expr(self.dtype.width == 16):
+ if N <= 16 * 1024:
+ return 2
+ elif N <= 32 * 1024:
+ return 2
+ elif N <= 64 * 1024:
+ return 4
+ elif N <= 128 * 1024:
+ return 8
+ else:
+ return 16
+ else:
+ if N <= 32 * 1024:
+ return 1
+ elif N <= 64 * 1024:
+ return 2
+ elif N <= 128 * 1024:
+ return 4
+ elif N <= 256 * 1024:
+ return 8
+ else:
+ return 16
+
+ def _num_threads(self) -> int:
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ return 128 if self.N <= 16384 else 256
+
+ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ num_threads = self._num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+ tpr = self._threads_per_row()
+ cluster_n = self._cluster_n()
+ num_cols_vec = cute.ceil_div(self.N, vecsize)
+ num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n)
+ cols_per_block = num_threads // tpr
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
+ tv_layout = cute.make_layout(
+ ((tpr, cols_per_block), (vecsize, num_blocks_N)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * tpr),
+ ),
+ )
+ return tiler_mn, tv_layout
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mRes, mO, mResO = [
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ copy_bits = const_expr(128)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = (
+ tv_layout.shape[0][0]
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._threads_per_row()
+ )
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+
+ if const_expr(mW is not None):
+ mW = cute.make_tensor(
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+
+ stage_bufs = 2 if self.stage > 1 else 1
+ tile_bytes_x = (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ )
+ tile_bytes_res = (
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn))
+ * stage_bufs
+ if const_expr(mRes is not None)
+ else 0
+ )
+ red_bytes = (
+ self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ )
+ mbar_bytes = self.stage * (cutlass.Int64.width // 8)
+ smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ )
+ )
+
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=([1, cluster_n, 1] if const_expr(cluster_n > 1) else None),
+ smem=smem_bytes,
+ stream=stream,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ cluster_n = self._cluster_n()
+ cluster_y = (
+ const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1]
+ )
+
+ smem = cutlass.utils.SmemAllocator()
+ sX0 = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ sX1 = (
+ smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(self.stage > 1)
+ else None
+ )
+ sRes0 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None)
+ else None
+ )
+ sRes1 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None and self.stage > 1)
+ else None
+ )
+
+ reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(
+ smem, num_warps, warps_per_row
+ )
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ num_copy_elems_X = tv_layout.shape[1][0]
+ use_async = const_expr(self.N >= 1024)
+ copy_atom = get_copy_atom_bw(
+ mX.element_type, num_copy_elems_X, is_async=use_async
+ )
+ thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
+
+ gW, gB = [
+ cute.local_tile(t, tiler_mn, (0, cluster_y))
+ if const_expr(t is not None)
+ else None
+ for t in (mW, mB)
+ ]
+ tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None
+ tXgB = thr_copy.partition_S(gB) if const_expr(mB is not None) else None
+ tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
+ tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
+ if const_expr(mW is not None):
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ )
+ if const_expr(mB is not None):
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ )
+
+ self._init_cluster(tidx, mbar_ptr)
+
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t)
+ if t is not None
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y))
+ gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y))
+ gRes_i = (
+ cute.local_tile(mRes_i, tiler_mn, (0, cluster_y))
+ if const_expr(mRes is not None)
+ else None
+ )
+ gResO_i = (
+ cute.local_tile(mResO_i, tiler_mn, (0, cluster_y))
+ if const_expr(mResO is not None)
+ else None
+ )
+ gRstd_i = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
+ if const_expr(mRstd is not None)
+ else None
+ )
+ cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
+ row_i = tXcX_i[0][0]
+ tXgRstd_i = (
+ thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ )
+
+ # Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2)
+ if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)):
+ vecsize = tv_layout.shape[1][0]
+ tpr = threads_per_row
+ target_tile_n = const_expr(4096 if shape[1] == 6144 else 8192)
+ tile_factor = const_expr(target_tile_n // (vecsize * tpr))
+ tile_n = vecsize * tpr * tile_factor
+ num_tiles = cute.ceil_div(shape[1], tile_n)
+
+ tiler_mn_tile = (tiler_mn[0], tile_n)
+ sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0))
+ sX1_tile = (
+ cute.local_tile(sX1, tiler_mn_tile, (0, 0))
+ if const_expr(self.stage > 1)
+ else None
+ )
+ sRes0_tile = (
+ cute.local_tile(sRes0, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+ sRes1_tile = (
+ cute.local_tile(sRes1, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None and self.stage > 1)
+ else None
+ )
+
+ tv_layout_tile = cute.make_layout(
+ ((tpr, tiler_mn[0]), (vecsize, tile_factor)),
+ stride=(
+ (vecsize * tiler_mn[0], 1),
+ (tiler_mn[0], tiler_mn[0] * vecsize * tpr),
+ ),
+ )
+ thr_copy_tile = cute.make_tiled_copy(
+ copy_atom, tv_layout_tile, tiler_mn_tile
+ ).get_slice(tidx)
+
+ sum_sq_acc = cute.Float32(0.0)
+ k_off0 = const_expr(0) * tile_n
+ gX_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgX_0 = thr_copy_tile.partition_S(gX_0)
+ tXsX_0 = thr_copy_tile.partition_D(sX0_tile)
+ cX_0 = cute.local_tile(
+ cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y)
+ )
+ tXc_0 = thr_copy_tile.partition_S(cX_0)
+ tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1])
+ tXp_ping = tXp_0
+ tXp_pong = tXp_0
+ if row_i < shape[0]:
+ copy_tiled(
+ tXgX_0,
+ tXsX_0,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_0,
+ )
+ if const_expr(mRes is not None):
+ gRes_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgRes_0 = thr_copy_tile.partition_S(gRes_0)
+ tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile)
+ copy_tiled(
+ tXgRes_0,
+ tXsRes_0,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_0,
+ )
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+
+ for t in cutlass.range_constexpr(num_tiles):
+ next_t = t + 1
+ if next_t < num_tiles:
+ k_off_n = next_t * tile_n
+ gX_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgX_n = thr_copy_tile.partition_S(gX_n)
+ cX_n = cute.local_tile(
+ cute.domain_offset((0, k_off_n), cX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXc_n = thr_copy_tile.partition_S(cX_n)
+ tXp_n = qutils.predicate_k(tXc_n, limit=shape[1])
+ if const_expr((t % 2) == 0):
+ tXsX_n = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ tXp_pong = tXp_n
+ else:
+ tXsX_n = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ tXp_ping = tXp_n
+ if row_i < shape[0]:
+ copy_tiled(
+ tXgX_n,
+ tXsX_n,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_n,
+ )
+ if const_expr(mRes is not None):
+ gRes_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgRes_n = thr_copy_tile.partition_S(gRes_n)
+ copy_tiled(
+ tXgRes_n,
+ tXsRes_n,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_n,
+ )
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ if const_expr(use_async):
+ cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0)
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ pred_cur = tXp_ping
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ pred_cur = tXp_pong
+ qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero)
+ if const_expr(mRes is not None):
+ qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero)
+
+ k_off = t * tile_n
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ gResO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mResO_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgResO_t = thr_copy_tile.partition_D(gResO_t)
+ tXrResO = cute.make_fragment_like(tXgResO_t)
+ tXrResO.store(x.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ copy_tiled(
+ tXrResO,
+ tXgResO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=pred_cur,
+ )
+
+ sum_sq_tile = row_reduce(
+ x * x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ hook_fn=(
+ cute.arch.cluster_wait if const_expr(cluster_n > 1) else None
+ ),
+ )
+ sum_sq_acc = sum_sq_acc + sum_sq_tile
+
+ rstd = cute.math.rsqrt(sum_sq_acc / shape[1] + eps, fastmath=True)
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ for t in cutlass.range_constexpr(num_tiles):
+ k_off = t * tile_n
+ cX_t = cute.local_tile(
+ cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y)
+ )
+ tXc_t = thr_copy_tile.partition_S(cX_t)
+ tXp_t = qutils.predicate_k(tXc_t, limit=shape[1])
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
+
+ qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero)
+ if const_expr(mRes is not None):
+ qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero)
+
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ y = x * rstd
+ if const_expr(mW is not None):
+ gW_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mW),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tWgW_t = thr_copy_tile.partition_S(gW_t)
+ tWrW_t = cute.make_fragment_like(tWgW_t)
+ copy_tiled(
+ tWgW_t,
+ tWrW_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
+ y = y * tWrW_t.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ gB_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mB),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tWgB_t = thr_copy_tile.partition_S(gB_t)
+ tWrB_t = cute.make_fragment_like(tWgB_t)
+ copy_tiled(
+ tWgB_t,
+ tWrB_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
+ y = y + tWrB_t.load().to(cute.Float32)
+
+ gO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mO_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
+ tXgO_t = thr_copy_tile.partition_D(gO_t)
+ tXrO = cute.make_fragment_like(tXgO_t)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ copy_tiled(
+ tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t
+ )
+
+ return
+
+ # Fallback: single-stage path identical to current rmsnorm.py
+ tXgX_i = thr_copy.partition_S(gX_i)
+ tXgRes_i = (
+ thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ )
+ tXgO_i = thr_copy.partition_D(gO_i)
+ tXgResO_i = (
+ thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ )
+ is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ tXpX_i = (
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1])
+ if not is_even_N_i
+ else None
+ )
+
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i)
+ if const_expr(mRes is not None):
+ cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i)
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ tXrX = cute.make_fragment_like(tXgX_i)
+ cute.autovec_copy(thr_copy.partition_D(sX0), tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ tXrRes = cute.make_fragment_like(tXgRes_i)
+ cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ tXrResO = cute.make_fragment_like(tXgResO_i)
+ tXrResO.store(x.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(
+ tXrResO.element_type, num_copy_elems_X, is_async=False
+ ),
+ tXrResO,
+ tXgResO_i,
+ )
+
+ sum_sq = row_reduce(
+ x * x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None),
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ y = x * rstd
+ if const_expr(mW is not None):
+ y = y * tXrW.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ y = y + tXrB.load().to(cute.Float32)
+
+ tXrO = cute.make_fragment_like(tXgO_i)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False),
+ tXrO,
+ tXgO_i,
+ )
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ num_warps,
+ warps_per_row,
+ threads_per_row,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ copy_bits = const_expr(128)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = self._num_threads()
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = self._threads_per_row()
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+
+ @cute.jit
+ def _alloc_reduction_and_mbar(
+ self,
+ smem: cutlass.utils.SmemAllocator,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
+ cluster_n = self._cluster_n()
+ red_layout = cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype, red_layout, byte_alignment=4
+ )
+ if const_expr(cluster_n > 1):
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
+ else:
+ mbar_ptr = None
+ return reduction_buffer, mbar_ptr
+
+ @cute.jit
+ def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]):
+ if const_expr(mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
+def rmsnorm_forward_with_stage2(
+ x: Tensor,
+ weight: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ store_rstd: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ assert x.is_cuda
+ assert x.dim() == 2
+ M, N = x.shape
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ def _convert_x(t: Tensor) -> cute.Tensor:
+ return from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(
+ leading_dim=1
+ )
+
+ mX = _convert_x(x)
+ mRes = _convert_x(residual) if residual is not None else None
+ out = torch.empty_like(x, dtype=x.dtype)
+ mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
+
+ mW = (
+ from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(
+ leading_dim=0
+ )
+ if weight is not None
+ else None
+ )
+ mB = (
+ from_dlpack(bias.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0)
+ if bias is not None
+ else None
+ )
+ if store_rstd:
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32)
+ mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+ else:
+ rstd = None
+ mRstd = None
+
+ residual_out = None
+ mResO = None
+ if residual is not None:
+ residual_out = torch.empty_like(residual)
+ mResO = from_dlpack(
+ residual_out.detach(), assumed_align=32
+ ).mark_layout_dynamic(leading_dim=1)
+
+ # Enable the intra-row cp.async K-loop only for DSv3-style large-N rows
+ # with very large M, where there is enough work per row to amortize the
+ # pipeline start-up cost. Mid-size M shapes are better served by the
+ # simpler single-stage schedule.
+ use_kloop = bool(M >= 65536 and N in (6144, 8192))
+ stage = 2 if use_kloop else 1
+ op = RMSNormSM100WithStage2(N, dtype, stage=stage)
+ if use_kloop:
+ op._tpr_override = 128 # type: ignore[attr-defined]
+ # Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match
+ # the original tuning there.
+ op._nt_override = 128 if N == 6144 else 256 # type: ignore[attr-defined]
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (
+ N,
+ dtype,
+ mRes is not None,
+ mW is not None,
+ mB is not None,
+ mResO is not None,
+ mRstd is not None,
+ stage,
+ )
+ compiled = _COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(
+ op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)
+ )
+ _COMPILE_CACHE[key] = compiled
+ compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps))
+ return out, rstd, residual_out
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
new file mode 100644
index 0000000..394ab48
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -0,0 +1,1582 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Softmax forward + backward kernels for SM100 (Blackwell) in CuteDSL.
+
+This module implements numerically stable softmax over the last dimension of
+2D tensors (M, N) and its backward pass, targeting SM100 with Quack-style
+tiling, cp.async pipelines, and cluster reductions, but without depending on
+the `quack` package at runtime.
+
+The kernels are self-contained and use only local helpers in
+`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import os
+import re
+from typing import Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.softmax requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell.fast_launch import (
+ StableI32Arg,
+ disable_fast_launch,
+ fast_launch_enabled,
+ set_runtime_ptr,
+ tls_cache as _tls_fast_launch_cache,
+)
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase,
+ fill_oob,
+ online_softmax_reduce,
+ predicate_k,
+ row_reduce,
+)
+
+_FWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {}
+_BWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {}
+_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+class _PtrSoftmaxFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_a: object,
+ ptr_b: object,
+ ptr_c: object | None,
+ arg_m: StableI32Arg,
+ arg_ld: StableI32Arg,
+ stream: cuda.CUstream,
+ assumed_align: int,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_a = ptr_a
+ self._ptr_b = ptr_b
+ self._ptr_c = ptr_c
+ self._arg_m = arg_m
+ self._arg_ld = arg_ld
+ self._stream = stream
+ self._assumed_align = int(assumed_align)
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_a_ptr = -1
+ self._last_b_ptr = -1
+ self._last_c_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+
+ def launch(
+ self,
+ *,
+ a_ptr: int,
+ b_ptr: int,
+ c_ptr: int | None,
+ M: int,
+ ld: int,
+ stream_handle: int,
+ dtype: type[cutlass.Numeric],
+ ) -> None:
+ if not fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ a_ptr=a_ptr,
+ b_ptr=b_ptr,
+ c_ptr=c_ptr,
+ M=M,
+ ld=ld,
+ stream_handle=stream_handle,
+ dtype=dtype,
+ )
+ return
+
+ if a_ptr != self._last_a_ptr:
+ try:
+ set_runtime_ptr(self._ptr_a, a_ptr)
+ self._last_a_ptr = a_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ a_ptr=a_ptr,
+ b_ptr=b_ptr,
+ c_ptr=c_ptr,
+ M=M,
+ ld=ld,
+ stream_handle=stream_handle,
+ dtype=dtype,
+ )
+ return
+
+ if b_ptr != self._last_b_ptr:
+ try:
+ set_runtime_ptr(self._ptr_b, b_ptr)
+ self._last_b_ptr = b_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ a_ptr=a_ptr,
+ b_ptr=b_ptr,
+ c_ptr=c_ptr,
+ M=M,
+ ld=ld,
+ stream_handle=stream_handle,
+ dtype=dtype,
+ )
+ return
+
+ if self._ptr_c is not None and c_ptr is not None:
+ if c_ptr != self._last_c_ptr:
+ try:
+ set_runtime_ptr(self._ptr_c, c_ptr)
+ self._last_c_ptr = c_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ a_ptr=a_ptr,
+ b_ptr=b_ptr,
+ c_ptr=c_ptr,
+ M=M,
+ ld=ld,
+ stream_handle=stream_handle,
+ dtype=dtype,
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ self._use_fast_launch = False
+ disable_fast_launch()
+
+ def _fallback_launch(
+ self,
+ *,
+ a_ptr: int,
+ b_ptr: int,
+ c_ptr: int | None,
+ M: int,
+ ld: int,
+ stream_handle: int,
+ dtype: type[cutlass.Numeric],
+ ) -> None:
+ stream = cuda.CUstream(int(stream_handle))
+ ptr_a = rt.make_ptr(
+ dtype,
+ a_ptr,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_b = rt.make_ptr(
+ dtype,
+ b_ptr,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ if self._ptr_c is not None and c_ptr is not None:
+ ptr_c = rt.make_ptr(
+ dtype,
+ c_ptr,
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ self._compiled(ptr_a, ptr_b, ptr_c, Int32(int(M)), Int32(int(ld)), stream)
+ else:
+ self._compiled(ptr_a, ptr_b, Int32(int(M)), Int32(int(ld)), stream)
+
+
+def _get_fast_ptr_softmax_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ assumed_align: int,
+ is_bwd: bool,
+) -> _PtrSoftmaxFastLaunch | None:
+ if not fast_launch_enabled():
+ return None
+ key = (
+ "ptr_fast_bwd" if is_bwd else "ptr_fast_fwd",
+ id(compiled),
+ int(N),
+ dtype,
+ int(device_index),
+ int(stream_handle),
+ int(assumed_align),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ assumed_align = int(assumed_align)
+ ptr_a = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_b = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_c = (
+ rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ if is_bwd
+ else None
+ )
+
+ arg_m = StableI32Arg(0)
+ arg_ld = StableI32Arg(N)
+ stream = cuda.CUstream(int(stream_handle))
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+ try:
+ if ptr_c is not None:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_a,
+ ptr_b,
+ ptr_c,
+ arg_m,
+ arg_ld,
+ stream,
+ )
+ else:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_a,
+ ptr_b,
+ arg_m,
+ arg_ld,
+ stream,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ disable_fast_launch()
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_a,
+ ptr_b,
+ ptr_c,
+ arg_m,
+ arg_ld,
+ stream,
+ *adapted_args,
+ )
+ launcher = _PtrSoftmaxFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_a=ptr_a,
+ ptr_b=ptr_b,
+ ptr_c=ptr_c,
+ arg_m=arg_m,
+ arg_ld=arg_ld,
+ stream=stream,
+ assumed_align=assumed_align,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+class SoftmaxFwdSM100(ReductionBase):
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # One-stage online reduction: pack (max, sum_exp) into Int64 reduction buffer.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _get_num_threads(self) -> int:
+ # SM100 tuning note:
+ # For N=4096, we use 32 threads per row (1 warp) and run 1 row per CTA
+ # (32 threads total). This keeps the reduction fully warp-local and
+ # improves throughput on this GB200 versus Quack's default 2-rows-per-CTA
+ # schedule with 64 threads per row (4 warps total).
+ if self.N == 4096:
+ return 32
+ return super()._get_num_threads()
+
+ def _calculate_threads_per_row(self) -> int:
+ # Match Quack's bucketed policy for Softmax.
+ N = self.N
+ if N == 4096:
+ return 32
+ if N == 6144:
+ return 128
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 3072:
+ return 32
+ if N <= 6144:
+ return 64
+ if N <= 16384:
+ return 128
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ # Quack-style growth of cluster_n with N and dtype.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> None:
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+ # Use the generic ReductionBase tiling with 128-bit vectorization.
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(mX, mO, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mO)
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_out: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions.
+
+ Reconstructs cute.Tensor views from raw pointers + explicit layouts
+ inside the JIT graph, matching the existing SM100 schedule.
+ """
+ # Mirror Quack/LayerNorm contracts: assume 16B alignment and an LD that
+ # preserves 128-bit vectorized copies for every row start.
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+ self.__call__(mX, mO, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Quack-style CTA tiling.
+ gX, gO, cX = [
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)
+ ]
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
+
+ # Copy atoms for gmem <-> smem and smem <-> gmem.
+ # Use 128-bit cp.async for global->shared and 128-bit vectorized stores.
+ copy_atom_load = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mX.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gO.element_type,
+ num_bits_per_copy=128,
+ )
+
+ num_copy_elems = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems))
+ thr_copy_load = cute.make_tiled_copy_tv(
+ copy_atom_load, thr_layout, val_layout
+ ).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy_tv(
+ copy_atom_store, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tXgX = thr_copy_load.partition_S(gX)
+ tXsX = thr_copy_load.partition_D(sX)
+ tXgO = thr_copy_store.partition_D(gO)
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
+
+ # Register fragments.
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ # Predicate and cp.async pipeline for potential tail tiles.
+ is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ # Online softmax reduction: compute max and sum_exp in a single pass, with
+ # optional cluster-wide aggregation via an Int64 reduction buffer.
+ max_x, denom, exp_x = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=True,
+ )
+
+ y = exp_x * cute.arch.rcp_approx(denom)
+ tXrO.store(y.to(tXrO.element_type))
+
+ tOpO = (
+ predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_store, tXrO, tXgO, pred=tOpO)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(mX, mO, tv_layout, tiler_mn)
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ ) -> None:
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(mX, mO, tv_layout, tiler_mn)
+
+
+class SoftmaxBwdSM100(ReductionBase):
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # One stage for dot(dy, y) per row.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
+
+ def _calculate_threads_per_row(self) -> int:
+ # Match Quack backward softmax buckets.
+ N = self.N
+ if N in (4096, 6144):
+ return 128
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 3072:
+ return 32
+ if N <= 6144:
+ return 64
+ if N <= 8192:
+ return 128
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ def _get_num_threads(self) -> int:
+ # Slightly more aggressive threading for large N than the base class.
+ return 128 if self.N <= 8192 else 256
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
+ # Store both y and dy tiles plus reduction buffers and mbarriers.
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
+ + self.stage
+ * num_warps
+ * self.cluster_n
+ * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8)
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mdY.element_type == self.dtype
+ assert mY.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ # Use the generic ReductionBase tiling with 128-bit vectorization.
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(mdY, mY, mdX, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mdY, mY, mdX)
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_dy: cute.Pointer,
+ ptr_y: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ mdY = cute.make_tensor(ptr_dy, layout_mn)
+ mY = cute.make_tensor(ptr_y, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ self.__call__(mdY, mY, mdX, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mdY.shape
+ idX = cute.make_identity_tensor(shape)
+
+ gdY, gY, gdX, cX = [
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y))
+ for mT in (mdY, mY, mdX, idX)
+ ]
+
+ smem = cutlass.utils.SmemAllocator()
+ sdY = smem.allocate_tensor(
+ mdY.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ sY = smem.allocate_tensor(
+ mY.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
+
+ copy_atom_load = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mdY.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=128,
+ )
+
+ num_copy_elems = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems))
+ thr_copy_load = cute.make_tiled_copy_tv(
+ copy_atom_load, thr_layout, val_layout
+ ).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy_tv(
+ copy_atom_store, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tdYgdY = thr_copy_load.partition_S(gdY)
+ tdYsdY = thr_copy_load.partition_D(sdY)
+ tYgY = thr_copy_load.partition_S(gY)
+ tYsY = thr_copy_load.partition_D(sY)
+ tdXgdX = thr_copy_store.partition_D(gdX)
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
+
+ tdYrdY, tYrY, tdXrdX = [
+ cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)
+ ]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n)
+ tdYpdY = (
+ predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
+ cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ cute.autovec_copy(tdYsdY, tdYrdY)
+ cute.autovec_copy(tYsY, tYrY)
+ dy = tdYrdY.load().to(Float32)
+ y = tYrY.load().to(Float32)
+ dot = row_reduce(
+ dy * y,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ )
+
+ dx = y * (dy - dot)
+ tdXrdX.store(dx.to(tdXrdX.element_type))
+
+ tdXpdX = (
+ predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn)
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ ) -> None:
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn)
+
+
+class SoftmaxFwdBwdSM100(ReductionBase):
+ """Fused softmax forward+backward producing dx from (x, dy).
+
+ Computes:
+ y = softmax(x)
+ dot = sum(dy * y)
+ dx = y * (dy - dot)
+
+ This avoids materializing the intermediate `y` in global memory, which is
+ the dominant overhead in a naive `softmax_backward(dy, softmax_forward(x))`
+ composition.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # Online softmax reduction uses an Int64 reduction buffer packing
+ # (max, sum_exp) pairs. We allocate a separate Float32 reduction buffer
+ # for dot(dy, y).
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _calculate_threads_per_row(self) -> int:
+ # Favor the backward bucket policy (better for the dot reduction).
+ N = self.N
+ if N in (4096, 6144):
+ return 128
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 3072:
+ return 32
+ if N <= 6144:
+ return 64
+ if N <= 8192:
+ return 128
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ # Quack-style growth of cluster_n with N and dtype.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ def _get_num_threads(self) -> int:
+ # Keep in sync with _calculate_threads_per_row.
+ return 128 if self.N <= 8192 else 256
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
+ # Allocation order:
+ # 1) sX (16B aligned)
+ # 2) sdY (16B aligned)
+ # 3) reduction_buffer_stats (8B aligned)
+ # 4) reduction_buffer_dot (8B aligned)
+ # 5) optional mbarrier array (8B aligned)
+ def _align_up(x: int, align: int) -> int:
+ return ((x + align - 1) // align) * align
+
+ tile_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)))
+ reduction_stats_bytes = int(
+ num_warps * self.cluster_n * (cutlass.Int64.width // 8)
+ )
+ reduction_dot_bytes = int(
+ num_warps * self.cluster_n * (cutlass.Float32.width // 8)
+ )
+ mbar_bytes = (
+ int(2 * (cutlass.Int64.width // 8)) if const_expr(self.cluster_n > 1) else 0
+ )
+
+ offset = _align_up(tile_bytes, 16)
+ offset = _align_up(offset, 16) + tile_bytes
+ offset = _align_up(offset, 8) + reduction_stats_bytes
+ offset = _align_up(offset, 8) + reduction_dot_bytes
+ offset = _align_up(offset, 8) + mbar_bytes
+ return int(offset)
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mdY: cute.Tensor,
+ mdX: cute.Tensor,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ assert mdY.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(mX, mdY, mdX, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mdY, mdX)
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_dy: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mdY = cute.make_tensor(ptr_dy, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ self.__call__(mX, mdY, mdX, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mdY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ cluster_y = (
+ const_expr(0)
+ if const_expr(self.cluster_n == 1)
+ else cute.arch.block_idx()[1]
+ )
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ gX, gdY, gdX, cX = [
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y))
+ for mT in (mX, mdY, mdX, idX)
+ ]
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ sdY = smem.allocate_tensor(
+ mdY.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+
+ reduction_layout = self._get_reduction_buffer_layout(tv_layout, self.cluster_n)
+ reduction_buffer_stats = smem.allocate_tensor(
+ cutlass.Int64, reduction_layout, byte_alignment=8
+ )
+ reduction_buffer_dot = smem.allocate_tensor(
+ cutlass.Float32, reduction_layout, byte_alignment=8
+ )
+
+ if const_expr(self.cluster_n > 1):
+ mbar_ptr_base = smem.allocate_array(cutlass.Int64, num_elems=2)
+ mbar_ptr_stats = mbar_ptr_base
+ mbar_ptr_dot = mbar_ptr_base + Int32(1)
+ else:
+ mbar_ptr_stats = None
+ mbar_ptr_dot = None
+
+ copy_atom_load = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mX.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=128,
+ )
+
+ num_copy_elems = (
+ tv_layout.shape[1]
+ if const_expr(cute.rank(tv_layout.shape[1]) == 1)
+ else tv_layout.shape[1][0]
+ )
+ threads_per_row = (
+ tv_layout.shape[0]
+ if const_expr(cute.rank(tv_layout.shape[0]) == 1)
+ else tv_layout.shape[0][0]
+ )
+ thr_layout = cute.make_ordered_layout(
+ (tiler_mn[0], threads_per_row), order=(1, 0)
+ )
+ val_layout = cute.make_layout((1, num_copy_elems))
+ thr_copy_load = cute.make_tiled_copy_tv(
+ copy_atom_load, thr_layout, val_layout
+ ).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy_tv(
+ copy_atom_store, thr_layout, val_layout
+ ).get_slice(tidx)
+
+ tXgX = thr_copy_load.partition_S(gX)
+ tXsX = thr_copy_load.partition_D(sX)
+ tdYgdY = thr_copy_load.partition_S(gdY)
+ tdYsdY = thr_copy_load.partition_D(sdY)
+ tdXgdX = thr_copy_store.partition_D(gdX)
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
+
+ tXrX, tdYrdY, tdXrdX = [
+ cute.make_fragment_like(thr) for thr in (tXgX, tdYgdY, tdXgdX)
+ ]
+
+ if const_expr(
+ self.cluster_n > 1
+ and mbar_ptr_stats is not None
+ and mbar_ptr_dot is not None
+ ):
+ if tidx < 2:
+ cute.arch.mbarrier_init(mbar_ptr_stats + tidx, 1)
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+ is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX)
+ cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+ fill_oob(tdYsdY, tXpX, 0.0)
+
+ cute.autovec_copy(tXsX, tXrX)
+ cute.autovec_copy(tdYsdY, tdYrdY)
+ x = tXrX.load().to(Float32)
+ dy = tdYrdY.load().to(Float32)
+
+ _, denom, exp_x = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer_stats[None, None, 0],
+ mbar_ptr_stats,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=True,
+ )
+ assert exp_x is not None
+ y = exp_x * cute.arch.rcp_approx(denom)
+
+ dot = row_reduce(
+ dy * y,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer_dot[None, None, 0],
+ mbar_ptr_dot,
+ phase=None,
+ init_val=0.0,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ )
+
+ dx = y * (dy - dot)
+ tdXrdX.store(dx.to(tdXrdX.element_type))
+
+ tOpO = (
+ predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tOpO)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mdY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn)
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mdY: cute.Tensor,
+ mdX: cute.Tensor,
+ ) -> None:
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn)
+
+
+def _convert_2d_tensor(x: Tensor) -> cute.Tensor:
+ # Match Quack's Softmax conversion exactly: assume 16B alignment and mark
+ # the shape compact with row-major stride order (0, 1), with mode=0 (batch).
+ # We intentionally do not call mark_layout_dynamic here to avoid the
+ # leading_dim stride==1 constraint used in RMSNorm.
+ return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0, stride_order=(0, 1)
+ )
+
+
+def _can_use_ptr_path_2d(x: Tensor) -> bool:
+ """Conservative guard for the pointer-based fast path."""
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.dtype not in TORCH2CUTE_DTYPE:
+ return False
+ # Require row-major last-dim contiguous.
+ if x.stride(1) != 1:
+ return False
+ # Require 16B alignment (matches from_dlpack(..., assumed_align=16)).
+ if (x.data_ptr() % 16) != 0:
+ return False
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ # Softmax uses ReductionBase default num_copy_bits=128, so N must be divisible.
+ if (x.shape[1] % divby) != 0:
+ return False
+ # Ensure each row start remains aligned for 128-bit vectorized copies.
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None:
+ """Launch the pointer-based Softmax forward kernel into preallocated `out`."""
+ assert x.is_cuda and x.dim() == 2
+ assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype
+ assert out.stride() == x.stride(), "Pointer path expects out to match x strides"
+
+ M, N = x.shape
+ device_index = x.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ key = ("ptr_fwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = SoftmaxFwdSM100(dtype_x, int(N))
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_out,
+ Int32(int(M)),
+ ld,
+ stream,
+ )
+ _PTR_FWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_softmax_launcher(
+ compiled=compiled,
+ dtype=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ assumed_align=16,
+ is_bwd=False,
+ )
+ if launcher is not None:
+ launcher.launch(
+ a_ptr=int(x.data_ptr()),
+ b_ptr=int(out.data_ptr()),
+ c_ptr=None,
+ M=int(M),
+ ld=int(x.stride(0)),
+ stream_handle=stream_handle,
+ dtype=dtype_x,
+ )
+ return
+
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream)
+
+
+def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None:
+ """Launch the pointer-based Softmax backward kernel into preallocated `dx`."""
+ assert dy.is_cuda and dy.dim() == 2
+ assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype
+ assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype
+ assert dy.stride() == y.stride() == dx.stride(), (
+ "Pointer path expects matching strides"
+ )
+
+ M, N = dy.shape
+ device_index = dy.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[dy.dtype]
+ key = ("ptr_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_BWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = SoftmaxBwdSM100(dtype_x, int(N))
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_y = rt.make_ptr(
+ dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ld = Int32(int(dy.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_dy,
+ ptr_y,
+ ptr_dx,
+ Int32(int(M)),
+ ld,
+ stream,
+ )
+ _PTR_BWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_softmax_launcher(
+ compiled=compiled,
+ dtype=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ assumed_align=16,
+ is_bwd=True,
+ )
+ if launcher is not None:
+ launcher.launch(
+ a_ptr=int(dy.data_ptr()),
+ b_ptr=int(y.data_ptr()),
+ c_ptr=int(dx.data_ptr()),
+ M=int(M),
+ ld=int(dy.stride(0)),
+ stream_handle=stream_handle,
+ dtype=dtype_x,
+ )
+ return
+
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_y = rt.make_ptr(
+ dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream)
+
+
+def _softmax_fwd_bwd_ptr_into(*, x: Tensor, dy: Tensor, dx: Tensor) -> None:
+ """Launch the fused pointer-based Softmax fwd+bwd kernel into preallocated `dx`."""
+ assert x.is_cuda and x.dim() == 2
+ assert dy.is_cuda and dy.shape == x.shape and dy.dtype == x.dtype
+ assert dx.is_cuda and dx.shape == x.shape and dx.dtype == x.dtype
+ assert x.stride() == dy.stride() == dx.stride(), (
+ "Pointer path expects matching strides"
+ )
+
+ M, N = x.shape
+ device_index = x.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = SoftmaxFwdBwdSM100(dtype_x, int(N))
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_dy,
+ ptr_dx,
+ Int32(int(M)),
+ ld,
+ stream,
+ )
+ _PTR_FWDBWD_COMPILE_CACHE[key] = compiled
+
+ launcher = _get_fast_ptr_softmax_launcher(
+ compiled=compiled,
+ dtype=dtype_x,
+ N=int(N),
+ device_index=int(device_index),
+ stream_handle=stream_handle,
+ assumed_align=16,
+ is_bwd=True,
+ )
+ if launcher is not None:
+ launcher.launch(
+ a_ptr=int(x.data_ptr()),
+ b_ptr=int(dy.data_ptr()),
+ c_ptr=int(dx.data_ptr()),
+ M=int(M),
+ ld=int(x.stride(0)),
+ stream_handle=stream_handle,
+ dtype=dtype_x,
+ )
+ return
+
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ compiled(ptr_x, ptr_dy, ptr_dx, Int32(int(M)), Int32(int(x.stride(0))), stream)
+
+
+def softmax_forward(x: Tensor) -> Tensor:
+ """SM100 CuteDSL softmax forward pass: y = softmax(x, dim=-1)."""
+ assert x.dim() == 2, "Input must be 2D (M, N)"
+ assert x.is_cuda, "Input must be on CUDA device"
+ assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype"
+
+ N = x.size(1)
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ if _can_use_ptr_path_2d(x):
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ _softmax_forward_ptr_into(x=x, out=out)
+ return out
+
+ out = torch.empty_like(x)
+
+ x_tensor = _convert_2d_tensor(x)
+ out_tensor = _convert_2d_tensor(out)
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype, N)
+ kernel = _FWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = SoftmaxFwdSM100(dtype, N)
+ kernel = cute.compile(op, x_tensor, out_tensor, current_stream)
+ _FWD_COMPILE_CACHE[compile_key] = kernel
+ kernel(x_tensor, out_tensor, current_stream)
+ return out
+
+
+def softmax_backward(dy: Tensor, y: Tensor) -> Tensor:
+ """SM100 CuteDSL softmax backward pass."""
+ assert dy.dim() == 2 and y.dim() == 2, "dy and y must be 2D (M, N)"
+ assert dy.shape == y.shape, "dy and y must have the same shape"
+ assert dy.is_cuda and y.is_cuda, "dy and y must be on CUDA device"
+ assert dy.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype"
+ assert y.dtype == dy.dtype, "dy and y must have the same dtype"
+
+ N = dy.size(1)
+ dtype = TORCH2CUTE_DTYPE[dy.dtype]
+ if (
+ _can_use_ptr_path_2d(dy)
+ and _can_use_ptr_path_2d(y)
+ and dy.stride() == y.stride()
+ ):
+ dx = torch.empty_strided(
+ dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype
+ )
+ _softmax_backward_ptr_into(dy=dy, y=y, dx=dx)
+ return dx
+
+ dx = torch.empty_like(dy)
+
+ dy_tensor = _convert_2d_tensor(dy)
+ y_tensor = _convert_2d_tensor(y)
+ dx_tensor = _convert_2d_tensor(dx)
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype, N)
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = SoftmaxBwdSM100(dtype, N)
+ kernel = cute.compile(op, dy_tensor, y_tensor, dx_tensor, current_stream)
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+ kernel(dy_tensor, y_tensor, dx_tensor, current_stream)
+ return dx
+
+
+def softmax_fwd_bwd(dy: Tensor, x: Tensor) -> Tensor:
+ """Fused softmax forward+backward producing ``dx`` from ``(x, dy)``.
+
+ This is intended for benchmarks and training-like use-cases where the
+ intermediate ``y = softmax(x)`` is not needed outside the backward pass.
+ """
+ assert x.dim() == 2 and dy.dim() == 2, "x and dy must be 2D (M, N)"
+ assert x.shape == dy.shape, "x and dy must have the same shape"
+ assert x.is_cuda and dy.is_cuda, "x and dy must be on CUDA device"
+ assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype"
+ assert dy.dtype == x.dtype, "x and dy must have the same dtype"
+
+ if (
+ _can_use_ptr_path_2d(x)
+ and _can_use_ptr_path_2d(dy)
+ and x.stride() == dy.stride()
+ ):
+ dx = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ _softmax_fwd_bwd_ptr_into(x=x, dy=dy, dx=dx)
+ return dx
+
+ with torch.no_grad():
+ return softmax_backward(dy, softmax_forward(x))
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ y = softmax_forward(x)
+ ctx.save_for_backward(y)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy: Tensor) -> tuple[Tensor]:
+ (y,) = ctx.saved_tensors
+ dx = softmax_backward(dy, y)
+ return dx
+
+
+def softmax(x: Tensor) -> Tensor:
+ """Autograd-friendly softmax using the SM100 CuteDSL kernel."""
+ return SoftmaxFunction.apply(x)
+
+
+def _torch_softmax_reference(x: Tensor) -> Tensor:
+ return torch.nn.functional.softmax(x, dim=-1)
+
+
+def verify_softmax_parity(
+ M: int,
+ N: int,
+ dtype: torch.dtype = torch.bfloat16,
+ atol: float = 5e-2,
+ rtol: float = 5e-2,
+) -> None:
+ """Compare SM100 CuteDSL softmax against PyTorch for a single shape."""
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ x.requires_grad_(True)
+
+ # Forward parity
+ y_ref = _torch_softmax_reference(x)
+ y = softmax(x)
+ torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
+
+ # Backward parity
+ dy = torch.randn_like(y)
+ (dx_ref,) = torch.autograd.grad(y_ref, x, dy, retain_graph=False)
+ dx = softmax_backward(dy, y)
+ torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)