Skip to content

Commit cb9dcc3

Browse files
kozlov721JSabadin
andauthored
LuxonisML Data Speedup (#117)
Co-authored-by: Jernej Sabadin <[email protected]>
1 parent 8f38f88 commit cb9dcc3

File tree

9 files changed

+556
-233
lines changed

9 files changed

+556
-233
lines changed

examples/Data_Custom_Example.ipynb

Lines changed: 343 additions & 20 deletions
Large diffs are not rendered by default.

luxonis_ml/data/datasets/luxonis_dataset.py

Lines changed: 128 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from typing import Dict, List, Optional, Set, Tuple
1010

1111
import numpy as np
12-
import pandas as pd
13-
import pyarrow as pa
12+
import polars as pl
1413
import pyarrow.parquet as pq
1514
import rich.progress
1615

@@ -138,10 +137,7 @@ def __len__(self) -> int:
138137
"""Returns the number of instances in the dataset."""
139138

140139
df = self._load_df_offline(self.bucket_storage != BucketStorage.LOCAL)
141-
if df is not None:
142-
return len(set(df["uuid"]))
143-
else:
144-
return 0
140+
return len(df.select("uuid").unique()) if df is not None else 0
145141

146142
def _write_datasets(self) -> None:
147143
with open(self.datasets_cache_file, "w") as file:
@@ -182,7 +178,7 @@ def _init_path(self) -> None:
182178
f"{self.team_id}/datasets/{self.dataset_name}"
183179
)
184180

185-
def _load_df_offline(self, sync_mode: bool = False) -> Optional[pd.DataFrame]:
181+
def _load_df_offline(self, sync_mode: bool = False) -> Optional[pl.DataFrame]:
186182
dfs = []
187183
if self.bucket_storage == BucketStorage.LOCAL or sync_mode:
188184
annotations_path = self.annotations_path
@@ -192,32 +188,32 @@ def _load_df_offline(self, sync_mode: bool = False) -> Optional[pd.DataFrame]:
192188
return None
193189
for file in annotations_path.iterdir():
194190
if file.suffix == ".parquet":
195-
dfs.append(pd.read_parquet(file))
196-
if len(dfs):
197-
return pd.concat(dfs)
191+
dfs.append(pl.read_parquet(file))
192+
if dfs:
193+
return pl.concat(dfs)
198194
else:
199195
return None
200196

201197
def _find_filepath_uuid(
202198
self,
203199
filepath: Path,
204-
index: Optional[pd.DataFrame],
200+
index: Optional[pl.DataFrame],
205201
*,
206202
raise_on_missing: bool = False,
207203
) -> Optional[str]:
208204
if index is None:
209205
return None
210206

211207
abs_path = str(filepath.absolute())
212-
if abs_path in list(index["original_filepath"]):
213-
matched = index[index["original_filepath"] == abs_path]
214-
if len(matched):
215-
return list(matched["uuid"])[0]
208+
matched = index.filter(pl.col("original_filepath") == abs_path)
209+
210+
if len(matched):
211+
return list(matched.select("uuid"))[0][0]
216212
elif raise_on_missing:
217213
raise ValueError(f"File {abs_path} not found in index")
218214
return None
219215

220-
def _get_file_index(self) -> Optional[pd.DataFrame]:
216+
def _get_file_index(self) -> Optional[pl.DataFrame]:
221217
index = None
222218
if self.bucket_storage == BucketStorage.LOCAL:
223219
file_index_path = self.metadata_path / "file_index.parquet"
@@ -228,24 +224,25 @@ def _get_file_index(self) -> Optional[pd.DataFrame]:
228224
except Exception:
229225
pass
230226
if file_index_path.exists():
231-
index = pd.read_parquet(file_index_path)
227+
index = pl.read_parquet(file_index_path).select(
228+
pl.all().exclude("^__index_level_.*$")
229+
)
232230
return index
233231

234232
def _write_index(
235233
self,
236-
index: Optional[pd.DataFrame],
237-
new_index: Dict,
234+
index: Optional[pl.DataFrame],
235+
new_index: Dict[str, List[str]],
238236
override_path: Optional[str] = None,
239237
) -> None:
240238
if override_path:
241239
file_index_path = override_path
242240
else:
243241
file_index_path = self.metadata_path / "file_index.parquet"
244-
df = pd.DataFrame(new_index)
242+
df = pl.DataFrame(new_index)
245243
if index is not None:
246-
df = pd.concat([index, df])
247-
table = pa.Table.from_pandas(df)
248-
pq.write_table(table, file_index_path)
244+
df = pl.concat([index, df])
245+
pq.write_table(df.to_arrow(), file_index_path)
249246

250247
@contextmanager
251248
def _log_time(self):
@@ -358,130 +355,141 @@ def delete_dataset(self) -> None:
358355
if self.bucket_storage == BucketStorage.LOCAL:
359356
shutil.rmtree(self.path)
360357

361-
def add(
362-
self,
363-
generator: DatasetIterator,
364-
batch_size: int = 1_000_000,
365-
) -> None:
366-
def _process_arrays(batch_data: List[DatasetRecord]) -> None:
367-
array_paths = set(
368-
ann.path for ann in batch_data if isinstance(ann, ArrayAnnotation)
358+
def _process_arrays(self, batch_data: List[DatasetRecord]) -> None:
359+
array_paths = set(
360+
ann.path for ann in batch_data if isinstance(ann, ArrayAnnotation)
361+
)
362+
if array_paths:
363+
task = self.progress.add_task(
364+
"[magenta]Processing arrays...", total=len(batch_data)
369365
)
370-
if array_paths:
371-
task = self.progress.add_task(
372-
"[magenta]Processing arrays...", total=len(batch_data)
373-
)
374-
self.logger.info("Checking arrays...")
375-
with self._log_time():
376-
data_utils.check_arrays(array_paths)
377-
self.logger.info("Generating array UUIDs...")
378-
with self._log_time():
379-
array_uuid_dict = self.fs.get_file_uuids(
380-
array_paths, local=True
381-
) # TODO: support from bucket
382-
if self.bucket_storage != BucketStorage.LOCAL:
383-
self.logger.info("Uploading arrays...")
384-
# TODO: support from bucket (likely with a self.fs.copy_dir)
385-
with self._log_time():
386-
arrays_upload_dict = self.fs.put_dir(
387-
local_paths=array_paths,
388-
remote_dir="arrays",
389-
uuid_dict=array_uuid_dict,
390-
)
391-
self.logger.info("Finalizing paths...")
392-
self.progress.start()
393-
for ann in batch_data:
394-
if isinstance(ann, ArrayAnnotation):
395-
if self.bucket_storage != BucketStorage.LOCAL:
396-
remote_path = arrays_upload_dict[str(ann.path)] # type: ignore
397-
remote_path = (
398-
f"{self.fs.protocol}://{self.fs.path / remote_path}"
399-
)
400-
ann.path = remote_path # type: ignore
401-
else:
402-
ann.path = ann.path.absolute()
403-
self.progress.update(task, advance=1)
404-
self.progress.stop()
405-
self.progress.remove_task(task)
406-
407-
def _add_process_batch(batch_data: List[DatasetRecord]) -> None:
408-
paths = list(set(data.file for data in batch_data))
409-
self.logger.info("Generating UUIDs...")
366+
self.logger.info("Checking arrays...")
367+
with self._log_time():
368+
data_utils.check_arrays(array_paths)
369+
self.logger.info("Generating array UUIDs...")
410370
with self._log_time():
411-
uuid_dict = self.fs.get_file_uuids(
412-
paths, local=True
371+
array_uuid_dict = self.fs.get_file_uuids(
372+
array_paths, local=True
413373
) # TODO: support from bucket
414374
if self.bucket_storage != BucketStorage.LOCAL:
415-
self.logger.info("Uploading media...")
375+
self.logger.info("Uploading arrays...")
416376
# TODO: support from bucket (likely with a self.fs.copy_dir)
417-
418377
with self._log_time():
419-
self.fs.put_dir(
420-
local_paths=paths, remote_dir="media", uuid_dict=uuid_dict
378+
arrays_upload_dict = self.fs.put_dir(
379+
local_paths=array_paths,
380+
remote_dir="arrays",
381+
uuid_dict=array_uuid_dict,
421382
)
383+
self.logger.info("Finalizing paths...")
384+
self.progress.start()
385+
for ann in batch_data:
386+
if isinstance(ann, ArrayAnnotation):
387+
if self.bucket_storage != BucketStorage.LOCAL:
388+
remote_path = arrays_upload_dict[str(ann.path)] # type: ignore
389+
remote_path = (
390+
f"{self.fs.protocol}://{self.fs.path / remote_path}"
391+
)
392+
ann.path = remote_path # type: ignore
393+
else:
394+
ann.path = ann.path.absolute()
395+
self.progress.update(task, advance=1)
396+
self.progress.stop()
397+
self.progress.remove_task(task)
422398

423-
task = self.progress.add_task(
424-
"[magenta]Processing data...", total=len(batch_data)
425-
)
426-
427-
_process_arrays(batch_data)
399+
def _add_process_batch(
400+
self,
401+
batch_data: List[DatasetRecord],
402+
pfm: ParquetFileManager,
403+
index: Optional[pl.DataFrame],
404+
new_index: Dict[str, List[str]],
405+
processed_uuids: Set[str],
406+
) -> None:
407+
paths = list(set(data.file for data in batch_data))
408+
self.logger.info("Generating UUIDs...")
409+
with self._log_time():
410+
uuid_dict = self.fs.get_file_uuids(
411+
paths, local=True
412+
) # TODO: support from bucket
413+
if self.bucket_storage != BucketStorage.LOCAL:
414+
self.logger.info("Uploading media...")
415+
# TODO: support from bucket (likely with a self.fs.copy_dir)
428416

429-
self.logger.info("Saving annotations...")
430417
with self._log_time():
431-
self.progress.start()
432-
for ann in batch_data:
433-
filepath = ann.file
434-
file = filepath.name
435-
uuid = uuid_dict[str(filepath)]
436-
matched_id = self._find_filepath_uuid(filepath, index)
437-
if matched_id is not None:
438-
if matched_id != uuid:
439-
# TODO: not sure if this should be an exception or how we should really handle it
440-
raise Exception(
441-
f"{filepath} already added to the dataset! Please skip or rename the file."
442-
)
443-
# TODO: we may also want to check for duplicate uuids to get a one-to-one relationship
444-
elif uuid not in new_index["uuid"]:
445-
new_index["uuid"].append(uuid)
446-
new_index["file"].append(file)
447-
new_index["original_filepath"].append(str(filepath.absolute()))
448-
449-
self.pfm.write({"uuid": uuid, **ann.to_parquet_dict()})
450-
self.progress.update(task, advance=1)
451-
self.progress.stop()
452-
self.progress.remove_task(task)
418+
self.fs.put_dir(
419+
local_paths=paths, remote_dir="media", uuid_dict=uuid_dict
420+
)
453421

422+
task = self.progress.add_task(
423+
"[magenta]Processing data...", total=len(batch_data)
424+
)
425+
426+
self._process_arrays(batch_data)
427+
428+
self.logger.info("Saving annotations...")
429+
with self._log_time():
430+
self.progress.start()
431+
for ann in batch_data:
432+
filepath = ann.file
433+
file = filepath.name
434+
uuid = uuid_dict[str(filepath)]
435+
matched_id = self._find_filepath_uuid(filepath, index)
436+
if matched_id is not None:
437+
if matched_id != uuid:
438+
# TODO: not sure if this should be an exception or how we should really handle it
439+
raise Exception(
440+
f"{filepath} already added to the dataset! Please skip or rename the file."
441+
)
442+
# TODO: we may also want to check for duplicate uuids to get a one-to-one relationship
443+
elif uuid not in processed_uuids:
444+
new_index["uuid"].append(uuid)
445+
new_index["file"].append(file)
446+
new_index["original_filepath"].append(str(filepath.absolute()))
447+
processed_uuids.add(uuid)
448+
449+
pfm.write({"uuid": uuid, **ann.to_parquet_dict()})
450+
self.progress.update(task, advance=1)
451+
self.progress.stop()
452+
self.progress.remove_task(task)
453+
454+
def add(self, generator: DatasetIterator, batch_size: int = 1_000_000) -> None:
454455
if self.bucket_storage == BucketStorage.LOCAL:
455-
self.pfm = ParquetFileManager(str(self.annotations_path))
456+
annotations_dir = self.annotations_path
456457
else:
457458
self._make_temp_dir()
458459
annotations_dir = self.tmp_dir / "annotations"
459460
annotations_dir.mkdir(exist_ok=True, parents=True)
460-
self.pfm = ParquetFileManager(str(annotations_dir))
461461

462462
index = self._get_file_index()
463463
new_index = {"uuid": [], "file": [], "original_filepath": []}
464+
processed_uuids = set()
464465

465466
batch_data: list[DatasetRecord] = []
466467

467468
classes_per_task: Dict[str, Set[str]] = defaultdict(set)
468469
num_kpts_per_task: Dict[str, int] = {}
469470

470-
for i, data in enumerate(generator, start=1):
471-
record = data if isinstance(data, DatasetRecord) else DatasetRecord(**data)
472-
if record.annotation is not None:
473-
classes_per_task[record.annotation.task].add(record.annotation.class_)
474-
if record.annotation.type_ == "keypoints":
475-
num_kpts_per_task[record.annotation.task] = len(
476-
record.annotation.keypoints
471+
with ParquetFileManager(annotations_dir) as pfm:
472+
for i, data in enumerate(generator, start=1):
473+
record = (
474+
data if isinstance(data, DatasetRecord) else DatasetRecord(**data)
475+
)
476+
if record.annotation is not None:
477+
classes_per_task[record.annotation.task].add(
478+
record.annotation.class_
477479
)
480+
if record.annotation.type_ == "keypoints":
481+
num_kpts_per_task[record.annotation.task] = len(
482+
record.annotation.keypoints
483+
)
478484

479-
batch_data.append(record)
480-
if i % batch_size == 0:
481-
_add_process_batch(batch_data)
482-
batch_data = []
485+
batch_data.append(record)
486+
if i % batch_size == 0:
487+
self._add_process_batch(
488+
batch_data, pfm, index, new_index, processed_uuids
489+
)
490+
batch_data = []
483491

484-
_add_process_batch(batch_data)
492+
self._add_process_batch(batch_data, pfm, index, new_index, processed_uuids)
485493

486494
_, curr_classes = self.get_classes()
487495
for task, classes in classes_per_task.items():
@@ -490,8 +498,6 @@ def _add_process_batch(batch_data: List[DatasetRecord]) -> None:
490498
self.logger.info(f"Detected new classes for task {task}: {new_classes}")
491499
self.set_classes(list(classes | old_classes), task, _remove_tmp_dir=False)
492500

493-
self.pfm.close()
494-
495501
if self.bucket_storage == BucketStorage.LOCAL:
496502
self._write_index(index, new_index)
497503
else:
@@ -516,7 +522,7 @@ def make_splits(
516522

517523
df = self._load_df_offline()
518524
assert df is not None
519-
ids = list(set(df["uuid"]))
525+
ids = df.select("uuid").unique().get_column("uuid").to_list()
520526
np.random.shuffle(ids)
521527
N = len(ids)
522528
b1 = round(N * ratios[0])

0 commit comments

Comments
 (0)