-
Notifications
You must be signed in to change notification settings - Fork 80
Description
Hey everyone,
We're working with a dataset where items belong to different clusters.
Example:
item_1, cluster_A
item_2, cluster_A
item_3, cluster_B
I was wondering if there's a way to create batches that have a guaranteed mix of items from these clusters? For example, to ensure every single batch is made up of 50% items from cluster A and 50% from cluster B.
In PyTorch DataLoader for in-memory datasets, we handle this with a custom Sampler that you pass to the DataLoader (example code). The basic idea is to get lists of indices for each cluster, shuffle them, and then build the batches by picking one index from each cluster's list in a round-robin fashion. This approach can upsample the minority clusters to create balanced batches even when the original data is not.
Is there any way to do something like this? Any guidance would be much appreciated, and I'd be happy to help contribute if pointed in the right direction
Thanks