Skip to content

Supporting samplers for creating balanced batches #717

@karinazad

Description

@karinazad

Hey everyone,

We're working with a dataset where items belong to different clusters.

Example:
item_1, cluster_A
item_2, cluster_A
item_3, cluster_B

I was wondering if there's a way to create batches that have a guaranteed mix of items from these clusters? For example, to ensure every single batch is made up of 50% items from cluster A and 50% from cluster B.

In PyTorch DataLoader for in-memory datasets, we handle this with a custom Sampler that you pass to the DataLoader (example code). The basic idea is to get lists of indices for each cluster, shuffle them, and then build the batches by picking one index from each cluster's list in a round-robin fashion. This approach can upsample the minority clusters to create balanced batches even when the original data is not.

Is there any way to do something like this? Any guidance would be much appreciated, and I'd be happy to help contribute if pointed in the right direction

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions