Skip to content

Potential CUDA Context Pollution from GpuIndex: Interferes with Subsequent PyTorch CUDA Kernels (Shared Memory Allocation Failure) #4638

@sunstarchan

Description

@sunstarchan

Is your feature request related to a problem? Please describe.
No, this is a potential bug report/compatibility issue.

Describe the bug
When using faiss.index_cpu_to_all_gpus() to create a GPU index followed by a search() operation, it appears to "pollute" the current CUDA context in a way that prevents subsequent PyTorch-based CUDA kernels from setting their required shared memory size. Specifically, this causes a RuntimeError in libraries like gsplat (a PyTorch extension for Gaussian Splatting rendering) when they attempt to allocate dynamic shared memory via cudaFuncSetAttribute for their rasterization kernels.

The error occurs even after explicitly deleting the FAISS GPU index object (del index_gpu) and calling torch.cuda.empty_cache(). Lowering the shared memory request in the affected kernel (e.g., gsplat's tile_size=8) also fails to resolve it, suggesting a persistent change to the CUDA context's function attributes.

This issue only manifests when FAISS GPU operations precede the conflicting kernel; running gsplat first works fine.

Steps To Reproduce

  1. Set up a Python environment with PyTorch (tested with 2.x), FAISS-GPU (0.9.0+), and gsplat (1.0+).
  2. Create and train a simple IVF index on GPU.
  3. Perform a search on the GPU index.
  4. Attempt to run a gsplat rasterization (or any PyTorch kernel requesting ~7168 bytes of shared memory).

Here's a minimal reproducible example (assuming you have gsplat installed; if not, replace the gsplat call with a custom PyTorch CUDA kernel that requests shared memory):

import faiss
import numpy as np
import torch
import gsplat  # Or import torch; for a minimal test, see note below

# Sample data (3D points, ~10k for quick repro)
n_points = 10000
d = 3
xyz = np.random.rand(n_points, d).astype('float32')

# Step 1: FAISS GPU setup and search
nlist = int(np.sqrt(n_points))
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.train(xyz)
index_gpu = faiss.index_cpu_to_all_gpus(index)
index_gpu.add(xyz)
dist, _ = index_gpu.search(xyz, k=6)  # This triggers the pollution

# Cleanup attempt (doesn't help)
del index_gpu
torch.cuda.empty_cache()

# Step 2: Now try gsplat (fails with shared memory error)
device = 'cuda'
H, W = 512, 512
means3D = torch.randn(100, 3, device=device)  # Dummy Gaussians
# ... (other dummy params for gsplat.rasterization)
try:
    rendered, _, _ = gsplat.rasterization(
        means3D=means3D,
        # ... fill in other required params (viewmatrix, etc.) with dummies
        image_height=H,
        image_width=W,
        tile_size=8,  # Even with reduced tile_size, it fails
    )
    print("gsplat succeeded")
except RuntimeError as e:
    if "shared memory" in str(e):
        print(f"FAILED: {e}")
    else:
        raise

# Note: If gsplat isn't available, test with a minimal PyTorch CUDA extension that calls
# cudaFuncSetAttribute for 7168 bytes dynamic shared memory – it will fail similarly.

Expected behavior: Both FAISS search and gsplat rasterization succeed without errors.
Actual behavior: gsplat fails with:

RuntimeError: Failed to set maximum shared memory size (requested 7168 bytes), try lowering tile_size.

Even with tile_size=4 (reducing request to ~448 bytes).

Environment

  • FAISS version: 1.8.0 (or specify your version; installed via conda install faiss-gpu -c pytorch)
  • Python version: 3.10
  • PyTorch version: 2.1.0+cu118
  • CUDA version: 12.1 (driver 535.xx)
  • GPU: NVIDIA RTX 4090 (24GB VRAM; sharedMemPerBlock: 49152 bytes)
  • OS: Ubuntu 22.04
  • Other libraries: gsplat 1.0.0

Additional context

  • This seems related to FAISS GPU kernels calling cudaFuncSetAttribute with a conflicting cudaFuncAttributeMaxDynamicSharedMemorySize, which persists in the CUDA context after the FAISS object is deleted.
  • No issues found in FAISS GitHub for "shared memory" + "gsplat" or "context pollution". Similar problems have been reported in other libs (e.g., nerfstudio/gsplat#266), but not tied to FAISS.
  • Workaround: Run FAISS in a separate multiprocessing subprocess (with spawn start method) to isolate contexts, but this adds overhead for large datasets.
  • Happy to provide more details, full repro repo, or test patches!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions