@@ -57,15 +57,16 @@ def _class_id_for(self, split: str, class_name: str) -> int:
5757 return cmap [class_name ]
5858
5959 def _write_classes_csv (self , split : str , split_dir : Path ) -> None :
60- cmap = self .split_class_maps .get (split )
61- if cmap is None or len (cmap ) == 0 :
62- return
60+ self ._ensure_background (split )
6361
64- # Write in ascending id order so the CSV row order matches list indexing
65- items_by_id = sorted (cmap .items (), key = lambda kv : kv [1 ])
62+ cmap = self .split_class_maps [split ]
6663
64+ # Ensure directory exists
6765 csv_path = split_dir / "_classes.csv"
6866 csv_path .parent .mkdir (parents = True , exist_ok = True )
67+
68+ items_by_id = sorted (cmap .items (), key = lambda kv : kv [1 ])
69+
6970 with csv_path .open ("w" , newline = "" , encoding = "utf-8" ) as f :
7071 w = csv .writer (f )
7172 w .writerow ([self .ID_COL , self .CLASS_COL ])
@@ -84,6 +85,10 @@ def export(self, prepared_ldf: PreparedLDF) -> None:
8485
8586 copied_pairs : set [tuple [Path , str ]] = set ()
8687
88+ for split in ("train" , "val" , "test" ):
89+ split_dir = self ._get_data_path (self .output_path , split , self .part )
90+ split_dir .mkdir (parents = True , exist_ok = True )
91+
8792 for key , entry in grouped :
8893 file_name , group_id = cast (tuple [str , Any ], key )
8994 file_path = Path (str (file_name ))
0 commit comments