Skip to content

[Bug] Potential issue when load_weights #8073

Open
@Tavish9

Description

@Tavish9

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Currently, the loading using self.named_parameters(), however, this method just returns trainable parameters.

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))

For non-trainable, it would be skipped or cause the loading error.

Reproduction

All registered models work well, but not for custom model.

Let's check simple BatchNorm1d layer, although LLM does not use it any more, some vision towers use this.

import torch

torch.nn.BatchNorm1d(4)
>>> BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
list(norm.named_parameters())
>>> [('weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
norm.state_dict()
>>> OrderedDict([('weight', tensor([1., 1., 1., 1.])), ('bias', tensor([0., 0., 0., 0.])), ('running_mean', tensor([0., 0., 0., 0.])), ('running_var', tensor([1., 1., 1., 1.])), ('num_batches_tracked', tensor(0))])

When saving ckpt containing layers above, and loading this using sglang, it would cause

KeyError: 'custom_model.num_batches_tracked'
KeyError: 'custom_model.running_mean'
KeyError: 'custom_model.running_var'

If we skip these non-trainable keys, we got

print(self.state_dict()["custom_model.running_mean"])
tensor([0., 0., 0., 0.], device='cuda:0')

Recommend replacing params_dict = dict(self.named_parameters()) with params_dict = self.state_dict()? Also need to convert tensor to nn.Parameter

Environment

CUDA available: True
GPU 0: NVIDIA A800-SXM4-80GB
GPU 0 Compute Capability: 8.0
CUDA_HOME: /mnt/petrelfs/share/cuda-12.4
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.90.07
PyTorch: 2.6.0+cu124
sglang: 0.4.6.post5
sgl_kernel: 0.1.4
flashinfer_python: 0.2.5+cu124torch2.6
triton: 3.2.0
transformers: 4.51.1
torchao: 0.9.0
numpy: 2.2.6
aiohttp: 3.12.14
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.33.4
interegular: 0.3.3
modelscope: 1.28.0
orjson: 3.10.18
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.0
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.19
openai: 1.95.1
tiktoken: 0.9.0
anthropic: 0.57.1
litellm: 1.74.3
decord: 0.6.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV8     NV8     NV8     NV8     NV8     NV8     NV8     PXB     NODE    SYS     SYS     NODE    0-31,64-95      0               N/A
GPU1    NV8      X      NV8     NV8     NV8     NV8     NV8     NV8     PXB     NODE    SYS     SYS     NODE    0-31,64-95      0               N/A
GPU2    NV8     NV8      X      NV8     NV8     NV8     NV8     NV8     NODE    PXB     SYS     SYS     NODE    0-31,64-95      0               N/A
GPU3    NV8     NV8     NV8      X      NV8     NV8     NV8     NV8     NODE    PXB     SYS     SYS     NODE    0-31,64-95      0               N/A
GPU4    NV8     NV8     NV8     NV8      X      NV8     NV8     NV8     SYS     SYS     PXB     NODE    SYS     32-63,96-127    1               N/A
GPU5    NV8     NV8     NV8     NV8     NV8      X      NV8     NV8     SYS     SYS     PXB     NODE    SYS     32-63,96-127    1               N/A
GPU6    NV8     NV8     NV8     NV8     NV8     NV8      X      NV8     SYS     SYS     NODE    PXB     SYS     32-63,96-127    1               N/A
GPU7    NV8     NV8     NV8     NV8     NV8     NV8     NV8      X      SYS     SYS     NODE    PXB     SYS     32-63,96-127    1               N/A
NIC0    PXB     PXB     NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    SYS     SYS     NODE
NIC1    NODE    NODE    PXB     PXB     SYS     SYS     SYS     SYS     NODE     X      SYS     SYS     NODE
NIC2    SYS     SYS     SYS     SYS     PXB     PXB     NODE    NODE    SYS     SYS      X      NODE    SYS
NIC3    SYS     SYS     SYS     SYS     NODE    NODE    PXB     PXB     SYS     SYS     NODE     X      SYS
NIC4    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_2
  NIC1: mlx5_3
  NIC2: mlx5_4
  NIC3: mlx5_5
  NIC4: mlx5_bond_1


ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions