99from typing import Dict , List , Optional , Set , Tuple
1010
1111import numpy as np
12- import pandas as pd
13- import pyarrow as pa
12+ import polars as pl
1413import pyarrow .parquet as pq
1514import rich .progress
1615
@@ -138,10 +137,7 @@ def __len__(self) -> int:
138137 """Returns the number of instances in the dataset."""
139138
140139 df = self ._load_df_offline (self .bucket_storage != BucketStorage .LOCAL )
141- if df is not None :
142- return len (set (df ["uuid" ]))
143- else :
144- return 0
140+ return len (df .select ("uuid" ).unique ()) if df is not None else 0
145141
146142 def _write_datasets (self ) -> None :
147143 with open (self .datasets_cache_file , "w" ) as file :
@@ -182,7 +178,7 @@ def _init_path(self) -> None:
182178 f"{ self .team_id } /datasets/{ self .dataset_name } "
183179 )
184180
185- def _load_df_offline (self , sync_mode : bool = False ) -> Optional [pd .DataFrame ]:
181+ def _load_df_offline (self , sync_mode : bool = False ) -> Optional [pl .DataFrame ]:
186182 dfs = []
187183 if self .bucket_storage == BucketStorage .LOCAL or sync_mode :
188184 annotations_path = self .annotations_path
@@ -192,32 +188,32 @@ def _load_df_offline(self, sync_mode: bool = False) -> Optional[pd.DataFrame]:
192188 return None
193189 for file in annotations_path .iterdir ():
194190 if file .suffix == ".parquet" :
195- dfs .append (pd .read_parquet (file ))
196- if len ( dfs ) :
197- return pd .concat (dfs )
191+ dfs .append (pl .read_parquet (file ))
192+ if dfs :
193+ return pl .concat (dfs )
198194 else :
199195 return None
200196
201197 def _find_filepath_uuid (
202198 self ,
203199 filepath : Path ,
204- index : Optional [pd .DataFrame ],
200+ index : Optional [pl .DataFrame ],
205201 * ,
206202 raise_on_missing : bool = False ,
207203 ) -> Optional [str ]:
208204 if index is None :
209205 return None
210206
211207 abs_path = str (filepath .absolute ())
212- if abs_path in list ( index [ "original_filepath" ]):
213- matched = index [ index [ "original_filepath" ] == abs_path ]
214- if len (matched ):
215- return list (matched [ "uuid" ]) [0 ]
208+ matched = index . filter ( pl . col ( "original_filepath" ) == abs_path )
209+
210+ if len (matched ):
211+ return list (matched . select ( "uuid" ))[ 0 ] [0 ]
216212 elif raise_on_missing :
217213 raise ValueError (f"File { abs_path } not found in index" )
218214 return None
219215
220- def _get_file_index (self ) -> Optional [pd .DataFrame ]:
216+ def _get_file_index (self ) -> Optional [pl .DataFrame ]:
221217 index = None
222218 if self .bucket_storage == BucketStorage .LOCAL :
223219 file_index_path = self .metadata_path / "file_index.parquet"
@@ -228,24 +224,25 @@ def _get_file_index(self) -> Optional[pd.DataFrame]:
228224 except Exception :
229225 pass
230226 if file_index_path .exists ():
231- index = pd .read_parquet (file_index_path )
227+ index = pl .read_parquet (file_index_path ).select (
228+ pl .all ().exclude ("^__index_level_.*$" )
229+ )
232230 return index
233231
234232 def _write_index (
235233 self ,
236- index : Optional [pd .DataFrame ],
237- new_index : Dict ,
234+ index : Optional [pl .DataFrame ],
235+ new_index : Dict [ str , List [ str ]] ,
238236 override_path : Optional [str ] = None ,
239237 ) -> None :
240238 if override_path :
241239 file_index_path = override_path
242240 else :
243241 file_index_path = self .metadata_path / "file_index.parquet"
244- df = pd .DataFrame (new_index )
242+ df = pl .DataFrame (new_index )
245243 if index is not None :
246- df = pd .concat ([index , df ])
247- table = pa .Table .from_pandas (df )
248- pq .write_table (table , file_index_path )
244+ df = pl .concat ([index , df ])
245+ pq .write_table (df .to_arrow (), file_index_path )
249246
250247 @contextmanager
251248 def _log_time (self ):
@@ -358,130 +355,141 @@ def delete_dataset(self) -> None:
358355 if self .bucket_storage == BucketStorage .LOCAL :
359356 shutil .rmtree (self .path )
360357
361- def add (
362- self ,
363- generator : DatasetIterator ,
364- batch_size : int = 1_000_000 ,
365- ) -> None :
366- def _process_arrays (batch_data : List [DatasetRecord ]) -> None :
367- array_paths = set (
368- ann .path for ann in batch_data if isinstance (ann , ArrayAnnotation )
358+ def _process_arrays (self , batch_data : List [DatasetRecord ]) -> None :
359+ array_paths = set (
360+ ann .path for ann in batch_data if isinstance (ann , ArrayAnnotation )
361+ )
362+ if array_paths :
363+ task = self .progress .add_task (
364+ "[magenta]Processing arrays..." , total = len (batch_data )
369365 )
370- if array_paths :
371- task = self .progress .add_task (
372- "[magenta]Processing arrays..." , total = len (batch_data )
373- )
374- self .logger .info ("Checking arrays..." )
375- with self ._log_time ():
376- data_utils .check_arrays (array_paths )
377- self .logger .info ("Generating array UUIDs..." )
378- with self ._log_time ():
379- array_uuid_dict = self .fs .get_file_uuids (
380- array_paths , local = True
381- ) # TODO: support from bucket
382- if self .bucket_storage != BucketStorage .LOCAL :
383- self .logger .info ("Uploading arrays..." )
384- # TODO: support from bucket (likely with a self.fs.copy_dir)
385- with self ._log_time ():
386- arrays_upload_dict = self .fs .put_dir (
387- local_paths = array_paths ,
388- remote_dir = "arrays" ,
389- uuid_dict = array_uuid_dict ,
390- )
391- self .logger .info ("Finalizing paths..." )
392- self .progress .start ()
393- for ann in batch_data :
394- if isinstance (ann , ArrayAnnotation ):
395- if self .bucket_storage != BucketStorage .LOCAL :
396- remote_path = arrays_upload_dict [str (ann .path )] # type: ignore
397- remote_path = (
398- f"{ self .fs .protocol } ://{ self .fs .path / remote_path } "
399- )
400- ann .path = remote_path # type: ignore
401- else :
402- ann .path = ann .path .absolute ()
403- self .progress .update (task , advance = 1 )
404- self .progress .stop ()
405- self .progress .remove_task (task )
406-
407- def _add_process_batch (batch_data : List [DatasetRecord ]) -> None :
408- paths = list (set (data .file for data in batch_data ))
409- self .logger .info ("Generating UUIDs..." )
366+ self .logger .info ("Checking arrays..." )
367+ with self ._log_time ():
368+ data_utils .check_arrays (array_paths )
369+ self .logger .info ("Generating array UUIDs..." )
410370 with self ._log_time ():
411- uuid_dict = self .fs .get_file_uuids (
412- paths , local = True
371+ array_uuid_dict = self .fs .get_file_uuids (
372+ array_paths , local = True
413373 ) # TODO: support from bucket
414374 if self .bucket_storage != BucketStorage .LOCAL :
415- self .logger .info ("Uploading media ..." )
375+ self .logger .info ("Uploading arrays ..." )
416376 # TODO: support from bucket (likely with a self.fs.copy_dir)
417-
418377 with self ._log_time ():
419- self .fs .put_dir (
420- local_paths = paths , remote_dir = "media" , uuid_dict = uuid_dict
378+ arrays_upload_dict = self .fs .put_dir (
379+ local_paths = array_paths ,
380+ remote_dir = "arrays" ,
381+ uuid_dict = array_uuid_dict ,
421382 )
383+ self .logger .info ("Finalizing paths..." )
384+ self .progress .start ()
385+ for ann in batch_data :
386+ if isinstance (ann , ArrayAnnotation ):
387+ if self .bucket_storage != BucketStorage .LOCAL :
388+ remote_path = arrays_upload_dict [str (ann .path )] # type: ignore
389+ remote_path = (
390+ f"{ self .fs .protocol } ://{ self .fs .path / remote_path } "
391+ )
392+ ann .path = remote_path # type: ignore
393+ else :
394+ ann .path = ann .path .absolute ()
395+ self .progress .update (task , advance = 1 )
396+ self .progress .stop ()
397+ self .progress .remove_task (task )
422398
423- task = self .progress .add_task (
424- "[magenta]Processing data..." , total = len (batch_data )
425- )
426-
427- _process_arrays (batch_data )
399+ def _add_process_batch (
400+ self ,
401+ batch_data : List [DatasetRecord ],
402+ pfm : ParquetFileManager ,
403+ index : Optional [pl .DataFrame ],
404+ new_index : Dict [str , List [str ]],
405+ processed_uuids : Set [str ],
406+ ) -> None :
407+ paths = list (set (data .file for data in batch_data ))
408+ self .logger .info ("Generating UUIDs..." )
409+ with self ._log_time ():
410+ uuid_dict = self .fs .get_file_uuids (
411+ paths , local = True
412+ ) # TODO: support from bucket
413+ if self .bucket_storage != BucketStorage .LOCAL :
414+ self .logger .info ("Uploading media..." )
415+ # TODO: support from bucket (likely with a self.fs.copy_dir)
428416
429- self .logger .info ("Saving annotations..." )
430417 with self ._log_time ():
431- self .progress .start ()
432- for ann in batch_data :
433- filepath = ann .file
434- file = filepath .name
435- uuid = uuid_dict [str (filepath )]
436- matched_id = self ._find_filepath_uuid (filepath , index )
437- if matched_id is not None :
438- if matched_id != uuid :
439- # TODO: not sure if this should be an exception or how we should really handle it
440- raise Exception (
441- f"{ filepath } already added to the dataset! Please skip or rename the file."
442- )
443- # TODO: we may also want to check for duplicate uuids to get a one-to-one relationship
444- elif uuid not in new_index ["uuid" ]:
445- new_index ["uuid" ].append (uuid )
446- new_index ["file" ].append (file )
447- new_index ["original_filepath" ].append (str (filepath .absolute ()))
448-
449- self .pfm .write ({"uuid" : uuid , ** ann .to_parquet_dict ()})
450- self .progress .update (task , advance = 1 )
451- self .progress .stop ()
452- self .progress .remove_task (task )
418+ self .fs .put_dir (
419+ local_paths = paths , remote_dir = "media" , uuid_dict = uuid_dict
420+ )
453421
422+ task = self .progress .add_task (
423+ "[magenta]Processing data..." , total = len (batch_data )
424+ )
425+
426+ self ._process_arrays (batch_data )
427+
428+ self .logger .info ("Saving annotations..." )
429+ with self ._log_time ():
430+ self .progress .start ()
431+ for ann in batch_data :
432+ filepath = ann .file
433+ file = filepath .name
434+ uuid = uuid_dict [str (filepath )]
435+ matched_id = self ._find_filepath_uuid (filepath , index )
436+ if matched_id is not None :
437+ if matched_id != uuid :
438+ # TODO: not sure if this should be an exception or how we should really handle it
439+ raise Exception (
440+ f"{ filepath } already added to the dataset! Please skip or rename the file."
441+ )
442+ # TODO: we may also want to check for duplicate uuids to get a one-to-one relationship
443+ elif uuid not in processed_uuids :
444+ new_index ["uuid" ].append (uuid )
445+ new_index ["file" ].append (file )
446+ new_index ["original_filepath" ].append (str (filepath .absolute ()))
447+ processed_uuids .add (uuid )
448+
449+ pfm .write ({"uuid" : uuid , ** ann .to_parquet_dict ()})
450+ self .progress .update (task , advance = 1 )
451+ self .progress .stop ()
452+ self .progress .remove_task (task )
453+
454+ def add (self , generator : DatasetIterator , batch_size : int = 1_000_000 ) -> None :
454455 if self .bucket_storage == BucketStorage .LOCAL :
455- self . pfm = ParquetFileManager ( str ( self .annotations_path ))
456+ annotations_dir = self .annotations_path
456457 else :
457458 self ._make_temp_dir ()
458459 annotations_dir = self .tmp_dir / "annotations"
459460 annotations_dir .mkdir (exist_ok = True , parents = True )
460- self .pfm = ParquetFileManager (str (annotations_dir ))
461461
462462 index = self ._get_file_index ()
463463 new_index = {"uuid" : [], "file" : [], "original_filepath" : []}
464+ processed_uuids = set ()
464465
465466 batch_data : list [DatasetRecord ] = []
466467
467468 classes_per_task : Dict [str , Set [str ]] = defaultdict (set )
468469 num_kpts_per_task : Dict [str , int ] = {}
469470
470- for i , data in enumerate (generator , start = 1 ):
471- record = data if isinstance (data , DatasetRecord ) else DatasetRecord (** data )
472- if record .annotation is not None :
473- classes_per_task [record .annotation .task ].add (record .annotation .class_ )
474- if record .annotation .type_ == "keypoints" :
475- num_kpts_per_task [record .annotation .task ] = len (
476- record .annotation .keypoints
471+ with ParquetFileManager (annotations_dir ) as pfm :
472+ for i , data in enumerate (generator , start = 1 ):
473+ record = (
474+ data if isinstance (data , DatasetRecord ) else DatasetRecord (** data )
475+ )
476+ if record .annotation is not None :
477+ classes_per_task [record .annotation .task ].add (
478+ record .annotation .class_
477479 )
480+ if record .annotation .type_ == "keypoints" :
481+ num_kpts_per_task [record .annotation .task ] = len (
482+ record .annotation .keypoints
483+ )
478484
479- batch_data .append (record )
480- if i % batch_size == 0 :
481- _add_process_batch (batch_data )
482- batch_data = []
485+ batch_data .append (record )
486+ if i % batch_size == 0 :
487+ self ._add_process_batch (
488+ batch_data , pfm , index , new_index , processed_uuids
489+ )
490+ batch_data = []
483491
484- _add_process_batch (batch_data )
492+ self . _add_process_batch (batch_data , pfm , index , new_index , processed_uuids )
485493
486494 _ , curr_classes = self .get_classes ()
487495 for task , classes in classes_per_task .items ():
@@ -490,8 +498,6 @@ def _add_process_batch(batch_data: List[DatasetRecord]) -> None:
490498 self .logger .info (f"Detected new classes for task { task } : { new_classes } " )
491499 self .set_classes (list (classes | old_classes ), task , _remove_tmp_dir = False )
492500
493- self .pfm .close ()
494-
495501 if self .bucket_storage == BucketStorage .LOCAL :
496502 self ._write_index (index , new_index )
497503 else :
@@ -516,7 +522,7 @@ def make_splits(
516522
517523 df = self ._load_df_offline ()
518524 assert df is not None
519- ids = list ( set ( df [ "uuid" ]) )
525+ ids = df . select ( "uuid" ). unique (). get_column ( "uuid" ). to_list ( )
520526 np .random .shuffle (ids )
521527 N = len (ids )
522528 b1 = round (N * ratios [0 ])
0 commit comments