Skip to content

Refactor weight loading from deepforest.main to models#1083

Merged
henrykironde merged 32 commits intoweecology:mainfrom
jveitchmichaelis:model_weight_refactor
Jul 21, 2025
Merged

Refactor weight loading from deepforest.main to models#1083
henrykironde merged 32 commits intoweecology:mainfrom
jveitchmichaelis:model_weight_refactor

Conversation

@jveitchmichaelis
Copy link
Collaborator

@jveitchmichaelis jveitchmichaelis commented Jun 25, 2025

Changelog:

  • refactors hub weight loading so that from_pretrained is handled by model/modules directly. Specifically, it adds a wrapper to retinanet and faster-rcnn in order to allow them to use the HF Hub API methods. User facing functionality isn't affected and I've added a hook when the state dictionary is loaded so that older hub checkpoints can be used without modification.
  • fixes various integration bugs with the DeformableDetr model and "harmonizes" things with RetinaNet.
  • removes legacy FasterRCNN interface as we do not currently provide models.
  • checkpoint models are loaded by default and docs are updated to describe how to train in various situations.
  • additional test cases
  • fixes: Type conversion error when training DeformableDETR on GPU #1095 + tests

Rationale

This has come out of #1078 and some musings about how to handle checkpoints when we add new model variants to the package.

We currently store weights for trained models on HuggingFace. Models are instantiated two ways:

  • deepforest.main calls create_model() on construction. This returns an instance of the desired architecture that's initialised to whatever the object returns by default, assuming no model was passed to the constructor directly.
  • Alternatively/additionally we then call load_model() to pull in pre-trained weights.

There are a couple of issues here that I've run into when trying to add another model to the package:

  • load_model() calls deepforest.main.from_pretrained which is a class method provided from the PyTorchModelHubMixin mixin. We attempt to set the return value to self.model and overwrite a few settings such as the instance label_dict, model and numeric_to_label_dict. This is an unusual way of using this function, because it builds an entirely new deepforest instance and then copies those three items from it. Whereas you'd normally do something like this to begin with:
m = deepforest.from_pretrained(model_name, revision)
  • We don't need to call create_model on construction unless we're training, because load_model will override it. This doesn't save much time, but it's a bit more efficient. I think it's actually called again time by load_model.
  • Some parameters are hardcoded (e.g. the bird model class names) in the library, when they would be better off in the hub checkpoint.
  • Weights are coupled to DeepForest, particularly the constructor signature due to how from_pretrained works to reconstruct an instance of the object you've requested. I think it's nice to scope the weights as narrowly as possible.
  • We should consider what hyperparameters are model-related and what are pipeline-related. For example, the weights on the hub should have a config included that describe the class mapping, and sensible defaults for things like NMS/confidence threshold. We have the Hydra/OmegaConf config to declare default processing (tiling, etc).

Implementation

Given what load_model does (generate a nn.Module), I suggest we move the weight loading logic to the models directly.

This has an advantage that the weights stored on HF are (mostly) decoupled from the library and minimally contain the weights + label dict. I suspect there is currently a bug in the constructor at the moment where all models will open as retinanets, because of the way the config system is loaded (and the unnecessary extra call to create_model when you call load_pretrained as it calls the constructor again). Investigating that.

This PR proposes that models subclass PyTorchModelHubMixin where necessary. This has no effect on Lightning checkpointing as far as I'm aware, as those are separate (and store a lot more information to resume training state). This does not need to be done for transformers models as they already support the hub methods, but it would require adding the mixin to RetinaNet and FasterRCNN.

We should probably re-publish weights to HF with a suitable revision (e.g. v2) and deprecate the old ones as v1, or just update the config.json to be a bit cleaner there.

State of PR

  • The current example overrides RetinaNet. A similar approach would be made for Faster-RCNN. Changes to the codebase are otherwise quite minimal, load_model is basically the same except the call to from_pretrained happens a bit lower down.

  • There are some minor improvements to the main class (e.g. we can use set_labels instead of duplicating logic).

  • Expect tests to fail at the moment while I check a few things, though most are passing locally.

  • I've renamed the abstract class models.Model to 'BaseModel' (as Henry suggested we might want to avoid calling Model(Model)).

  • Since the weights are still the same in current hub models, we can use them with only a change to their keys.

  • To minimize changes we could keep the call to create_model in the constructor. I don't think we have a clean way of expressing whether we're training/fine-tuning/predicting, otherwise I would just say that load_model should run by default. It doesn't make any difference in speed - they're both constructing an architecture. Perhaps just have a config option to load empty/default weights for training? This also makes user-facing code simpler because you don't need to remember to load a model, it comes from the config automatically.

@jveitchmichaelis jveitchmichaelis changed the title WIP: Refactor weight loading from deepforest:main to models WIP: Refactor weight loading from deepforest.main to models Jun 25, 2025
@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 26, 2025

I've restored the call to create_model in __init__ for now.

Almost all tests pass now. Two that are odd are the save/reload weights via Lightning + saving/loading model state via deepforest. The loaded state dictionaries are identical (I added an assertion for that) and the predictions are almost identical - 55 vs 56 and it looks like there's a rounding error somewhere, or perhaps a very subtle config difference. Not sure what to make of this yet but working on it.

m = deepforest(
  (iou_metric): IntersectionOverUnion()
  (mAP_metric): MeanAveragePrecision()
  (empty_frame_accuracy): B..., 0.456, 0.406], std=[0.229, 0.224, 0.225])
        Resize(min_size=(800,), max_size=1333, mode='bilinear')
    )
  )
)
tmpdir = local('/tmp/pytest-of-runner/pytest-0/test_save_and_reload_checkpoin0')

    def test_save_and_reload_checkpoint(m, tmpdir):
        img_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png")
        m.config.train.fast_dev_run = True
        m.create_trainer()
        # save the prediction dataframe after training and
        # compare with prediction after reload checkpoint
        m.trainer.fit(m)
        pred_after_train = m.predict_image(path=img_path)
        m.save_model("{}/checkpoint.pl".format(tmpdir))
    
        # reload the checkpoint to model object
        after = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir))
        pred_after_reload = after.predict_image(path=img_path)
    
        assert not pred_after_train.empty
        assert not pred_after_reload.empty
        assert m.config == after.config
        assert state_dicts_equal(m.model, after.model)
>       pd.testing.assert_frame_equal(pred_after_train, pred_after_reload)
E       AssertionError: DataFrame are different
E       
E       DataFrame shape mismatch
E       [left]:  (56, 8)
E       [right]: (55, 8)

I removed the test for creating a retinanet backbone as the function is no longer there. We don't expose any way to adjust this by users (it's always a resnet50_fpn with COCO_V1 weights).

Still TODO is to support Faster-RCNN in the same way, and also to update docs where necessary.

Alongside the other PR I would like to do a training run to verify everything works as expected.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 1, 2025

@ethanwhite feel free to take a look. The gist of the PR remains the same.

My assumption is the save/load difference should be a simple fix once we figure out what state is/isn't being recovered (but the model is clearly working so it's not like we're loading garbage).

However as it may affect checkpointing for training it'd be good to solve it before merging.

The minor fixes to main/elsewhere (like logging batches and refactoring how label dicts are processed) I would probably make a separate PR(s).

@henrykironde
Copy link
Contributor

Had a look, and it’s looking good for the moment. If we are downgrading to omega conf do we still need hydra-core or we use both ?

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 1, 2025

@henrykironde we still support Hydra, but I made some of the docs a bit clearer that the objects that are passed around are DictConfigs.

I think of Hydra as providing a powerful command line interface for overriding configurations (and sweeping etc), built on top of omegaconf. The decorator (@hyra.main) yields an omegaconf DictConfig object, constructed from whatever was passed via the CLI. So the only place we should initialize Hydra (I think) is in those CLI scripts, otherwise it can interfere with other applications. This way someone can use DeepForest in another project without conflicts, like we saw briefly with BOEM.

@jveitchmichaelis jveitchmichaelis changed the title WIP: Refactor weight loading from deepforest.main to models Refactor weight loading from deepforest.main to models Jul 3, 2025
Copy link
Member

@ethanwhite ethanwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this @jveitchmichaelis. Everything looks good to me. I've made a couple of minor suggestions and asked a couple of questions on things I wasn't sure about.

From the tests it looks like there's probably a small bug in the changes to the logging code. If you want to pull that into a separate PR like you mentioned I have no objection, but I'm also fine with cleaning it up here. The different number of predictions is definitely super weird and I agree that we want to understand it before merging. I agree that there's nothing obvious here that should result in this change.

Comment on lines +135 to +139
if model_name is None:
model_name = self.config.model.name

if revision is None:
revision = self.config.model.revision
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how I feel about this, so this is really just a question - what do you think about having a fall back where where if model_name/revision is None and these values aren't specified in the config file that they default back to the current defaults? Upside, not a back breaking change in one of our major functions. If the downside is just another bit of crufty code then it's probably worth it since this will break all existing versions. If there are cascading downsides then it's probably worth going ahead and breaking it.

Comment on lines +149 to +156
# TODO: Hub model should store this mapping.
if model_name == "weecology/deepforest-bird":
self.config.retinanet.score_thresh = 0.3
self.label_dict = {"Bird": 0}
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}
self.set_labels({"Bird": 0})
else:
self.set_labels(self.model.label_dict)

return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's go ahead and get this moved to the hub config file and clean out this code before merging. I just sent you an invite to the hugging face team.

@henrykironde
Copy link
Contributor

@jveitchmichaelis I found three errors which are error but looks like they are all backward compatibility related to PyTorch.

PyTorch Compatibility Issue

  1. Problem: AttributeError: 'RetinaNetHub' object has no attribute 'register_load_state_dict_pre_hook'
    Root Cause: The code is using register_load_state_dict_pre_hook which was introduced in PyTorch 2.3.0, but my environment has PyTorch 2.2.2

I believe that also the change of the data frame is a tensor rounding off of values also created by the version changes.

DataFrame shape mismatch
[left]:  (56, 8)
[right]: (55, 8)

We could test on older version of PyTorch and see if we get the correct results. We can also focus the tests to not just look for the values but probably the structure or test that the values are so close since object detection models can have non-deterministic behavior , floating-point differences both at GPU and CPU.

for the self.log(), apparently we didn't have to call validation correctly.
We by passed PyTorch Lightning's proper loop initialization, so when validation_step() tried to call self.log(), the trainer's result collection wasn't set up.


val_dataloader = m.val_dataloader()
batch = next(iter(val_dataloader))
val_loss = m.validation_step(batch, 0)

to

m.create_trainer(fast_dev_run=True)
m.trainer.validate(m) 

adding a PR to your branch as a point of reference as you continue to debug.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 7, 2025

Thanks @henrykironde, I think there is an underscore method we can use for older versions. The newer one seems to be a bugfix? Not sure if it matters. https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/module.py#L2097

Alternatively we could ignore it and force the use of newer weights (eg if someone requests weecology/deepforest-tree we hard code an alias to a v2 version with updated tensor names). Just as a back compatibility check.

The problem does occur within the same version of Pytorch though, for example if you load the weights, predict and then save/reload? But yeah I'm not sure how important this is versus checking that the behaviour is generally similar.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 11, 2025

@henrykironde for that logging issue - we already account for this, but I didn't notice when I added the additional call. I've just added val/loss as an additional metric (to be consistent with logging train/loss, which is the summed loss). I wrapped everything in the existing try/except block.

I just added some logic to allow for num_classes to be modified when a checkpoint is loaded. This isn't the cleanest solution I think, but it's not urgent. This is mostly useful for fine-tuning, when someone takes our single-class model and re-trains it. The fact this didn't error before suggests I need to go through the tests to make sure everything is actually captured.

The code is now a bit stricter about creating models with a label_dict (and enforcing consistency between num_classes and len(labels)).

I'll add a fallback for the hook registration based on PyTorch version, not sure there's a good way around that as they renamed it (it used to be an internal _ method).

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 11, 2025

That rounding error also seems to go away if you call create_model and then load_model (which is what we do at the moment). We can take another look at that later if we try to make initialization a bit more efficient.

Currently fixing some new bugs that seem to have been introduced with the last tweaks. A little confused that up to e800a70 passed here but doesn't seem to locally any more (or after the rebase).

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 15, 2025

Made a few changes here:

  • create_model will now only create an empty model if the model.name is None. Open to suggestions on the best way to express this. Otherwise it will load the weights in the config.
  • Fixed some inconsistencies in the DETR integration and training (which still needs to be tested in anger)
  • Check that we override the score/nms thresholds from config when we load a model.
  • deepforest will load the config-specified model on instantiation. No need to call load_model explicitly now, but it shouldn't break anything if you do, it'll just run twice. Should add a test for that probably.
    • This also fixes the test failures which were indeed due to batch_norm values not being preserved.
    • It's not entirely obvious why this wasn't working before, but probably occurred to the changes in what model is loaded by default. The load/save checkpoint features seem to test OK so I think it's fine.
  • Only the config is saved in the (deepforest) hyperparameter file. This is only used for Lightning checkpoints.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 18, 2025

Last commit fixes #1095, committing it here because it's somewhat related to the overall refactor and the tests were actually failing on my Mac and I mistakenly thought it was MPS being weird.

@ethanwhite ethanwhite requested a review from henrykironde July 18, 2025 13:23
@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 19, 2025

I think this is good for review + squash. I fixed some of the label_dict bugs and added some additional test cases to verify things like single label -> multi label models and vice versa. If you'd like any of this as a separate PR happy to split some of it out, but turned out there were quite a few things that needed to be touched.

Updated the changelog at the top fyi.

Copy link
Contributor

@henrykironde henrykironde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me

@henrykironde henrykironde merged commit 1d77ec7 into weecology:main Jul 21, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Type conversion error when training DeformableDETR on GPU

3 participants