Skip to content

Refactor weight loading from deepforest.main to models #1083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

jveitchmichaelis
Copy link
Collaborator

@jveitchmichaelis jveitchmichaelis commented Jun 25, 2025

This has come out of #1078 and some musings about how to handle checkpoints when we add new model variants to the package.

The PR refactors hub weight loading so that from_pretrained is handled by model/modules directly. Specifically, it adds a wrapper to retinanet and faster-rcnn in order to allow them to use the HF Hub API methods. User facing functionality isn't affected and I've added a hook when the state dictionary is loaded so that older hub checkpoints can be used without modification.

Rationale

We currently store weights for trained models on HuggingFace. Models are instantiated two ways:

  • deepforest.main calls create_model() on construction. This returns an instance of the desired architecture that's initialised to whatever the object returns by default, assuming no model was passed to the constructor directly.
  • Alternatively/additionally we then call load_model() to pull in pre-trained weights.

There are a couple of issues here that I've run into when trying to add another model to the package:

  • load_model() calls deepforest.main.from_pretrained which is a class method provided from the PyTorchModelHubMixin mixin. We attempt to set the return value to self.model and overwrite a few settings such as the instance label_dict, model and numeric_to_label_dict. This is an unusual way of using this function, because it builds an entirely new deepforest instance and then copies those three items from it. Whereas you'd normally do something like this to begin with:
m = deepforest.from_pretrained(model_name, revision)
  • We don't need to call create_model on construction unless we're training, because load_model will override it. This doesn't save much time, but it's a bit more efficient. I think it's actually called again time by load_model.
  • Some parameters are hardcoded (e.g. the bird model class names) in the library, when they would be better off in the hub checkpoint.
  • Weights are coupled to DeepForest, particularly the constructor signature due to how from_pretrained works to reconstruct an instance of the object you've requested. I think it's nice to scope the weights as narrowly as possible.
  • We should consider what hyperparameters are model-related and what are pipeline-related. For example, the weights on the hub should have a config included that describe the class mapping, and sensible defaults for things like NMS/confidence threshold. We have the Hydra/OmegaConf config to declare default processing (tiling, etc).

Implementation

Given what load_model does (generate a nn.Module), I suggest we move the weight loading logic to the models directly.

This has an advantage that the weights stored on HF are (mostly) decoupled from the library and minimally contain the weights + label dict. I suspect there is currently a bug in the constructor at the moment where all models will open as retinanets, because of the way the config system is loaded (and the unnecessary extra call to create_model when you call load_pretrained as it calls the constructor again). Investigating that.

This PR proposes that models subclass PyTorchModelHubMixin where necessary. This has no effect on Lightning checkpointing as far as I'm aware, as those are separate (and store a lot more information to resume training state). This does not need to be done for transformers models as they already support the hub methods, but it would require adding the mixin to RetinaNet and FasterRCNN.

We should probably re-publish weights to HF with a suitable revision (e.g. v2) and deprecate the old ones as v1, or just update the config.json to be a bit cleaner there.

State of PR

  • The current example overrides RetinaNet. A similar approach would be made for Faster-RCNN. Changes to the codebase are otherwise quite minimal, load_model is basically the same except the call to from_pretrained happens a bit lower down.

  • There are some minor improvements to the main class (e.g. we can use set_labels instead of duplicating logic).

  • Expect tests to fail at the moment while I check a few things, though most are passing locally.

  • I've renamed the abstract class models.Model to 'BaseModel' (as Henry suggested we might want to avoid calling Model(Model)).

  • Since the weights are still the same in current hub models, we can use them with only a change to their keys.

  • To minimize changes we could keep the call to create_model in the constructor. I don't think we have a clean way of expressing whether we're training/fine-tuning/predicting, otherwise I would just say that load_model should run by default. It doesn't make any difference in speed - they're both constructing an architecture. Perhaps just have a config option to load empty/default weights for training? This also makes user-facing code simpler because you don't need to remember to load a model, it comes from the config automatically.

@jveitchmichaelis jveitchmichaelis changed the title WIP: Refactor weight loading from deepforest:main to models WIP: Refactor weight loading from deepforest.main to models Jun 25, 2025
@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 26, 2025

I've restored the call to create_model in __init__ for now.

Almost all tests pass now. Two that are odd are the save/reload weights via Lightning + saving/loading model state via deepforest. The loaded state dictionaries are identical (I added an assertion for that) and the predictions are almost identical - 55 vs 56 and it looks like there's a rounding error somewhere, or perhaps a very subtle config difference. Not sure what to make of this yet but working on it.

