File tree Expand file tree Collapse file tree 1 file changed +12
-6
lines changed Expand file tree Collapse file tree 1 file changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -128,17 +128,23 @@ def __getitem__(self, idx: int) -> LuxonisLoaderOutput:
128128
129129 indices = [idx ]
130130 if self .augmentations .is_batched :
131- other_indices = [i for i in range (len (self )) if i != idx ]
132131 if self .augmentations .aug_batch_size > len (self ):
133132 warnings .warn (
134133 f"Augmentations batch_size ({ self .augmentations .aug_batch_size } ) is larger than dataset size ({ len (self )} ), samples will include repetitions."
135134 )
136- random_fun = random .choices
135+ other_indices = [i for i in range (len (self )) if i != idx ]
136+ picked_indices = random .choices (
137+ other_indices , k = self .augmentations .aug_batch_size - 1
138+ )
137139 else :
138- random_fun = random .sample
139- picked_indices = random_fun (
140- other_indices , k = self .augmentations .aug_batch_size - 1
141- )
140+ picked_indices = set ()
141+ max_val = len (self )
142+ while len (picked_indices ) < self .augmentations .aug_batch_size - 1 :
143+ rand_idx = random .randint (0 , max_val - 1 )
144+ if rand_idx != idx and rand_idx not in picked_indices :
145+ picked_indices .add (rand_idx )
146+ picked_indices = list (picked_indices )
147+
142148 indices .extend (picked_indices )
143149
144150 out_dict : Dict [str , Tuple [np .ndarray , LabelType ]] = {}
You can’t perform that action at this time.
0 commit comments