Skip to content

Transformers/DETR integration #1078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 4, 2025

Conversation

jveitchmichaelis
Copy link
Collaborator

@jveitchmichaelis jveitchmichaelis commented Jun 14, 2025

This PR includes rudimentary support for transformers object detection models, specifically those supported by AutoModelForObjectDetection. I've trained a simple DeformableDetr model as a demonstration using the latest public release of MillionTrees which is hosted here (for testing): https://huggingface.co/joshvm/milliontrees-detr

The PR fixes a few issues in the process:

  • Fixed an issue with albumentations versioning which can be rebased once the other PR is approved
  • Adds dependencies for transformers and timm. I think as I used uv to add the pyproject line, it should be compatible with the python versions that we support.
  • The config.score_thresh parameter is moved to be a top-level config item as there's no need for it to be specific to retinanet, and other places it's referenced have been updated. This is also used for post-processing results from DETR.
  • Removed a comment in the config file that incorrectly suggests that threshold is NMS-related
  • Removed some docstrings in models that were out of date (referring to constructor arguments that no longer exist)
  • Adds a model class currently called DeformableDETR, but really it could support a variety of models. These can be separate files, but we would likely want to re-use the wrapper.
  • A small improvement to plot_results to allow an image path to provided instead of an array instead of erroring.

An advantage of the current approach is that it requires no changes to main at the moment, the implementation is isolated to models/detr.py. I need to check compatibility with other prediction routes (like via the trainer) but the forward function returns [boxes, labels, scores] in a DeepForest-friendly format without any modification. The wrapper class is a bit awkward, but necessary because of the way that Model and sub-classes are currently implemented (since we need to run pre- and post-process which is transformers specific).

The model and processor (and their config as it relates to DeepForest) should be coupled, so I don't think we need to worry about exposing the details. So I think it's OK to have both loaded with config.model.name and keep the config fixed, rather than allow users to mess with image resizing. Other options like rescaling (1/255) and normalization should be preset so that the processor works with our dataset + loader.

NMS for predict_image doesn't seem to be working though, and I'm not sure why (@bw4sz?), see the following test script:

from deepforest.main import deepforest
from deepforest.visualize import plot_results

m = deepforest(config_args={"model": {"name": "joshvm/milliontrees-detr"},
                            "score_thresh": 0.25, # This does work, at least.
                            "architecture": "detr"})

image_path = "src/deepforest/data/2018_SJER_3_252000_4107000_image_477.tif"

results = m.predict_image(path=image_path)

plot_results(results, image=image_path)

image

Training is going to need a bigger refactor (probably) because the various Lightning steps in main make assumptions that might be tied to torchvision-type models.

TODO:

  • Test suite. And fix the stuff that's currently breaking.
  • Tidy up docstrings etc.
  • Check prediction with other functions.
  • Think about how we can enable training as well, vs just loading models.
  • Evaluation check vs retinanet.
  • Docs for what is required to pull in a model from HF and configuration parameters.

@jveitchmichaelis jveitchmichaelis requested a review from bw4sz June 14, 2025 00:08
@jveitchmichaelis
Copy link
Collaborator Author

Requesting a WIP review to check that this seems like a sensible path to go down, or if there are any high level things that need changing.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 14, 2025

Changed this to explicitly name itself DeformableDetr. I've tested basic forward passes with dummy data and data loaded from a BoxDataset. Model loading is probably broken because we load_model probably requires the checkpoint to be generated from DeepForest's training pipeline. You can get around this by loading weights with create_model, which default to the MS-COCO version (like Faster-RCNN and retinanet).

So the behavior is that create_model gives you an n-class model, ready to train. Load model should pull down weights from HuggingFace. I've bodged this by loading the weights manually and pushing to hub joshvm/deepforest-detr.

NMS is not bundled into transformers models like it is with RetinaNet. Not sure where the best place to apply this is.

I guess the next step would be to actually try a training loop within DeepForest itself.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 15, 2025

Will revert the NMS config changes, don't think that needs to be here. Same goes for some minor typos etc in the test cases for Faster-RCNN; that can be a separate PR.

Hub loading works OK, but it's quite brittle as the checkpoint is sensitive to any change in the constructor for main. Might want to think about whether we ought to limit the hub checkpoints to just model weights? And not the entire module. I'm not sure it's necessary.

Copy link
Contributor

@henrykironde henrykironde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look good for now. Hope we are planning on incorporating parallelization in this design?

loaded_model = self.from_pretrained(model_name, revision=revision)
self.label_dict = loaded_model.label_dict
self.model = loaded_model.model
self.numeric_to_label_dict = loaded_model.numeric_to_label_dict

# Set bird-specific settings if loading the bird model
if model_name == "weecology/deepforest-bird":
self.config.retinanet.score_thresh = 0.3
self.config.score_thresh = 0.3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this configuration is already included in the config file, so we can likely skip it if that's the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this is an override for the bird model, but it would be more correct to have a different config entirely and not hard code it here.

"""

def __init__(self, config):

# Check for required properties and formats
self.config = config
self.nms_thresh = None # Required for some models but not all
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this here? self.nms_thresh = Non

Copy link
Collaborator Author

@jveitchmichaelis jveitchmichaelis Jun 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm probably not, let me check. I think I included this because initially I thought that the model was responsible for NMS but then I realised that it's a specific feature inside RetinaNet.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 15, 2025

On NMS: According to the DETR authors, the algorithm is supposed to be "NMS-Free", but clearly the post-processed result from transformers still returns a lot of overlapping predictions (and they're not no-object).

This makes sense for training, because the ground truth/prediction matcher is 1:1 by definition (via Hungarian Matching/linear sum assignment), so you don't need to first run NMS to clean up the boxes before calculating the loss. The model should learn to not over-predict.

But when there's nothing to compare against, all the "object" predictions are retained. It might be a poorly trained model (quite possible!) but I'm not sure how non-overlapping we expect the final predictions to be.

See facebookresearch/detr#72

and also the DETA paper which suggests that NMS is actually fine and you don't need bipartite matching anyway, though that's more to provide better guidance during training.

https://arxiv.org/abs/2212.06137

@jveitchmichaelis jveitchmichaelis changed the title WIP: Transformers/DETR integration Transformers/DETR integration Jun 16, 2025
@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 16, 2025

Will begin to revert changes that are moved to other PRs + rebase as we get them in.

See:

@bw4sz
Copy link
Collaborator

bw4sz commented Jun 19, 2025

What's the latest here with NMS? predict_image doesn't perform NMS, only when there are multiple classes present.

https://github.com/weecology/DeepForest/blob/c53f0904725fbba1e08d873032a229bf303a6315/src/deepforest/predict.py#L43C1-L44C1

because torchvision does nms internally in retinanet.

bw4sz
bw4sz previously requested changes Jun 19, 2025
Copy link
Collaborator

@bw4sz bw4sz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Excited about getting this done and ready. I think the essential pieces look good, most of what I want to see are in the tests to make sure that the model covers

m.predict_tile
m.trainer.fit(m)
m.trainer.validate(m) <- even if we don't have a polygon evaluation method yet? just that it goes through the process and calls torchmetrics
m.trainer.predict() <- so we can use multi-gpu inference.

These could be separate PRs too, up to you, just trying to document where in the workflow this is.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 19, 2025

Yeah DETR is (mostly) only boxes. I'll add some integration tests for prediction and training - my hope is that it should be seamless.

We probably do need to add an explicit NMS step unless we can figure out why the trained model is over-predicting. Or maybe it needs a better confidence threshold.

Copy link
Contributor

@henrykironde henrykironde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a few comments. For the tests, I think we should include proper assertions to validate behavior more clearly. Also, we should follow Python naming conventions and rename the file from DeformableDetr.py to deformable_detr.py


def __init__(self, config, **kwargs):
"""
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete the docstring

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming a class as its parent class is little confusing class Model(Model): would it be to much to change to something different ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but we should change the other model subclasses for consistency. This also applies to the filenames (eg FasterRCNN).

@pytest.fixture()
def config():
config = utilities.load_config()
config.model.name = "joshvm/milliontrees-detr"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to move this model to weecology/milliontrees-detr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me see if this is the best checkpoint I have. I think there's likely a better one that doesn't over-predict.

import pytest
import torch
from PIL import Image
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports are not grouped by standard library, third party, and local imports, add spacing between the groups

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have guidelines for this? I normally use isort in pre-commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import os

import cv2
import numpy as np
import rasterio
import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from torchvision import models, transforms
from torchvision.datasets import ImageFolder

I usually use Pycharm, it organizes imports or optimizes imports. But we should have

Organized by type:

Standard library imports first (os)
Third-party imports second (all the others)
Alphabetical ordering within each groups 

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 25, 2025

@henrykironde - review summary:

  • Will add a boolean return value to check_model - that function also asserts internally, but it doesn't return anything that we could use for tests.
  • For the other test we can use predictions with random weights, they'll just have very low scores (fine for the test case)
  • I'm proposing a refactor of weight loading to better support new models: Refactor weight loading from deepforest.main to models #1083 this might also address your comment about Model(Model).
  • If we go with snake case, we should also think about moving to faster_rcnn and making similar changes. That may be breaking?
  • I've run an import sort on the test file - for consistency we ought to also do this for other modules. Alternatively consider automatically doing this on commit. Yapf explicitly doesn't support this, but isort can be run in combination with many formatters in a deterministic way.
  • Removed unused imports from test_retinanet.

@martibosch
Copy link

Hello @jveitchmichaelis

  • I've run an import sort on the test file - for consistency we ought to also do this for other modules. Alternatively consider automatically doing this on commit. Yapf explicitly doesn't support this, but isort can be run in combination with many formatters in a deterministic way.
  • Removed unused imports from test_retinanet.

I have also encountered similar issues in #901 and I think it would indeed be a good idea to enforce a consistent styling. I think that there is a big consensus that using ruff is the way to go, so I have drafted #1084.

Happy to discuss this further. Best,
Martí

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jun 28, 2025

@bw4sz making sure this trains fine with retinanet + MillionTrees currently (using the re-factored hub weight loading in #1083), then will repeat with DETR. I'll add this as a "cookbook" example to the docs, I couldn't see an overview of training a model from scratch and helps reproducibility.

@bw4sz bw4sz linked an issue Jul 3, 2025 that may be closed by this pull request
Copy link
Contributor

@henrykironde henrykironde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jveitchmichaelis for adding the transformers to deepforest and all the other contributors.

@henrykironde henrykironde dismissed bw4sz’s stale review July 4, 2025 06:10

bw4sz gave a thumbs up to moving forward.

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 4, 2025

We need to fix the other PR first (weight reloading) as that directly affects this one, but once that's figured out we can merge this.

That one is quite close, but some things came up in review that need testing.

@henrykironde henrykironde merged commit 6bf5ecb into weecology:main Jul 4, 2025
14 checks passed
@henrykironde
Copy link
Contributor

I didn't see that comment coming in, but If we have to revert let me know @jveitchmichaelis

@jveitchmichaelis
Copy link
Collaborator Author

jveitchmichaelis commented Jul 4, 2025

No problem. Just had a quick look at the diffs, I think it should be OK. I'm going to rebase the other PR and see how that goes, but that's also fairly self contained.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integrate transformers library.
4 participants