m = deepforest(
  (iou_metric): IntersectionOverUnion()
  (mAP_metric): MeanAveragePrecision()
  (empty_frame_accuracy): B..., 0.456, 0.406], std=[0.229, 0.224, 0.225])
        Resize(min_size=(800,), max_size=1333, mode='bilinear')
    )
  )
)
tmpdir = local('/tmp/pytest-of-runner/pytest-0/test_save_and_reload_checkpoin0')

    def test_save_and_reload_checkpoint(m, tmpdir):
        img_path = get_data(path="2019_YELL_2_528000_4978000_image_crop2.png")
        m.config.train.fast_dev_run = True
        m.create_trainer()
        # save the prediction dataframe after training and
        # compare with prediction after reload checkpoint
        m.trainer.fit(m)
        pred_after_train = m.predict_image(path=img_path)
        m.save_model("{}/checkpoint.pl".format(tmpdir))
    
        # reload the checkpoint to model object
        after = main.deepforest.load_from_checkpoint("{}/checkpoint.pl".format(tmpdir))
        pred_after_reload = after.predict_image(path=img_path)
    
        assert not pred_after_train.empty
        assert not pred_after_reload.empty
        assert m.config == after.config
        assert state_dicts_equal(m.model, after.model)
>       pd.testing.assert_frame_equal(pred_after_train, pred_after_reload)
E       AssertionError: DataFrame are different
E       
E       DataFrame shape mismatch
E       [left]:  (56, 8)
E       [right]: (55, 8)

I removed the test for creating a retinanet backbone as the function is no longer there. We don't expose any way to adjust this by users (it's always a resnet50_fpn with COCO_V1 weights).

Still TODO is to support Faster-RCNN in the same way, and also to update docs where necessary.

Alongside the other PR I would like to do a training run to verify everything works as expected.

@@ -16,6 +16,9 @@ model:
name: 'weecology/deepforest-tree'
revision: 'main'

label_dict:
Copy link
Collaborator Author

@jveitchmichaelis jveitchmichaelis Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this, but not sure if we need this here any more, possibly not? There was some logic in the main constructor that needed changing.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 1, 2025

@ethanwhite feel free to take a look. The gist of the PR remains the same.

My assumption is the save/load difference should be a simple fix once we figure out what state is/isn't being recovered (but the model is clearly working so it's not like we're loading garbage).

However as it may affect checkpointing for training it'd be good to solve it before merging.

The minor fixes to main/elsewhere (like logging batches and refactoring how label dicts are processed) I would probably make a separate PR(s).

@henrykironde
Copy link
Contributor

Had a look, and it’s looking good for the moment. If we are downgrading to omega conf do we still need hydra-core or we use both ?

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 1, 2025

@henrykironde we still support Hydra, but I made some of the docs a bit clearer that the objects that are passed around are DictConfigs.

I think of Hydra as providing a powerful command line interface for overriding configurations (and sweeping etc), built on top of omegaconf. The decorator (@hyra.main) yields an omegaconf DictConfig object, constructed from whatever was passed via the CLI. So the only place we should initialize Hydra (I think) is in those CLI scripts, otherwise it can interfere with other applications. This way someone can use DeepForest in another project without conflicts, like we saw briefly with BOEM.

@jveitchmichaelis jveitchmichaelis changed the title WIP: Refactor weight loading from deepforest.main to models Refactor weight loading from deepforest.main to models Jul 3, 2025
Copy link
Member

@ethanwhite ethanwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this @jveitchmichaelis. Everything looks good to me. I've made a couple of minor suggestions and asked a couple of questions on things I wasn't sure about.

From the tests it looks like there's probably a small bug in the changes to the logging code. If you want to pull that into a separate PR like you mentioned I have no objection, but I'm also fine with cleaning it up here. The different number of predictions is definitely super weird and I agree that we want to understand it before merging. I agree that there's nothing obvious here that should result in this change.

Comment on lines +135 to +139
if model_name is None:
model_name = self.config.model.name

if revision is None:
revision = self.config.model.revision
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how I feel about this, so this is really just a question - what do you think about having a fall back where where if model_name/revision is None and these values aren't specified in the config file that they default back to the current defaults? Upside, not a back breaking change in one of our major functions. If the downside is just another bit of crufty code then it's probably worth it since this will break all existing versions. If there are cascading downsides then it's probably worth going ahead and breaking it.

Comment on lines +149 to +156
# TODO: Hub model should store this mapping.
if model_name == "weecology/deepforest-bird":
self.config.retinanet.score_thresh = 0.3
self.label_dict = {"Bird": 0}
self.numeric_to_label_dict = {v: k for k, v in self.label_dict.items()}
self.set_labels({"Bird": 0})
else:
self.set_labels(self.model.label_dict)

return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's go ahead and get this moved to the hub config file and clean out this code before merging. I just sent you an invite to the hugging face team.

Comment on lines +102 to +105
#TODO: Remove this call unless needed, or plan to replace
# with load_model.
self.model = model
self.create_model()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to check and resolve this TODO

warnings.warn(
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use main.deepforest(config_args={'num_classes':value})"
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of using a config file or config_args. Use main.deepforest(config_args={'num_classes':value})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have the same warning for label_dict or is there something special about num_classes that means only it needs the warning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants