diff --git a/chromadb/utils/batch_utils.py b/chromadb/utils/batch_utils.py index f3881bc6a03..9fa210cbcfb 100644 --- a/chromadb/utils/batch_utils.py +++ b/chromadb/utils/batch_utils.py @@ -1,3 +1,4 @@ +import random from typing import Optional, Tuple, List from chromadb.api import BaseAPI from chromadb.api.types import ( @@ -18,19 +19,22 @@ def create_batches( _batches: List[ Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]] ] = [] - if len(ids) > api.get_max_batch_size(): - # create split batches - for i in range(0, len(ids), api.get_max_batch_size()): + max_batch_size = api.get_max_batch_size() + offset = 0 + if len(ids) > max_batch_size: + while offset < len(ids): + batch_size = random.randint(1, max_batch_size) _batches.append( ( # type: ignore - ids[i : i + api.get_max_batch_size()], - embeddings[i : i + api.get_max_batch_size()] + ids[offset : offset + batch_size], + embeddings[offset : offset + batch_size] if embeddings else None, - metadatas[i : i + api.get_max_batch_size()] if metadatas else None, - documents[i : i + api.get_max_batch_size()] if documents else None, + metadatas[offset : offset + batch_size] if metadatas else None, + documents[offset : offset + batch_size] if documents else None, ) ) + offset += batch_size else: _batches.append((ids, embeddings, metadatas, documents)) # type: ignore return _batches