Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions oink/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# KernelAgent Oink (vLLM plugin)

This subproject provides an **out-of-tree vLLM plugin** that registers
`torch.library.custom_op` entrypoints under the `oink::` namespace:

- `torch.ops.oink.rmsnorm`
- `torch.ops.oink.fused_add_rms_norm`

The implementation is backed by a CuTeDSL (CUTLASS) RMSNorm kernel tuned for
**NVIDIA Blackwell (SM100)**.

## Install (editable)

From the `KernelAgent` repo root:

```bash
pip install -e ./oink
```

This plugin requires the CuTeDSL stack:

```bash
pip install nvidia-cutlass-dsl cuda-python
```

## Use with vLLM

1. Enable the vLLM integration:

```bash
export VLLM_USE_OINK_RMSNORM=1
```

2. Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` /
CUDA graphs. In Python:

```python
from vllm import LLM

llm = LLM(
model=...,
tensor_parallel_size=...,
enforce_eager=False,
compilation_config={"custom_ops": ["none", "+rms_norm"]},
)
```

Without `+rms_norm`, Inductor may fuse RMSNorm into larger Triton kernels and
neither vLLM's CUDA RMSNorm nor Oink will run.

## Notes

- This plugin is designed to be **safe to import even when disabled**; it only
registers ops when `VLLM_USE_OINK_RMSNORM` is truthy (`"1"` / `"true"`).
- The ops preserve **padded-row layouts** for 2D tensors (shape `[M, N]`,
`stride(1) == 1`, and potentially `stride(0) > N`), which is required for
`torch.compile` stride verification on some models (e.g., MLA padded inputs).
29 changes: 29 additions & 0 deletions oink/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "kernelagent-oink"
version = "0.1.0"
description = "vLLM plugin that registers Oink Blackwell RMSNorm custom ops"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
authors = [{name = "PyTorch Labs"}]

# Keep dependencies minimal, but include the CuTeDSL stack required by the
# Blackwell RMSNorm implementation.
#
# We intentionally do NOT depend on `torch` here because vLLM already pins and
# provides a compatible PyTorch build.
dependencies = [
"nvidia-cutlass-dsl",
"cuda-python",
]

[project.entry-points."vllm.general_plugins"]
oink = "kernelagent_oink:register"

[tool.setuptools.packages.find]
where = ["src"]
include = ["kernelagent_oink*"]
111 changes: 111 additions & 0 deletions oink/src/kernelagent_oink/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import os

logger = logging.getLogger(__name__)

_OPS_REGISTERED = False


def _env_truthy(name: str) -> bool:
val = os.environ.get(name)
if val is None:
return False
return val.strip().lower() in ("1", "true", "yes", "on")


def _infer_cuda_device_index() -> int:
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
try:
return int(local_rank)
except ValueError:
pass
return 0
Comment on lines +45 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore suggestion if we want to guard/enable on "off/on" "yes/no"

Suggested change
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
try:
return int(local_rank)
except ValueError:
pass
return 0
rank = os.environ.get("LOCAL_RANK", "0")
return int(rank)



def _compute_cutedsl_arch(major: int, minor: int) -> str:
# CuTeDSL uses an "a" suffix for >= Hopper.
suffix = "a" if major >= 9 else ""
# Match cutlass/base_dsl/env_manager.py: map sm_110 -> sm_101.
if major == 11 and minor == 0:
major, minor = 10, 1
return f"sm_{major}{minor}{suffix}"


def register() -> None:
"""vLLM plugin entrypoint.

This function must be safe to call multiple times and must not raise.
vLLM executes it in multiple processes (engine + workers).
"""
global _OPS_REGISTERED

if _OPS_REGISTERED:
return

# Gate on the vLLM integration flag so installing the package does not
# change behavior unless explicitly enabled.
if not _env_truthy("VLLM_USE_OINK_RMSNORM"):
return

try:
import torch
except Exception as e: # pragma: no cover
logger.debug("Oink plugin: torch import failed: %s", e)
return

try:
if not torch.cuda.is_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not torch.cuda.is_available():
if not torch.cuda.is_available():
logger.debug("Oink plugin: torch.cuda not found)

return
device_index = _infer_cuda_device_index()
major, minor = torch.cuda.get_device_capability(device_index)
sm = 10 * int(major) + int(minor)
if sm < 100:
return

# Ensure required deps are importable before registering ops so that vLLM
# doesn't detect ops that would later fail at first use.
try:
import cutlass # noqa: F401
import cuda.bindings.driver as _cuda # noqa: F401
except Exception as e:
logger.warning(
"Oink plugin: CuTeDSL deps missing; skipping op registration. "
"Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s",
e,
)
return

# Ensure CuTeDSL sees a target arch early. If the user has already set it,
# respect their choice.
os.environ.setdefault(
"CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor))
)

# Import registers the ops via torch.library.custom_op decorators.
from .blackwell import oink_custom_ops # noqa: F401
except Exception as e: # pragma: no cover
# Do not raise: vLLM plugin loader does not guard plugin execution.
logger.exception("Oink plugin: failed to register ops: %s", e)
return

_OPS_REGISTERED = True


__all__ = ["register"]
17 changes: 17 additions & 0 deletions oink/src/kernelagent_oink/blackwell/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

__all__ = []
Loading