Skip to content

[ENH]: Plumb prefix for spann and hnsw segment #4753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: 06-03-_enh_plumb_prefix_path_all_the_way_to_the_bf_writer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 67 additions & 23 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,30 @@ pub struct HnswIndexProvider {
pub write_mutex: AysncPartitionedMutex<IndexUuid>,
}

pub struct HnswIndexFlusher {
pub provider: HnswIndexProvider,
pub prefix_path: String,
pub index_id: IndexUuid,
}

#[derive(Clone)]
pub struct HnswIndexRef {
pub inner: Arc<RwLock<HnswIndex>>,
pub inner: Arc<RwLock<DistributedHnswInner>>,
}

pub struct DistributedHnswInner {
pub hnsw_index: HnswIndex,
pub prefix_path: String,
}

impl Debug for HnswIndexRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswIndexRef")
.field("id", &self.inner.read().id)
.field("dimensionality", &self.inner.read().dimensionality())
.field("id", &self.inner.read().hnsw_index.id)
.field(
"dimensionality",
&self.inner.read().hnsw_index.dimensionality(),
)
.finish_non_exhaustive()
}
}
Expand Down Expand Up @@ -99,10 +113,12 @@ impl Configurable<(HnswProviderConfig, Storage)> for HnswIndexProvider {
impl chroma_cache::Weighted for HnswIndexRef {
fn weight(&self) -> usize {
let index = self.inner.read();
if index.len() == 0 {
if index.hnsw_index.is_empty() {
return 1;
}
let bytes = index.len() * std::mem::size_of::<f32>() * index.dimensionality() as usize;
let bytes = index.hnsw_index.len()
* std::mem::size_of::<f32>()
* index.hnsw_index.dimensionality() as usize;
let as_mb = bytes / 1024 / 1024;
if as_mb == 0 {
1
Expand Down Expand Up @@ -135,7 +151,7 @@ impl HnswIndexProvider {
match self.cache.get(cache_key).await.ok().flatten() {
Some(index) => {
let index_with_lock = index.inner.read();
if index_with_lock.id == *index_id {
if index_with_lock.hnsw_index.id == *index_id {
// Clone is cheap because we are just cloning the Arc.
Some(index.clone())
} else {
Expand All @@ -146,9 +162,12 @@ impl HnswIndexProvider {
}
}

// TODO(rohitcp): Use HNSW_INDEX_S3_PREFIX.
fn format_key(&self, id: &IndexUuid, file: &str) -> String {
format!("hnsw/{}/{}", id, file)
fn format_key(&self, prefix_path: &str, id: &IndexUuid, file: &str) -> String {
// For legacy collections, prefix_path will be empty.
if prefix_path.is_empty() {
return format!("hnsw/{}/{}", id, file);
}
format!("{}/hnsw/{}/{}", prefix_path, id, file)
}

pub async fn fork(
Expand All @@ -158,6 +177,7 @@ impl HnswIndexProvider {
dimensionality: i32,
distance_function: DistanceFunction,
ef_search: usize,
prefix_path: &str,
) -> Result<HnswIndexRef, Box<HnswIndexProviderForkError>> {
// We take a lock here to synchronize concurrent forks of the same index.
// Otherwise, we could end up with a corrupted index since the filesystem
Expand All @@ -176,7 +196,7 @@ impl HnswIndexProvider {
}

match self
.load_hnsw_segment_into_directory(source_id, &new_storage_path)
.load_hnsw_segment_into_directory(source_id, &new_storage_path, prefix_path)
.await
{
Ok(_) => {}
Expand All @@ -203,7 +223,10 @@ impl HnswIndexProvider {
None => match HnswIndex::load(storage_path_str, &index_config, ef_search, new_id) {
Ok(index) => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
inner: Arc::new(RwLock::new(DistributedHnswInner {
hnsw_index: index,
prefix_path: prefix_path.to_string(),
})),
};
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
Expand Down Expand Up @@ -251,14 +274,15 @@ impl HnswIndexProvider {
&self,
source_id: &IndexUuid,
index_storage_path: &Path,
prefix_path: &str,
) -> Result<(), Box<HnswIndexProviderFileError>> {
// Fetch the files from storage and put them in the index storage path.
for file in FILES.iter() {
let s3_fetch_span =
tracing::trace_span!(parent: Span::current(), "Read bytes from s3", file = file);
let buf = s3_fetch_span
.in_scope(|| async {
let key = self.format_key(source_id, file);
let key = self.format_key(prefix_path, source_id, file);
tracing::info!("Loading hnsw index file: {} into directory", key);
let bytes_res = self
.storage
Expand Down Expand Up @@ -296,6 +320,7 @@ impl HnswIndexProvider {
dimensionality: i32,
distance_function: DistanceFunction,
ef_search: usize,
prefix_path: &str,
) -> Result<HnswIndexRef, Box<HnswIndexProviderOpenError>> {
// This is the double checked locking pattern. This avoids taking the
// async mutex in the common case where the index is already in the cache.
Expand All @@ -320,7 +345,7 @@ impl HnswIndexProvider {
}

match self
.load_hnsw_segment_into_directory(id, &index_storage_path)
.load_hnsw_segment_into_directory(id, &index_storage_path, prefix_path)
.await
{
Ok(_) => {}
Expand All @@ -347,7 +372,10 @@ impl HnswIndexProvider {
None => match HnswIndex::load(index_storage_path_str, &index_config, ef_search, *id) {
Ok(index) => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
inner: Arc::new(RwLock::new(DistributedHnswInner {
hnsw_index: index,
prefix_path: prefix_path.to_string(),
})),
};
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
Expand Down Expand Up @@ -379,6 +407,7 @@ impl HnswIndexProvider {
// Cases
// A query comes in and the index is in the cache -> we can query the index based on segment files id (Same as compactor case 3 where we have the index)
// A query comes in and the index is not in the cache -> we need to load the index from s3 based on the segment files id
#[allow(clippy::too_many_arguments)]
pub async fn create(
&self,
cache_key: &CacheKey,
Expand All @@ -387,6 +416,7 @@ impl HnswIndexProvider {
ef_search: usize,
dimensionality: i32,
distance_function: DistanceFunction,
prefix_path: &str,
) -> Result<HnswIndexRef, Box<HnswIndexProviderCreateError>> {
let id = IndexUuid(Uuid::new_v4());
// We take a lock here to synchronize concurrent creates of the same index.
Expand Down Expand Up @@ -425,7 +455,10 @@ impl HnswIndexProvider {
Some(index) => Ok(index.clone()),
None => {
let index = HnswIndexRef {
inner: Arc::new(RwLock::new(index)),
inner: Arc::new(RwLock::new(DistributedHnswInner {
hnsw_index: index,
prefix_path: prefix_path.to_string(),
})),
};
self.cache.insert(*cache_key, index.clone()).await;
Ok(index)
Expand All @@ -434,7 +467,7 @@ impl HnswIndexProvider {
}

pub fn commit(&self, index: HnswIndexRef) -> Result<(), Box<dyn ChromaError>> {
match index.inner.write().save() {
match index.inner.write().hnsw_index.save() {
Ok(_) => {}
Err(e) => {
return Err(Box::new(HnswIndexProviderCommitError::HnswSaveError(e)));
Expand All @@ -444,11 +477,15 @@ impl HnswIndexProvider {
Ok(())
}

pub async fn flush(&self, id: &IndexUuid) -> Result<(), Box<HnswIndexProviderFlushError>> {
pub async fn flush(
&self,
prefix_path: &str,
id: &IndexUuid,
) -> Result<(), Box<HnswIndexProviderFlushError>> {
let index_storage_path = self.temporary_storage_path.join(id.to_string());
for file in FILES.iter() {
let file_path = index_storage_path.join(file);
let key = self.format_key(id, file);
let key = self.format_key(prefix_path, id, file);
let res = self
.storage
.put_file(
Expand Down Expand Up @@ -641,6 +678,7 @@ mod tests {
let dimensionality = 128;
let distance_function = DistanceFunction::Euclidean;
let default_hnsw_params = InternalHnswConfiguration::default();
let prefix_path = "";
let created_index = provider
.create(
&collection_id,
Expand All @@ -649,10 +687,11 @@ mod tests {
default_hnsw_params.ef_search,
dimensionality,
distance_function.clone(),
prefix_path,
)
.await
.unwrap();
let created_index_id = created_index.inner.read().id;
let created_index_id = created_index.inner.read().hnsw_index.id;

let forked_index = provider
.fork(
Expand All @@ -661,10 +700,11 @@ mod tests {
dimensionality,
distance_function,
default_hnsw_params.ef_search,
prefix_path,
)
.await
.unwrap();
let forked_index_id = forked_index.inner.read().id;
let forked_index_id = forked_index.inner.read().hnsw_index.id;

assert_ne!(created_index_id, forked_index_id);
}
Expand All @@ -684,6 +724,7 @@ mod tests {
let dimensionality = 2;
let distance_function = DistanceFunction::Euclidean;
let default_hnsw_params = InternalHnswConfiguration::default();
let prefix_path = "";
let created_index = provider
.create(
&collection_id,
Expand All @@ -692,18 +733,20 @@ mod tests {
default_hnsw_params.ef_search,
dimensionality,
distance_function.clone(),
prefix_path,
)
.await
.unwrap();
created_index
.inner
.write()
.hnsw_index
.add(1, &[1.0, 3.0])
.expect("Expected to add");
let created_index_id = created_index.inner.read().id;
let created_index_id = created_index.inner.read().hnsw_index.id;
provider.commit(created_index).expect("Expected to commit");
provider
.flush(&created_index_id)
.flush(prefix_path, &created_index_id)
.await
.expect("Expected to flush");
// clear the cache.
Expand All @@ -719,10 +762,11 @@ mod tests {
dimensionality,
distance_function,
default_hnsw_params.ef_search,
prefix_path,
)
.await
.expect("Expect open to succeed");
let opened_index_id = open_index.inner.read().id;
let opened_index_id = open_index.inner.read().hnsw_index.id;

assert_eq!(opened_index_id, created_index_id);
check_purge_successful(storage_dir.clone()).await;
Expand Down
Loading
Loading