Skip to content

Commit 8a168f0

Browse files
authored
Better error message on missing task names (#369)
1 parent 7f15cba commit 8a168f0

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

luxonis_ml/data/loaders/luxonis_loader.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,25 @@ def __init__(
154154
self.source_names = self.dataset.get_source_names()
155155

156156
if color_space is None:
157-
color_space = {source: "RGB" for source in self.source_names}
157+
color_space = dict.fromkeys(self.source_names, "RGB")
158158
elif isinstance(color_space, str):
159-
color_space = {source: color_space for source in self.source_names}
159+
color_space = dict.fromkeys(self.source_names, color_space)
160160
elif not isinstance(color_space, dict):
161161
raise ValueError(
162162
"color_space must be either a string or a dictionary"
163163
)
164164
self.color_space = color_space
165165

166166
if self.filter_task_names is not None:
167+
if self.dataset.metadata.tasks:
168+
df_task_names = set(self.dataset.metadata.tasks)
169+
else:
170+
df_task_names = set(self.df["task_name"].to_list())
171+
if extras := set(self.filter_task_names) - df_task_names:
172+
raise ValueError(
173+
f"filter_task_names contains task names that "
174+
f"are not in the dataset: {extras}"
175+
)
167176
self.df = self.df.filter(
168177
pl.col("task_name").is_in(self.filter_task_names)
169178
)

0 commit comments

Comments
 (0)