Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions luxonis_ml/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -160,29 +161,21 @@ def make_splits(
self,
splits: Optional[
Union[
Dict[str, Sequence[PathType]],
Dict[str, float],
Tuple[float, float, float],
Mapping[str, Sequence[PathType]],
Mapping[str, Union[float, int]],
Tuple[Union[float, int], Union[float, int], Union[float, int]],
]
] = None,
*,
ratios: Optional[
Union[Dict[str, float], Tuple[float, float, float]]
] = None,
definitions: Optional[Dict[str, List[PathType]]] = None,
replace_old_splits: bool = False,
) -> None:
"""Generates splits for the dataset.

@type splits: Optional[Union[Dict[str, Sequence[PathType]], Dict[str, float], Tuple[float, float, float]]]
@param splits: A dictionary of splits or a tuple of ratios for train, val, and test splits. Can be one of:
- A dictionary of splits with keys as split names and values as lists of filepaths
- A dictionary of splits with keys as split names and values as ratios
- A 3-tuple of ratios for train, val, and test splits
@type ratios: Optional[Union[Dict[str, float], Tuple[float, float, float]]]
@param ratios: Deprecated! A dictionary of splits with keys as split names and values as ratios.
@type definitions: Optional[Dict[str, List[PathType]]]
@param definitions: Deprecated! A dictionary of splits with keys as split names and values as lists of filepaths.
@param splits: Splits can be defined in one of the following formats:
- Dict[str, List[PathType]]: Explicit file assignments for each split
- Tuple[float, float, float]: Ratios for C{"train"}, C{"val"}, and C{"test"} splits.
Must sum to 1.0
- Dict[str, float]: Ratios for arbitrary splits. Must sum to 1.0
@type replace_old_splits: bool
@param replace_old_splits: Whether to remove old splits and generate new ones. If set to False, only new files will be added to the splits. Default is False.
"""
Expand Down
204 changes: 105 additions & 99 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from luxonis_ml.typing import PathType
from luxonis_ml.utils import (
LuxonisFileSystem,
deprecated,
environ,
make_progress_bar,
)
Expand Down Expand Up @@ -1003,92 +1002,111 @@ def get_splits(self) -> Optional[Dict[str, List[str]]]:
with open(splits_path, "r") as file:
return json.load(file)

@deprecated(
"ratios",
"definitions",
suggest={"ratios": "splits", "definitions": "splits"},
)
@override
def make_splits(
self,
splits: Optional[
Union[
Mapping[str, Sequence[PathType]],
Mapping[str, float],
Tuple[float, float, float],
Mapping[str, Union[float, int]],
Tuple[Union[float, int], Union[float, int], Union[float, int]],
]
] = None,
*,
ratios: Optional[
Union[Dict[str, float], Tuple[float, float, float]]
] = None,
definitions: Optional[Dict[str, List[PathType]]] = None,
replace_old_splits: bool = False,
) -> None:
if ratios is not None and definitions is not None:
raise ValueError("Cannot provide both ratios and definitions")

if splits is None and ratios is None and definitions is None:
splits = {"train": 0.8, "val": 0.1, "test": 0.1}
splits = splits or {"train": 0.8, "val": 0.1, "test": 0.1}

if splits is not None:
if ratios is not None or definitions is not None:
raise ValueError(
"Cannot provide both splits and ratios/definitions"
)
if isinstance(splits, tuple):
ratios = splits
elif isinstance(splits, dict):
value = next(iter(splits.values()))
if isinstance(value, float):
ratios = splits # type: ignore
elif isinstance(value, list):
definitions = splits # type: ignore

if ratios is not None:
if isinstance(ratios, tuple):
if not len(ratios) == 3:
raise ValueError(
"Ratios must be a tuple of 3 floats for train, val, and test splits"
)
ratios = {
"train": ratios[0],
"val": ratios[1],
"test": ratios[2],
}
sum_ = sum(ratios.values())
if not math.isclose(sum_, 1.0):
raise ValueError(f"Ratios must sum to 1.0, got {sum_:0.4f}")

if definitions is not None:
n_files = sum(map(len, definitions.values()))
if n_files > len(self):
raise ValueError(
"Dataset size is smaller than the total number of files in the definitions. "
f"Dataset size: {len(self)}, Definitions: {n_files}."
)
if isinstance(splits, tuple):
if len(splits) != 3:
raise ValueError("Tuple splits must contain exactly 3 ratios")
splits = {"train": splits[0], "val": splits[1], "test": splits[2]}

splits_to_update: List[str] = []
new_splits: Dict[str, List[str]] = {}
old_splits: Dict[str, List[str]] = defaultdict(list)
split_type = self._validate_splits(splits)

splits_path = get_file(
self.fs,
"metadata/splits.json",
self.metadata_path,
default=self.metadata_path / "splits.json",
)
old_splits = self._load_existing_splits(splits_path)

if split_type == "ratios":
new_splits = self._make_ratio_splits(
splits, # type: ignore
old_splits,
replace_old_splits,
)
else:
new_splits = self._make_definition_splits(splits) # type: ignore

for split, uuids in new_splits.items():
old_splits[split].extend(uuids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only extend if replace_old_splits=False no? Otherwise if split_type="files" wouldn't we in this case now get one image in two times in the splits (once from older split and once from newer one)?


splits_path.write_text(json.dumps(old_splits, indent=4))
with suppress(shutil.SameFileError):
self.fs.put_file(splits_path, "metadata/splits.json")

def _validate_splits(
self,
splits: Union[
Mapping[str, Sequence[PathType]], Mapping[str, Union[float, int]]
],
) -> Literal["ratios", "files"]:
"""Validate split format and values."""
try:
value = next(iter(splits.values()))
except Exception as e:
raise ValueError(
"Invalid split format. "
"Must be either ratios or file assignments"
) from e

if isinstance(value, (float, int)):
sum_ratios = sum(splits.values()) # type: ignore
if not math.isclose(sum_ratios, 1.0):
raise ValueError(
f"Split ratios must sum to 1.0, got {sum_ratios:0.4f}"
)
return "ratios"

if isinstance(value, list):
total_files = sum(len(files) for files in splits.values()) # type: ignore
if total_files > len(self):
raise ValueError(
f"Total files in splits ({total_files}) exceeds "
f"dataset size ({len(self)})"
)
return "files"

raise ValueError(
"Invalid split format. Must be either ratios or file assignments"
)

def _load_existing_splits(self, splits_path: Path) -> Dict[str, List[str]]:
"""Load existing splits from file."""
if splits_path.exists():
with open(splits_path, "r") as file:
old_splits = defaultdict(list, json.load(file))
with open(splits_path) as f:
return defaultdict(list, json.load(f))
return defaultdict(list)

def _make_ratio_splits(
self,
ratios: Mapping[str, Union[float, int]],
old_splits: Dict[str, List[str]],
replace_old: bool,
) -> Dict[str, List[str]]:
"""Create splits based on ratios."""
df = self._load_df_offline(raise_when_empty=True)
defined_uuids = set(
uuid for uuids in old_splits.values() for uuid in uuids
)

if definitions is None:
ratios = ratios or {"train": 0.8, "val": 0.1, "test": 0.1}
df = self._load_df_offline(raise_when_empty=True)
if replace_old:
ids = df.select("uuid").unique().get_column("uuid").to_list()
old_splits.clear()
else:
ids = (
df.filter(~pl.col("uuid").is_in(defined_uuids))
.select("uuid")
Expand All @@ -1097,49 +1115,37 @@ def make_splits(
.to_list()
)
if not ids:
if not replace_old_splits:
raise ValueError(
"No new files to add to splits. "
"If you want to generate new splits, set `replace_old_splits=True`"
)
else:
ids = (
df.select("uuid").unique().get_column("uuid").to_list()
)
old_splits = defaultdict(list)
raise ValueError(
"No new files to add to splits. Use replace_old_splits=True to regenerate"
)

np.random.shuffle(ids)
N = len(ids)
lower_bound = 0
for split, ratio in ratios.items():
upper_bound = lower_bound + math.ceil(N * ratio)
new_splits[split] = ids[lower_bound:upper_bound]
splits_to_update.append(split)
lower_bound = upper_bound
np.random.shuffle(ids)
N = len(ids)
new_splits = {}
offset = 0

else:
index = self._get_file_index(sync_from_cloud=True)
if index is None:
raise FileNotFoundError("File index not found")
for split, filepaths in definitions.items():
splits_to_update.append(split)
if not isinstance(filepaths, list):
raise ValueError(
"Must provide splits as a list of filepaths"
)
ids = [
find_filepath_uuid(filepath, index, raise_on_missing=True)
for filepath in filepaths
]
new_splits[split] = ids
for split, ratio in ratios.items():
n_samples = math.ceil(N * ratio)
new_splits[split] = ids[offset : offset + n_samples]
offset += n_samples

for split, uuids in new_splits.items():
old_splits[split].extend(uuids)
return new_splits

splits_path.write_text(json.dumps(old_splits, indent=4))
def _make_definition_splits(
self, definitions: Mapping[str, Sequence[PathType]]
) -> Dict[str, List[str]]:
"""Create splits from explicit file assignments."""
index = self._get_file_index(
sync_from_cloud=True, raise_when_empty=True
)

with suppress(shutil.SameFileError):
self.fs.put_file(splits_path, "metadata/splits.json")
return {
split: [
find_filepath_uuid(filepath, index, raise_on_missing=True)
for filepath in filepaths
]
for split, filepaths in definitions.items()
}

@staticmethod
@override
Expand Down
8 changes: 5 additions & 3 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,18 @@ def generator(step=15):
with pytest.raises(ValueError):
dataset.make_splits((0.7, 0.1, 0.1, 0.1)) # type: ignore

with pytest.raises(ValueError):
dataset.make_splits((0.7, 0.1, 1), definitions=definitions) # type: ignore

with pytest.raises(ValueError):
dataset.make_splits({"train": 1.5})

with pytest.raises(ValueError):
dataset.make_splits(
{split: defs * 2 for split, defs in splits.items()}
)
with pytest.raises(ValueError):
dataset.make_splits("invalid_argument") # type: ignore

with pytest.raises(ValueError):
dataset.make_splits({"train": ...}) # type: ignore

dataset.add(generator(10))
dataset.make_splits({"custom_split": 1.0})
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def generator():
bucket_storage=BucketStorage.LOCAL,
).add(generator())

dataset.make_splits(ratios=(1, 0, 0))
dataset.make_splits((1, 0, 0))

augmentation_config = [
{
Expand Down
Loading