-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support loading a model in a non-init-weights context #1105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
So, understand your aim here, but wouldn't that conflict with someone wanting to init on the meta device AND load OpenCLIP pretrained weights? I'm aware both OpenCLIP and timm need proper support for this, hasn't been a priority with the other things. But need a soln that works in the various cases... |
|
To provide more detail on my thoughs, I have some ideas to impl a two-phase init approach for OpenCLIP / timm, where the factories will allow the model to be created on meta device (under meta-device context) w/ no weight-init, and then can either load weights or trigger final init via a consistent Now, in your use case it looks like you're using an OpenCLIP model as a submodule in a larger one that has a transformers base. So, already had this hiccup running through scenarios. If a timm or OpenCLIP model factor is called from within a parent context, we need to detect and adapt the two-phase approach assuming that the parent will be taking the final action re weight loading or triggering init. Does that make sense or did I get the intention in your use case wrong? |
In my use case, the actual weights are loaded at another point (the 2nd call
Makes total sense. It's completely understandable that it hasn't been a priority. I don't know all the popular use cases, but I'd assume it wouldn't break stuff because if the model is in the meta device the current code doesn't work at all. I think the patch I'm proposing is a minimal-intervention solution to make it work when it's in the meta device (to prevent a crash), and shouldn't break the cases when the model is not. Note that the model can only be in the meta device because of meta-programming; the method
I think this is great but it aims to solve a slightly different issue to the one I'm raising: loading weights efficiently with OpenCLIP in isolation.
Yeah, agree on detecting. But, at least in my use case, we don't need to load the weights from OpenCLIP side as they will be later be loaded from the user side ("my" side in this case). Wanted to also link a related PR: #501 |
|
@bryant1410 yeah, I got that you're loading the weights outside of OpenCLIP and are already in a meta device context with the hacky patchy approach. I was thinking about both making OpenCLIP and timm models meta device compatible (timm especially is not right now) AND adding a two phase init approach mentioned in the hopes of avoiding the monkey patching altogether as I REALLY don't like that approach. I'd like any interm modifications like this one to be compatible with the end approach, and not require possibly breaking changes... your proposed changes are fairly minimal but need to run through if there's any unforseen consequence |
|
@bryant1410 And then have have the branch for a non-warning log that model is uninitialized / on the meta device. I'd prefer to leave the checkpoint loading logic alone. FWI the code does currently work fine if you do actually works without the load_weights too |
I just changed the code as per your first paragraph but I don't fully follow you for the rest, sorry (it's been a while and I lost some context). What other changes do you think I should do here? |
I guess this is a borderline bug fix/feature addition. There's the common use of using context manager that sets the model in the meta device by default when transitively loading a model (such as
accelerator.init_empty_weightsor what's used in transformers by default forfrom_pretrainedsince a couple of months ago. The problem is thatopen_clip.factory.create_modelmoves the model to a device, and thus it's incompatible for this. I added a check for this case and a log message for it.This is a code block to reproduce this error (it works fine after the fix):