Refactor weight loading from deepforest.main to models#1083
Refactor weight loading from deepforest.main to models#1083henrykironde merged 32 commits intoweecology:mainfrom
Conversation
|
I've restored the call to Almost all tests pass now. Two that are odd are the save/reload weights via Lightning + saving/loading model state via deepforest. The loaded state dictionaries are identical (I added an assertion for that) and the predictions are almost identical - 55 vs 56 and it looks like there's a rounding error somewhere, or perhaps a very subtle config difference. Not sure what to make of this yet but working on it. I removed the test for creating a retinanet backbone as the function is no longer there. We don't expose any way to adjust this by users (it's always a resnet50_fpn with COCO_V1 weights). Still TODO is to support Faster-RCNN in the same way, and also to update docs where necessary. Alongside the other PR I would like to do a training run to verify everything works as expected. |
|
@ethanwhite feel free to take a look. The gist of the PR remains the same. My assumption is the save/load difference should be a simple fix once we figure out what state is/isn't being recovered (but the model is clearly working so it's not like we're loading garbage). However as it may affect checkpointing for training it'd be good to solve it before merging. The minor fixes to |
|
Had a look, and it’s looking good for the moment. If we are downgrading to omega conf do we still need hydra-core or we use both ? |
|
@henrykironde we still support Hydra, but I made some of the docs a bit clearer that the objects that are passed around are DictConfigs. I think of Hydra as providing a powerful command line interface for overriding configurations (and sweeping etc), built on top of omegaconf. The decorator ( |
ethanwhite
left a comment
There was a problem hiding this comment.
Nice work on this @jveitchmichaelis. Everything looks good to me. I've made a couple of minor suggestions and asked a couple of questions on things I wasn't sure about.
From the tests it looks like there's probably a small bug in the changes to the logging code. If you want to pull that into a separate PR like you mentioned I have no objection, but I'm also fine with cleaning it up here. The different number of predictions is definitely super weird and I agree that we want to understand it before merging. I agree that there's nothing obvious here that should result in this change.
| if model_name is None: | ||
| model_name = self.config.model.name | ||
|
|
||
| if revision is None: | ||
| revision = self.config.model.revision |
There was a problem hiding this comment.
I don't know how I feel about this, so this is really just a question - what do you think about having a fall back where where if model_name/revision is None and these values aren't specified in the config file that they default back to the current defaults? Upside, not a back breaking change in one of our major functions. If the downside is just another bit of crufty code then it's probably worth it since this will break all existing versions. If there are cascading downsides then it's probably worth going ahead and breaking it.
src/deepforest/main.py
Outdated
| # TODO: Hub model should store this mapping. | ||
| if model_name == "weecology/deepforest-bird": | ||
| self.config.retinanet.score_thresh = 0.3 | ||
| self.label_dict = {"Bird": 0} | ||
| self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} | ||
| self.set_labels({"Bird": 0}) | ||
| else: | ||
| self.set_labels(self.model.label_dict) | ||
|
|
||
| return |
There was a problem hiding this comment.
Yeah, let's go ahead and get this moved to the hub config file and clean out this code before merging. I just sent you an invite to the hugging face team.
|
@jveitchmichaelis I found three errors which are error but looks like they are all backward compatibility related to PyTorch. PyTorch Compatibility Issue
I believe that also the change of the data frame is a tensor rounding off of values also created by the version changes. DataFrame shape mismatch
[left]: (56, 8)
[right]: (55, 8)We could test on older version of PyTorch and see if we get the correct results. We can also focus the tests to not just look for the values but probably the structure or test that the values are so close since object detection models can have non-deterministic behavior , floating-point differences both at GPU and CPU. for the self.log(), apparently we didn't have to call validation correctly. to adding a PR to your branch as a point of reference as you continue to debug. |
|
Thanks @henrykironde, I think there is an underscore method we can use for older versions. The newer one seems to be a bugfix? Not sure if it matters. https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/module.py#L2097 Alternatively we could ignore it and force the use of newer weights (eg if someone requests weecology/deepforest-tree we hard code an alias to a v2 version with updated tensor names). Just as a back compatibility check. The problem does occur within the same version of Pytorch though, for example if you load the weights, predict and then save/reload? But yeah I'm not sure how important this is versus checking that the behaviour is generally similar. |
|
@henrykironde for that logging issue - we already account for this, but I didn't notice when I added the additional call. I've just added I just added some logic to allow for The code is now a bit stricter about creating models with a label_dict (and enforcing consistency between num_classes and len(labels)). I'll add a fallback for the hook registration based on PyTorch version, not sure there's a good way around that as they renamed it (it used to be an internal |
|
That rounding error also seems to go away if you call Currently fixing some new bugs that seem to have been introduced with the last tweaks. A little confused that up to e800a70 passed here but doesn't seem to locally any more (or after the rebase). |
|
Made a few changes here:
|
|
Last commit fixes #1095, committing it here because it's somewhat related to the overall refactor and the tests were actually failing on my Mac and I mistakenly thought it was MPS being weird. |
|
I think this is good for review + squash. I fixed some of the label_dict bugs and added some additional test cases to verify things like single label -> multi label models and vice versa. If you'd like any of this as a separate PR happy to split some of it out, but turned out there were quite a few things that needed to be touched. Updated the changelog at the top fyi. |
Changelog:
from_pretrainedis handled by model/modules directly. Specifically, it adds a wrapper toretinanetandfaster-rcnnin order to allow them to use the HF Hub API methods. User facing functionality isn't affected and I've added a hook when the state dictionary is loaded so that older hub checkpoints can be used without modification.Rationale
This has come out of #1078 and some musings about how to handle checkpoints when we add new model variants to the package.
We currently store weights for trained models on HuggingFace. Models are instantiated two ways:
deepforest.maincallscreate_model()on construction. This returns an instance of the desired architecture that's initialised to whatever the object returns by default, assuming no model was passed to the constructor directly.load_model()to pull in pre-trained weights.There are a couple of issues here that I've run into when trying to add another model to the package:
load_model()callsdeepforest.main.from_pretrainedwhich is a class method provided from thePyTorchModelHubMixinmixin. We attempt to set the return value toself.modeland overwrite a few settings such as the instance label_dict, model and numeric_to_label_dict. This is an unusual way of using this function, because it builds an entirely new deepforest instance and then copies those three items from it. Whereas you'd normally do something like this to begin with:create_modelon construction unless we're training, becauseload_modelwill override it. This doesn't save much time, but it's a bit more efficient. I think it's actually called again time by load_model.DeepForest, particularly the constructor signature due to how from_pretrained works to reconstruct an instance of the object you've requested. I think it's nice to scope the weights as narrowly as possible.Implementation
Given what
load_modeldoes (generate a nn.Module), I suggest we move the weight loading logic to the models directly.This has an advantage that the weights stored on HF are (mostly) decoupled from the library and minimally contain the weights + label dict. I suspect there is currently a bug in the constructor at the moment where all models will open as retinanets, because of the way the config system is loaded (and the unnecessary extra call to
create_modelwhen you callload_pretrainedas it calls the constructor again). Investigating that.This PR proposes that models subclass
PyTorchModelHubMixinwhere necessary. This has no effect on Lightning checkpointing as far as I'm aware, as those are separate (and store a lot more information to resume training state). This does not need to be done for transformers models as they already support the hub methods, but it would require adding the mixin to RetinaNet and FasterRCNN.We should probably re-publish weights to HF with a suitable revision (e.g. v2) and deprecate the old ones as v1, or just update the
config.jsonto be a bit cleaner there.State of PR
The current example overrides RetinaNet. A similar approach would be made for Faster-RCNN. Changes to the codebase are otherwise quite minimal,
load_modelis basically the same except the call tofrom_pretrainedhappens a bit lower down.There are some minor improvements to the main class (e.g. we can use set_labels instead of duplicating logic).
Expect tests to fail at the moment while I check a few things, though most are passing locally.
I've renamed the abstract class models.Model to 'BaseModel' (as Henry suggested we might want to avoid calling Model(Model)).
Since the weights are still the same in current hub models, we can use them with only a change to their keys.
To minimize changes we could keep the call to
create_modelin the constructor. I don't think we have a clean way of expressing whether we're training/fine-tuning/predicting, otherwise I would just say thatload_modelshould run by default. It doesn't make any difference in speed - they're both constructing an architecture. Perhaps just have a config option to load empty/default weights for training? This also makes user-facing code simpler because you don't need to remember to load a model, it comes from the config automatically.