-
Notifications
You must be signed in to change notification settings - Fork 214
Refactor weight loading from deepforest.main to models #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor weight loading from deepforest.main to models #1083
Conversation
I've restored the call to Almost all tests pass now. Two that are odd are the save/reload weights via Lightning + saving/loading model state via deepforest. The loaded state dictionaries are identical (I added an assertion for that) and the predictions are almost identical - 55 vs 56 and it looks like there's a rounding error somewhere, or perhaps a very subtle config difference. Not sure what to make of this yet but working on it.
I removed the test for creating a retinanet backbone as the function is no longer there. We don't expose any way to adjust this by users (it's always a resnet50_fpn with COCO_V1 weights). Still TODO is to support Faster-RCNN in the same way, and also to update docs where necessary. Alongside the other PR I would like to do a training run to verify everything works as expected. |
@@ -16,6 +16,9 @@ model: | |||
name: 'weecology/deepforest-tree' | |||
revision: 'main' | |||
|
|||
label_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this, but not sure if we need this here any more, possibly not? There was some logic in the main constructor that needed changing.
@ethanwhite feel free to take a look. The gist of the PR remains the same. My assumption is the save/load difference should be a simple fix once we figure out what state is/isn't being recovered (but the model is clearly working so it's not like we're loading garbage). However as it may affect checkpointing for training it'd be good to solve it before merging. The minor fixes to |
Had a look, and it’s looking good for the moment. If we are downgrading to omega conf do we still need hydra-core or we use both ? |
@henrykironde we still support Hydra, but I made some of the docs a bit clearer that the objects that are passed around are DictConfigs. I think of Hydra as providing a powerful command line interface for overriding configurations (and sweeping etc), built on top of omegaconf. The decorator ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work on this @jveitchmichaelis. Everything looks good to me. I've made a couple of minor suggestions and asked a couple of questions on things I wasn't sure about.
From the tests it looks like there's probably a small bug in the changes to the logging code. If you want to pull that into a separate PR like you mentioned I have no objection, but I'm also fine with cleaning it up here. The different number of predictions is definitely super weird and I agree that we want to understand it before merging. I agree that there's nothing obvious here that should result in this change.
if model_name is None: | ||
model_name = self.config.model.name | ||
|
||
if revision is None: | ||
revision = self.config.model.revision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how I feel about this, so this is really just a question - what do you think about having a fall back where where if model_name
/revision
is None
and these values aren't specified in the config file that they default back to the current defaults? Upside, not a back breaking change in one of our major functions. If the downside is just another bit of crufty code then it's probably worth it since this will break all existing versions. If there are cascading downsides then it's probably worth going ahead and breaking it.
# TODO: Hub model should store this mapping. | ||
if model_name == "weecology/deepforest-bird": | ||
self.config.retinanet.score_thresh = 0.3 | ||
self.label_dict = {"Bird": 0} | ||
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} | ||
self.set_labels({"Bird": 0}) | ||
else: | ||
self.set_labels(self.model.label_dict) | ||
|
||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's go ahead and get this moved to the hub config file and clean out this code before merging. I just sent you an invite to the hugging face team.
#TODO: Remove this call unless needed, or plan to replace | ||
# with load_model. | ||
self.model = model | ||
self.create_model() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder to check and resolve this TODO
warnings.warn( | ||
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use main.deepforest(config_args={'num_classes':value})" | ||
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of using a config file or config_args. Use main.deepforest(config_args={'num_classes':value})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have the same warning for label_dict or is there something special about num_classes
that means only it needs the warning?
This has come out of #1078 and some musings about how to handle checkpoints when we add new model variants to the package.
The PR refactors hub weight loading so that
from_pretrained
is handled by model/modules directly. Specifically, it adds a wrapper toretinanet
andfaster-rcnn
in order to allow them to use the HF Hub API methods. User facing functionality isn't affected and I've added a hook when the state dictionary is loaded so that older hub checkpoints can be used without modification.Rationale
We currently store weights for trained models on HuggingFace. Models are instantiated two ways:
deepforest.main
callscreate_model()
on construction. This returns an instance of the desired architecture that's initialised to whatever the object returns by default, assuming no model was passed to the constructor directly.load_model()
to pull in pre-trained weights.There are a couple of issues here that I've run into when trying to add another model to the package:
load_model()
callsdeepforest.main.from_pretrained
which is a class method provided from thePyTorchModelHubMixin
mixin. We attempt to set the return value toself.model
and overwrite a few settings such as the instance label_dict, model and numeric_to_label_dict. This is an unusual way of using this function, because it builds an entirely new deepforest instance and then copies those three items from it. Whereas you'd normally do something like this to begin with:create_model
on construction unless we're training, becauseload_model
will override it. This doesn't save much time, but it's a bit more efficient. I think it's actually called again time by load_model.DeepForest
, particularly the constructor signature due to how from_pretrained works to reconstruct an instance of the object you've requested. I think it's nice to scope the weights as narrowly as possible.Implementation
Given what
load_model
does (generate a nn.Module), I suggest we move the weight loading logic to the models directly.This has an advantage that the weights stored on HF are (mostly) decoupled from the library and minimally contain the weights + label dict. I suspect there is currently a bug in the constructor at the moment where all models will open as retinanets, because of the way the config system is loaded (and the unnecessary extra call to
create_model
when you callload_pretrained
as it calls the constructor again). Investigating that.This PR proposes that models subclass
PyTorchModelHubMixin
where necessary. This has no effect on Lightning checkpointing as far as I'm aware, as those are separate (and store a lot more information to resume training state). This does not need to be done for transformers models as they already support the hub methods, but it would require adding the mixin to RetinaNet and FasterRCNN.We should probably re-publish weights to HF with a suitable revision (e.g. v2) and deprecate the old ones as v1, or just update the
config.json
to be a bit cleaner there.State of PR
The current example overrides RetinaNet. A similar approach would be made for Faster-RCNN. Changes to the codebase are otherwise quite minimal,
load_model
is basically the same except the call tofrom_pretrained
happens a bit lower down.There are some minor improvements to the main class (e.g. we can use set_labels instead of duplicating logic).
Expect tests to fail at the moment while I check a few things, though most are passing locally.
I've renamed the abstract class models.Model to 'BaseModel' (as Henry suggested we might want to avoid calling Model(Model)).
Since the weights are still the same in current hub models, we can use them with only a change to their keys.
To minimize changes we could keep the call to
create_model
in the constructor. I don't think we have a clean way of expressing whether we're training/fine-tuning/predicting, otherwise I would just say thatload_model
should run by default. It doesn't make any difference in speed - they're both constructing an architecture. Perhaps just have a config option to load empty/default weights for training? This also makes user-facing code simpler because you don't need to remember to load a model, it comes from the config automatically.