Skip to content

Conversation

@wuisawesome
Copy link
Contributor

@wuisawesome wuisawesome commented Apr 18, 2025

This PR adds support for loading custom transformers models with NamedParameters.

This is generally helpful for models which have something of the form

class Model(nn.Module):
  def __init__(self):
    self.my_scalar_hyperparam = nn.Parameter([0])

  def forward(self, x):
    x = x * self.my_scalar_hyperparam

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@wuisawesome wuisawesome changed the title [WIP] minor transformers bug fixes Support loading transformers models with named parameters Apr 18, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for loading custom transformers models with named parameters by initializing parameters from the meta device onto the execution device.

  • Added a new method init_parameters to recursively initialize parameters on the meta device.
  • Adjusted model initialization to call init_parameters right after setting up buffers.

Comment on lines +324 to +320
Copy link

Copilot AI Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using module.register_parameter(name, new_param) instead of setattr(module, name, new_param) to ensure the parameter is correctly registered in the module's parameter dictionary.

Suggested change
setattr(module, name, new_param)
for child in module.children():
module.register_parameter(name, new_param)

Copilot uses AI. Check for mistakes.
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is necessary, I'd prefer to parameterise init_buffers to work with buffers and parameters instead of duplicating the logic.

However, I'd like to find out why meta_to_device missed some parameters that your init_parameters didn't?

@wuisawesome
Copy link
Contributor Author

However, I'd like to find out why meta_to_device missed some parameters that your init_parameters didn't?

I guess a parameter initialized on a meta device just doesn't have an associated buffer?

I think this simplified repro illustrates what's going on

import torch # torch 2.6
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.my_param = nn.Parameter(torch.randn(10, 10))

with torch.device("meta"):
    model = Model()
    print(model.my_param)


for name, buffer in model.named_buffers():
    print("buffer ", name, buffer.device)

for name, param in model.named_parameters():
    print("param ", name, param.device)
    
print(f"Total buffers {len(list(model.buffers()))}")
print(f"Total params {len(list(model.parameters()))}")

@hmellor
Copy link
Member

hmellor commented Apr 22, 2025

But meta_to_empty iterates through all buffers and parameters then moves them to the device if they are still on the meta device:

def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
module.to_empty(device=self.device_config.device)
return # We can stop recursing because to_empty is recursive
for child in module.children():
self.meta_to_empty(child)

If I add the following to the end of your repro, then my_param is correctly moved to the device:

device="cuda"

def meta_to_empty(module: nn.Module):
    tensors = list(chain(module.buffers(), module.parameters()))
    if tensors and all(t.device == torch.device("meta") for t in tensors):
        module.to_empty(device=device)
        return  # We can stop recursing because to_empty is recursive
    for child in module.children():
        meta_to_empty(child)

meta_to_empty(model)
print(model.my_param)

@wuisawesome
Copy link
Contributor Author

wuisawesome commented Apr 24, 2025

So it seems like meta_to_empty() doesn't handle this case properly because the tensor parallel code replaces some of the meta parameters (corresponding to the tp-ed linear layers) with parameters on-device (not sure if this is desired behavior...).

I verified that all the non-meta parameters in my model were corresponding to the ColumnParallelLinear and RowParallelLinear layers with print statements.

In my particular case, there actually aren't any named buffers so init_buffers does nothing.

This illustrates what seems to be happening

import torch # torch 2.6
import torch.nn as nn

from itertools import chain

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.my_param = nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
        self.linear = nn.Linear(10, 10)

def init_buffers(module: nn.Module):
    for name, buffer in module.named_buffers(recurse=False):
        if buffer.device == torch.device("meta"):
            new_buffer = nn.Linear(10, 10, device="cuda", bias=False)
            setattr(module, name, new_buffer)
    for child in module.children():
        init_buffers(child)


def meta_to_empty(module: nn.Module):
    tensors = list(chain(module.buffers(), module.parameters()))
    all_meta = all(t.device == torch.device("meta") for t in tensors)
    if tensors and all_meta:
        module.to_empty(device="cuda")
        return  # We can stop recursing because to_empty is recursive
    for child in module.children():
        meta_to_empty(child)


with torch.device("meta"):
    model = Model()

# Simulate the `replace_linear_class` calls in `tensor_parallel()`
model.linear = nn.Linear(10, 10, device="cuda", bias=False)

init_buffers(model)
meta_to_empty(model)

assert model.my_param.device == torch.device("meta")

@hmellor
Copy link
Member

hmellor commented Apr 25, 2025

Yeah it's expected that the TP substitutions create tensors on device. This way the tensor parallel linear layers are instantiated like normal rather than being created on the meta device and potentially initialised incorrectly.

meta_to_empty is meant to sweep the remaining parameters and buffers to move them over if they're still on the meta device. Clearly something isn't working as expected. I'll look into it tomorrow.

Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!


The original issue was that meta_to_empty specifically doesn't work for nn.Parameters that are direct children of the model. This is because they don't appear in module.children() (because they're not nn.Modules) and therefore never reach module.to_empty (which would fail as nn.Parameter does not have this method).

After this PR I will follow up on the following two points:

  • Is meta_to_empty even necessary anymore?
  • We need to better handle buffers that are direct children of the model (right now, init_buffers would try and instantiate the whole model on one GPU, which would almost certainly cause OOM since most of the model is already on the GPU)

@wuisawesome
Copy link
Contributor Author

Updated the doc string. At first glance, the CI failure looks unrelated, but will defer to you.

Is meta_to_empty even necessary anymore?

I can't think of any real world models where it is needed, but in theory you could have modules which don't have parameters (some rope, norms, and quantization specific modules could look like this). In all of those cases, I think most people would explicitly set the device on any tensors they create to the same device as the input tensor though.

@hmellor
Copy link
Member

hmellor commented Apr 25, 2025

The main tests have not run yet, could you please make the DCO check pass then I'll trigger the CI suite

@wuisawesome wuisawesome force-pushed the alex-vllm-transformers branch 2 times, most recently from 7a1205a to baaf5dd Compare April 28, 2025 17:52
@hmellor hmellor enabled auto-merge (squash) April 28, 2025 17:53
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 28, 2025
auto-merge was automatically disabled April 28, 2025 19:22

Head branch was pushed to by a user without write access

@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) structured-output speculative-decoding labels Apr 28, 2025
@mergify mergify bot added v1 tpu Related to Google TPUs tool-calling labels Apr 28, 2025
Alex added 5 commits April 28, 2025 19:24
Signed-off-by: Alex <[email protected]>
Signed-off-by: Alex <[email protected]>
Signed-off-by: Alex <[email protected]>
@wuisawesome wuisawesome force-pushed the alex-vllm-transformers branch from ffa2235 to 6026465 Compare April 28, 2025 19:24
@mergify mergify bot removed the tpu Related to Google TPUs label Apr 28, 2025
@wuisawesome
Copy link
Contributor Author

I messed up a rebase, apologies for spamming everyone

@hmellor hmellor merged commit 6e74fd4 into vllm-project:main Apr 28, 2025
45 checks passed
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding structured-output tool-calling v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants