Skip to content

raise error in create_trainer when there's a label mismatch #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

dylankershaw
Copy link
Contributor

This PR addresses #574.

Comment on lines +100 to +101
self.label_dict = label_dict
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
Copy link
Contributor Author

@dylankershaw dylankershaw Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these up a bit since label_dict now needs to be initialized before create_trainer is called.

@jveitchmichaelis
Copy link
Collaborator

jveitchmichaelis commented Jul 15, 2025

Thanks! I'd like to review these changes after we've merged #1083, because that PR also adjusts when/where label handling happens (and also enforces other checks). Some of this logic is moved into set_labels (to verify configs) and we will also do some checks when models are created.

For clarity, this PR specifically addresses a third scenario, when we want to set up a training dataset. Overall we want to check:

  • Is the label_dict in the config sane? (deepforest instantiation)
  • Have we tried to make a model with a mis-matched label dict? (model instantiation)
  • Have we tried to pass in a dataset that disagrees with the label_dict? <- (dataset instantiation)

Here I think it suffices to check that labels in all dataset CSVs are present in the label_dict, but label_dict can contain additional keys.

I would consider moving the function call to the dataset class, and not create_trainer. That's really where we need to perform this sanity check and I think this would also only run on-demand. Otherwise your check will be called whenever a deepforest instance is created (if a training CSV is defined).

It would also be logical to perform this alongside other sanity checks, like verifying bounding boxes are in range, image paths exist and so on. But we probably want this to be optional for bigger datasets).

I think the complexity here is as good as we'll get, since you have to iterate over the whole CSV at least once. But I would check if there are faster/parallel ways to do this within pandas.

@jveitchmichaelis
Copy link
Collaborator

jveitchmichaelis commented Jul 15, 2025

The test that breaks is test_checkpoint_label_dict because of the ordering issue between setting the dict and creating the trainer. We're moving towards enforcing this at the config level. So we'd prefer:

m = main.deepforest(config_args={"num_classes": 1},
                        label_dict={
                            "Object": 0
                        })

rather than overwriting the label dict after creation (which was fine to set up a unit test, but in real life I don't know why you'd do it). Currently:

m.create_trainer() #<- test fails here
m.label_dict = {"Object": 0}
m.numeric_to_label_dict = {0: "Object"}
m.trainer.fit(m)
m.trainer.save_checkpoint("{}/checkpoint.pl".format(tmpdir))

would be replaced with:

m.create_trainer() 
m.trainer.fit(m) #<- test would fail here (we're not expecting it to), when Lightning calls m.train_dataloader
m.trainer.save_checkpoint("{}/checkpoint.pl".format(tmpdir))

Copy link
Collaborator

@jveitchmichaelis jveitchmichaelis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As comment above:

  • Can we see what this would look like if the check is performed as a sanity check within the dataset itself (BoxDataset), and not in create_trainer?
  • We should review some of the existing test cases and make sure that they reflect usage patterns we're expecting for v2+

@@ -226,6 +226,11 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs):
Returns:
None
"""
utilities.validate_labels(
label_dict=self.label_dict,
csv_file=self.config.train.csv_file,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note that the solution should also check val.csv_file otherwise this can potentially break after an epoch is complete. But moving the check to the dataset class would do this automatically.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants