Transformers/DETR integration#1078
Conversation
|
Requesting a WIP review to check that this seems like a sensible path to go down, or if there are any high level things that need changing. |
19c0dbc to
2ae6745
Compare
|
Changed this to explicitly name itself DeformableDetr. I've tested basic forward passes with dummy data and data loaded from a BoxDataset. Model loading is probably broken because we So the behavior is that NMS is not bundled into I guess the next step would be to actually try a training loop within DeepForest itself. |
|
Will revert the NMS config changes, don't think that needs to be here. Same goes for some minor typos etc in the test cases for Faster-RCNN; that can be a separate PR. Hub loading works OK, but it's quite brittle as the checkpoint is sensitive to any change in the constructor for |
henrykironde
left a comment
There was a problem hiding this comment.
These changes look good for now. Hope we are planning on incorporating parallelization in this design?
| # Set bird-specific settings if loading the bird model | ||
| if model_name == "weecology/deepforest-bird": | ||
| self.config.retinanet.score_thresh = 0.3 | ||
| self.config.score_thresh = 0.3 |
There was a problem hiding this comment.
I believe this configuration is already included in the config file, so we can likely skip it if that's the case.
There was a problem hiding this comment.
Yeah I think this is an override for the bird model, but it would be more correct to have a different config entirely and not hard code it here.
|
|
||
| # Check for required properties and formats | ||
| self.config = config | ||
| self.nms_thresh = None # Required for some models but not all |
There was a problem hiding this comment.
Do we really need this here? self.nms_thresh = Non
There was a problem hiding this comment.
Hmm probably not, let me check. I think I included this because initially I thought that the model was responsible for NMS but then I realised that it's a specific feature inside RetinaNet.
|
On NMS: According to the DETR authors, the algorithm is supposed to be "NMS-Free", but clearly the post-processed result from This makes sense for training, because the ground truth/prediction matcher is 1:1 by definition (via Hungarian Matching/linear sum assignment), so you don't need to first run NMS to clean up the boxes before calculating the loss. The model should learn to not over-predict. But when there's nothing to compare against, all the "object" predictions are retained. It might be a poorly trained model (quite possible!) but I'm not sure how non-overlapping we expect the final predictions to be. and also the DETA paper which suggests that NMS is actually fine and you don't need bipartite matching anyway, though that's more to provide better guidance during training. |
|
Will begin to revert changes that are moved to other PRs + rebase as we get them in. See:
|
|
What's the latest here with NMS? predict_image doesn't perform NMS, only when there are multiple classes present. because torchvision does nms internally in retinanet. |
bw4sz
left a comment
There was a problem hiding this comment.
Great! Excited about getting this done and ready. I think the essential pieces look good, most of what I want to see are in the tests to make sure that the model covers
m.predict_tile
m.trainer.fit(m)
m.trainer.validate(m) <- even if we don't have a polygon evaluation method yet? just that it goes through the process and calls torchmetrics
m.trainer.predict() <- so we can use multi-gpu inference.
These could be separate PRs too, up to you, just trying to document where in the workflow this is.
|
Yeah DETR is (mostly) only boxes. I'll add some integration tests for prediction and training - my hope is that it should be seamless. We probably do need to add an explicit NMS step unless we can figure out why the trained model is over-predicting. Or maybe it needs a better confidence threshold. |
henrykironde
left a comment
There was a problem hiding this comment.
I've added a few comments. For the tests, I think we should include proper assertions to validate behavior more clearly. Also, we should follow Python naming conventions and rename the file from DeformableDetr.py to deformable_detr.py
|
|
||
| def __init__(self, config, **kwargs): | ||
| """ | ||
| Args: |
There was a problem hiding this comment.
Complete the docstring
There was a problem hiding this comment.
Naming a class as its parent class is little confusing class Model(Model): would it be to much to change to something different ?
There was a problem hiding this comment.
Agree, but we should change the other model subclasses for consistency. This also applies to the filenames (eg FasterRCNN).
| @pytest.fixture() | ||
| def config(): | ||
| config = utilities.load_config() | ||
| config.model.name = "joshvm/milliontrees-detr" |
There was a problem hiding this comment.
We need to move this model to weecology/milliontrees-detr
There was a problem hiding this comment.
Sure, let me see if this is the best checkpoint I have. I think there's likely a better one that doesn't over-predict.
| import pytest | ||
| import torch | ||
| from PIL import Image | ||
| import numpy as np |
There was a problem hiding this comment.
Imports are not grouped by standard library, third party, and local imports, add spacing between the groups
There was a problem hiding this comment.
Do we have guidelines for this? I normally use isort in pre-commit.
There was a problem hiding this comment.
import os
import cv2
import numpy as np
import rasterio
import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from torchvision import models, transforms
from torchvision.datasets import ImageFolderI usually use Pycharm, it organizes imports or optimizes imports. But we should have
Organized by type:
Standard library imports first (os)
Third-party imports second (all the others)
Alphabetical ordering within each groups
|
@henrykironde - review summary:
|
|
Hello @jveitchmichaelis
I have also encountered similar issues in #901 and I think it would indeed be a good idea to enforce a consistent styling. I think that there is a big consensus that using ruff is the way to go, so I have drafted #1084. Happy to discuss this further. Best, |
henrykironde
left a comment
There was a problem hiding this comment.
Thanks @jveitchmichaelis for adding the transformers to deepforest and all the other contributors.
bw4sz gave a thumbs up to moving forward.
|
We need to fix the other PR first (weight reloading) as that directly affects this one, but once that's figured out we can merge this. That one is quite close, but some things came up in review that need testing. |
|
I didn't see that comment coming in, but If we have to revert let me know @jveitchmichaelis |
|
No problem. Just had a quick look at the diffs, I think it should be OK. I'm going to rebase the other PR and see how that goes, but that's also fairly self contained. |
|
Hello! sorry to comment on a closed PR, but I am trying to run the snippet from the first comment and I get the following error: I also get it when loading the default model, e.g.,: I have just installed it from github as in: Any clues on why this happens? It may not be because of this feature, but I wonder how you managed to run the snippet. Thank you! |
|
My guess is the manifest is out of date. It's referencing the old config.yml file
and not config.yaml. OmegaConf/Hydra requires the "a". Try updating the path in MANIFEST.in and see if the
install works (I've submitted #1092).
|
|
You could also try the incantation we use in the Action: `uv sync --all-extras --dev`
|
|
Hello @jveitchmichaelis! thank you, the latter worked, I only had to change "architecture" from "detr" to "DeformableDetr". |
This PR includes rudimentary support for transformers object detection models, specifically those supported by
AutoModelForObjectDetection. I've trained a simple DeformableDetr model as a demonstration using the latest public release of MillionTrees which is hosted here (for testing): https://huggingface.co/joshvm/milliontrees-detrThe PR fixes a few issues in the process:
uvto add the pyproject line, it should be compatible with the python versions that we support.config.score_threshparameter is moved to be a top-level config item as there's no need for it to be specific to retinanet, and other places it's referenced have been updated. This is also used for post-processing results from DETR.DeformableDETR, but really it could support a variety of models. These can be separate files, but we would likely want to re-use the wrapper.An advantage of the current approach is that it requires no changes to
mainat the moment, the implementation is isolated tomodels/detr.py. I need to check compatibility with other prediction routes (like via the trainer) but theforwardfunction returns [boxes, labels, scores] in a DeepForest-friendly format without any modification. The wrapper class is a bit awkward, but necessary because of the way thatModeland sub-classes are currently implemented (since we need to run pre- and post-process which is transformers specific).The model and processor (and their config as it relates to DeepForest) should be coupled, so I don't think we need to worry about exposing the details. So I think it's OK to have both loaded with
config.model.nameand keep the config fixed, rather than allow users to mess with image resizing. Other options like rescaling (1/255) and normalization should be preset so that the processor works with our dataset + loader.NMS for
predict_imagedoesn't seem to be working though, and I'm not sure why (@bw4sz?), see the following test script:Training is going to need a bigger refactor (probably) because the various Lightning
steps inmainmake assumptions that might be tied to torchvision-type models.TODO: