-
Notifications
You must be signed in to change notification settings - Fork 132
Description
Hello all, thanks for the nice library. I need help with the following problem:
CAGRA indices built with dtype="float16"
successfully save to cache but fail to load from cache with error: Unsupported dtype in file
. The same code works correctly with dtype="float32"
.
Steps/Code to reproduce bug
Minimal reproduction (tested on cuVS 25.10.00):
import torch
from cuvs import neighbors as ann
# Create sample float16 dataset
dataset = torch.randn(1000, 128, dtype=torch.float16, device='cuda')
# Build and save CAGRA index
index_params = ann.cagra.IndexParams(metric="inner_product")
index = ann.cagra.build(index_params=index_params, dataset=dataset)
ann.cagra.save("/tmp/test_fp16.index", index, include_dataset=True)
# ✓ Saves successfully
# Try to load - THIS FAILS
loaded_index = ann.cagra.load("/tmp/test_fp16.index")
# ✗ Raises: RAFT failure - Unsupported dtype
Test results comparison:
Operation | float16 | float32 |
---|---|---|
Build index | ✓ | ✓ |
Save to cache | ✓ (0.49 MB) | ✓ (0.73 MB) |
Load from cache | ✗ | ✓ |
Expected behavior
The cached CAGRA index should load successfully with float16 dtype, similar to float32 behavior. Both saving and loading should work consistently for all supported dtypes.
Environment details (please complete the following information):
- Environment location: Bare-metal (HPC cluster with NVIDIA Grace Hopper)
- Method of cuVS install: pip (via PyPI)
- cuVS version: 25.10.00
- CUDA version: 12.8
- Python version: 3.13.5
- Platform: Linux aarch64 (ARM64)
- PyTorch version: 2.7.1+cu128
Additional context
Error message:
RAFT failure at file=/__w/cuvs/cuvs/cpp/src/neighbors/cagra_c.cpp line=754:
Unsupported dtype in file /tmp/cagra_test/test_fp16.index
Stack trace:
Obtained 7 stack frames
#1 in libcuvs_c.so: raft::logic_error::logic_error(...) +0x78
#2 in libcuvs_c.so(+0x84064)
#3 in libcuvs_c.so: cuvsCagraDeserialize +0x1c [0x4001faeb474c]
#4 in cagra.cpython-312-aarch64-linux-gnu.so(+0x267f0)
#5 in all_neighbors.cpython-312-aarch64-linux-gnu.so(+0x13e9c)
#6 in resources.cpython-312-aarch64-linux-gnu.so(+0x17fd4)
#7 in python3: _PyObject_MakeTpCall +0xa0
Error originates from cuvsCagraDeserialize
in cagra_c.cpp:754
. I believe there may be a dtype check in deserialization that's missing float16 support.
For now, my workaround has been using float32, which works but increases cache size by 49% (0.73 MB vs 0.49 MB) and doubles memory usage. Another option is just disabling cache, but the index build time is large.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status