diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 92235ea89..07c42d389 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -45,7 +45,12 @@ def __init__(self, dimension: int, kvstore=None, bank_id: str = None): self.chunk_by_index = {} self.kvstore = kvstore self.bank_id = bank_id - self.initialize() + + @classmethod + async def create(cls, dimension: int, kvstore=None, bank_id: str = None): + instance = cls(dimension, kvstore, bank_id) + await instance.initialize() + return instance async def initialize(self) -> None: if not self.kvstore: @@ -132,7 +137,10 @@ async def initialize(self) -> None: for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore) + bank=bank, + index=await FaissIndex.create( + ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier + ), ) self.cache[bank.identifier] = index @@ -158,7 +166,9 @@ async def register_memory_bank( # Store in cache index = BankWithIndex( bank=memory_bank, - index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION, self.kvstore), + index=await FaissIndex.create( + ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier + ), ) self.cache[memory_bank.identifier] = index