-
Notifications
You must be signed in to change notification settings - Fork 213
raise error in create_trainer when there's a label mismatch #1093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
self.label_dict = label_dict | ||
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved these up a bit since label_dict
now needs to be initialized before create_trainer
is called.
Thanks! I'd like to review these changes after we've merged #1083, because that PR also adjusts when/where label handling happens (and also enforces other checks). Some of this logic is moved into For clarity, this PR specifically addresses a third scenario, when we want to set up a training dataset. Overall we want to check:
Here I think it suffices to check that labels in all dataset CSVs are present in the label_dict, but label_dict can contain additional keys. I would consider moving the function call to the dataset class, and not It would also be logical to perform this alongside other sanity checks, like verifying bounding boxes are in range, image paths exist and so on. But we probably want this to be optional for bigger datasets). I think the complexity here is as good as we'll get, since you have to iterate over the whole CSV at least once. But I would check if there are faster/parallel ways to do this within pandas. |
The test that breaks is m = main.deepforest(config_args={"num_classes": 1},
label_dict={
"Object": 0
}) rather than overwriting the label dict after creation (which was fine to set up a unit test, but in real life I don't know why you'd do it). Currently: m.create_trainer() #<- test fails here
m.label_dict = {"Object": 0}
m.numeric_to_label_dict = {0: "Object"}
m.trainer.fit(m)
m.trainer.save_checkpoint("{}/checkpoint.pl".format(tmpdir)) would be replaced with: m.create_trainer()
m.trainer.fit(m) #<- test would fail here (we're not expecting it to), when Lightning calls m.train_dataloader
m.trainer.save_checkpoint("{}/checkpoint.pl".format(tmpdir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As comment above:
- Can we see what this would look like if the check is performed as a sanity check within the dataset itself (
BoxDataset
), and not increate_trainer
? - We should review some of the existing test cases and make sure that they reflect usage patterns we're expecting for v2+
@@ -226,6 +226,11 @@ def create_trainer(self, logger=None, callbacks=[], **kwargs): | |||
Returns: | |||
None | |||
""" | |||
utilities.validate_labels( | |||
label_dict=self.label_dict, | |||
csv_file=self.config.train.csv_file, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note that the solution should also check val.csv_file
otherwise this can potentially break after an epoch is complete. But moving the check to the dataset class would do this automatically.
This PR addresses #574.