-
Notifications
You must be signed in to change notification settings - Fork 214
Transformers/DETR integration #1078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformers/DETR integration #1078
Conversation
Requesting a WIP review to check that this seems like a sensible path to go down, or if there are any high level things that need changing. |
19c0dbc
to
2ae6745
Compare
Changed this to explicitly name itself DeformableDetr. I've tested basic forward passes with dummy data and data loaded from a BoxDataset. Model loading is probably broken because we So the behavior is that NMS is not bundled into I guess the next step would be to actually try a training loop within DeepForest itself. |
Will revert the NMS config changes, don't think that needs to be here. Same goes for some minor typos etc in the test cases for Faster-RCNN; that can be a separate PR. Hub loading works OK, but it's quite brittle as the checkpoint is sensitive to any change in the constructor for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes look good for now. Hope we are planning on incorporating parallelization in this design?
loaded_model = self.from_pretrained(model_name, revision=revision) | ||
self.label_dict = loaded_model.label_dict | ||
self.model = loaded_model.model | ||
self.numeric_to_label_dict = loaded_model.numeric_to_label_dict | ||
|
||
# Set bird-specific settings if loading the bird model | ||
if model_name == "weecology/deepforest-bird": | ||
self.config.retinanet.score_thresh = 0.3 | ||
self.config.score_thresh = 0.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this configuration is already included in the config file, so we can likely skip it if that's the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think this is an override for the bird model, but it would be more correct to have a different config entirely and not hard code it here.
""" | ||
|
||
def __init__(self, config): | ||
|
||
# Check for required properties and formats | ||
self.config = config | ||
self.nms_thresh = None # Required for some models but not all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this here? self.nms_thresh = Non
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm probably not, let me check. I think I included this because initially I thought that the model was responsible for NMS but then I realised that it's a specific feature inside RetinaNet.
On NMS: According to the DETR authors, the algorithm is supposed to be "NMS-Free", but clearly the post-processed result from This makes sense for training, because the ground truth/prediction matcher is 1:1 by definition (via Hungarian Matching/linear sum assignment), so you don't need to first run NMS to clean up the boxes before calculating the loss. The model should learn to not over-predict. But when there's nothing to compare against, all the "object" predictions are retained. It might be a poorly trained model (quite possible!) but I'm not sure how non-overlapping we expect the final predictions to be. and also the DETA paper which suggests that NMS is actually fine and you don't need bipartite matching anyway, though that's more to provide better guidance during training. |
Will begin to revert changes that are moved to other PRs + rebase as we get them in. See:
|
What's the latest here with NMS? predict_image doesn't perform NMS, only when there are multiple classes present. because torchvision does nms internally in retinanet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Excited about getting this done and ready. I think the essential pieces look good, most of what I want to see are in the tests to make sure that the model covers
m.predict_tile
m.trainer.fit(m)
m.trainer.validate(m) <- even if we don't have a polygon evaluation method yet? just that it goes through the process and calls torchmetrics
m.trainer.predict() <- so we can use multi-gpu inference.
These could be separate PRs too, up to you, just trying to document where in the workflow this is.
Yeah DETR is (mostly) only boxes. I'll add some integration tests for prediction and training - my hope is that it should be seamless. We probably do need to add an explicit NMS step unless we can figure out why the trained model is over-predicting. Or maybe it needs a better confidence threshold. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few comments. For the tests, I think we should include proper assertions to validate behavior more clearly. Also, we should follow Python naming conventions and rename the file from DeformableDetr.py
to deformable_detr.py
|
||
def __init__(self, config, **kwargs): | ||
""" | ||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming a class as its parent class is little confusing class Model(Model):
would it be to much to change to something different ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, but we should change the other model subclasses for consistency. This also applies to the filenames (eg FasterRCNN).
@pytest.fixture() | ||
def config(): | ||
config = utilities.load_config() | ||
config.model.name = "joshvm/milliontrees-detr" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to move this model to weecology/milliontrees-detr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let me see if this is the best checkpoint I have. I think there's likely a better one that doesn't over-predict.
import pytest | ||
import torch | ||
from PIL import Image | ||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imports are not grouped by standard library, third party, and local imports, add spacing between the groups
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have guidelines for this? I normally use isort
in pre-commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import os
import cv2
import numpy as np
import rasterio
import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
I usually use Pycharm, it organizes imports or optimizes imports. But we should have
Organized by type:
Standard library imports first (os)
Third-party imports second (all the others)
Alphabetical ordering within each groups
@henrykironde - review summary:
|
Hello @jveitchmichaelis
I have also encountered similar issues in #901 and I think it would indeed be a good idea to enforce a consistent styling. I think that there is a big consensus that using ruff is the way to go, so I have drafted #1084. Happy to discuss this further. Best, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jveitchmichaelis for adding the transformers to deepforest and all the other contributors.
bw4sz gave a thumbs up to moving forward.
We need to fix the other PR first (weight reloading) as that directly affects this one, but once that's figured out we can merge this. That one is quite close, but some things came up in review that need testing. |
I didn't see that comment coming in, but If we have to revert let me know @jveitchmichaelis |
No problem. Just had a quick look at the diffs, I think it should be OK. I'm going to rebase the other PR and see how that goes, but that's also fairly self contained. |
This PR includes rudimentary support for transformers object detection models, specifically those supported by
AutoModelForObjectDetection
. I've trained a simple DeformableDetr model as a demonstration using the latest public release of MillionTrees which is hosted here (for testing): https://huggingface.co/joshvm/milliontrees-detrThe PR fixes a few issues in the process:
uv
to add the pyproject line, it should be compatible with the python versions that we support.config.score_thresh
parameter is moved to be a top-level config item as there's no need for it to be specific to retinanet, and other places it's referenced have been updated. This is also used for post-processing results from DETR.DeformableDETR
, but really it could support a variety of models. These can be separate files, but we would likely want to re-use the wrapper.An advantage of the current approach is that it requires no changes to
main
at the moment, the implementation is isolated tomodels/detr.py
. I need to check compatibility with other prediction routes (like via the trainer) but theforward
function returns [boxes, labels, scores] in a DeepForest-friendly format without any modification. The wrapper class is a bit awkward, but necessary because of the way thatModel
and sub-classes are currently implemented (since we need to run pre- and post-process which is transformers specific).The model and processor (and their config as it relates to DeepForest) should be coupled, so I don't think we need to worry about exposing the details. So I think it's OK to have both loaded with
config.model.name
and keep the config fixed, rather than allow users to mess with image resizing. Other options like rescaling (1/255) and normalization should be preset so that the processor works with our dataset + loader.NMS for
predict_image
doesn't seem to be working though, and I'm not sure why (@bw4sz?), see the following test script:Training is going to need a bigger refactor (probably) because the various Lightning
step
s inmain
make assumptions that might be tied to torchvision-type models.TODO: