Skip to content

Commit 4753888

Browse files
klemen1999kozlov721
authored andcommitted
Speedup in picking batched aug indices (#183)
1 parent 88af159 commit 4753888

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

luxonis_ml/data/loaders/luxonis_loader.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,23 @@ def __getitem__(self, idx: int) -> LuxonisLoaderOutput:
128128

129129
indices = [idx]
130130
if self.augmentations.is_batched:
131-
other_indices = [i for i in range(len(self)) if i != idx]
132131
if self.augmentations.aug_batch_size > len(self):
133132
warnings.warn(
134133
f"Augmentations batch_size ({self.augmentations.aug_batch_size}) is larger than dataset size ({len(self)}), samples will include repetitions."
135134
)
136-
random_fun = random.choices
135+
other_indices = [i for i in range(len(self)) if i != idx]
136+
picked_indices = random.choices(
137+
other_indices, k=self.augmentations.aug_batch_size - 1
138+
)
137139
else:
138-
random_fun = random.sample
139-
picked_indices = random_fun(
140-
other_indices, k=self.augmentations.aug_batch_size - 1
141-
)
140+
picked_indices = set()
141+
max_val = len(self)
142+
while len(picked_indices) < self.augmentations.aug_batch_size - 1:
143+
rand_idx = random.randint(0, max_val - 1)
144+
if rand_idx != idx and rand_idx not in picked_indices:
145+
picked_indices.add(rand_idx)
146+
picked_indices = list(picked_indices)
147+
142148
indices.extend(picked_indices)
143149

144150
out_dict: Dict[str, Tuple[np.ndarray, LabelType]] = {}

0 commit comments

Comments
 (0)