Skip to content

Conversation

@bryant1410
Copy link
Contributor

@bryant1410 bryant1410 commented Aug 6, 2025

I guess this is a borderline bug fix/feature addition. There's the common use of using context manager that sets the model in the meta device by default when transitively loading a model (such as accelerator.init_empty_weights or what's used in transformers by default for from_pretrained since a couple of months ago. The problem is that open_clip.factory.create_model moves the model to a device, and thus it's incompatible for this. I added a check for this case and a log message for it.

This is a code block to reproduce this error (it works fine after the fix):

import open_clip
import transformers

class Config(transformers.PretrainedConfig):
    pass

class Model(transformers.PreTrainedModel):
    config_class = Config

    def __init__(self, config):
        super().__init__(config)
        self.another_model = open_clip.create_model_and_transforms("ViT-B-32-quickgelu", pretrained="openai")[0]


a = Model(Config())
a.save_pretrained("/tmp/abc")

b = Model.from_pretrained("/tmp/abc")

@rwightman
Copy link
Collaborator

So, understand your aim here, but wouldn't that conflict with someone wanting to init on the meta device AND load OpenCLIP pretrained weights?

I'm aware both OpenCLIP and timm need proper support for this, hasn't been a priority with the other things. But need a soln that works in the various cases...

@rwightman
Copy link
Collaborator

rwightman commented Aug 6, 2025

To provide more detail on my thoughs, I have some ideas to impl a two-phase init approach for OpenCLIP / timm, where the factories will allow the model to be created on meta device (under meta-device context) w/ no weight-init, and then can either load weights or trigger final init via a consistent _init_weights / init_weights call if no pretrained loaded.

Now, in your use case it looks like you're using an OpenCLIP model as a submodule in a larger one that has a transformers base. So, already had this hiccup running through scenarios. If a timm or OpenCLIP model factor is called from within a parent context, we need to detect and adapt the two-phase approach assuming that the parent will be taking the final action re weight loading or triggering init. Does that make sense or did I get the intention in your use case wrong?

@bryant1410
Copy link
Contributor Author

So, understand your aim here, but wouldn't that conflict with someone wanting to init on the meta device AND load OpenCLIP pretrained weights?

In my use case, the actual weights are loaded at another point (the 2nd call from_pretrained, from the file system), so I don't need open_clip to do it for me.

I'm aware both OpenCLIP and timm need proper support for this, hasn't been a priority with the other things. But need a soln that works in the various cases...

Makes total sense. It's completely understandable that it hasn't been a priority. I don't know all the popular use cases, but I'd assume it wouldn't break stuff because if the model is in the meta device the current code doesn't work at all. I think the patch I'm proposing is a minimal-intervention solution to make it work when it's in the meta device (to prevent a crash), and shouldn't break the cases when the model is not. Note that the model can only be in the meta device because of meta-programming; the method torch.nn.Module.register_parameter has been monkey-patched. So the root cause of this issue is because of a "non-elegant" solution from accelerate/transformers/others, which is fine when we want performance but it modifies the underlying API many people use, including in this library (i.e., you expect self.weight = nn.Parameter(...) to behave predictably). And I guess library maintainers are then left to decide whether to support these use cases.

To provide more detail on my thoughs, I have some ideas to impl a two-phase init approach for OpenCLIP / timm, where the factories will allow the model to be created on meta device (under meta-device context) w/ no weight-init, and then can either load weights or trigger final init via a consistent _init_weights / init_weights call if no pretrained loaded.

I think this is great but it aims to solve a slightly different issue to the one I'm raising: loading weights efficiently with OpenCLIP in isolation.

Now, in your use case it looks like you're using an OpenCLIP model as a submodule in a larger one that has a transformers base. So, already had this hiccup running through scenarios. If a timm or OpenCLIP model factor is called from within a parent context, we need to detect and adapt the two-phase approach assuming that the parent will be taking the final action re weight loading or triggering init. Does that make sense or did I get the intention in your use case wrong?

Yeah, agree on detecting. But, at least in my use case, we don't need to load the weights from OpenCLIP side as they will be later be loaded from the user side ("my" side in this case).

Wanted to also link a related PR: #501

@rwightman
Copy link
Collaborator

@bryant1410 yeah, I got that you're loading the weights outside of OpenCLIP and are already in a meta device context with the hacky patchy approach.

I was thinking about both making OpenCLIP and timm models meta device compatible (timm especially is not right now) AND adding a two phase init approach mentioned in the hopes of avoiding the monkey patching altogether as I REALLY don't like that approach.

I'd like any interm modifications like this one to be compatible with the end approach, and not require possibly breaking changes... your proposed changes are fairly minimal but need to run through if there's any unforseen consequence

@rwightman
Copy link
Collaborator

rwightman commented Sep 25, 2025

@bryant1410
I've started working on better init schemes for timm, and eventually here... in the mean time, taking another peak at this, I think this would cover this use case without adding too much logic that would be invalidated by future changes

    model_is_in_meta_device = next(model.parameters()).device.type == "meta"
    if not model_is_in_meta_device:
        _set_model_device_and_precision(model, device, precision, is_timm_model)
        model_is_in_meta_device = device.type == 'meta'

And then have have the branch for a non-warning log that model is uninitialized / on the meta device.

I'd prefer to leave the checkpoint loading logic alone. load_weights=False should disable loading the weights and there will be future code where meta device use will involve materialization and loading in the factory.

FWI the code does currently work fine if you do

with torch.device('meta'):
   self.another_model = open_clip.create_model_and_transforms("ViT-B-32-quickgelu", pretrained="openai", device="meta", load_weights=False)[0]

actually works without the load_weights too

@bryant1410
Copy link
Contributor Author

And then have have the branch for a non-warning log that model is uninitialized / on the meta device.

I'd prefer to leave the checkpoint loading logic alone. load_weights=False should disable loading the weights and there will be future code where meta device use will involve materialization and loading in the factory.

I just changed the code as per your first paragraph but I don't fully follow you for the rest, sorry (it's been a while and I lost some context). What other changes do you think I should do here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants