Skip to content

feat(dataloaders): Custom dataloader registry support #2932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 165 commits into from
May 13, 2025

Conversation

ori-kron-wis
Copy link
Collaborator

No description provided.

@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.2 milestone Aug 7, 2024
@ori-kron-wis ori-kron-wis self-assigned this Aug 7, 2024
@ori-kron-wis ori-kron-wis linked an issue Aug 7, 2024 that may be closed by this pull request
Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 81.67614% with 129 lines in your changes missing coverage. Please review.

Project coverage is 80.16%. Comparing base (ced87df) to head (67caa96).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/scvi/model/base/_base_model.py 49.63% 69 Missing ⚠️
src/scvi/dataloaders/_custom_dataloders.py 91.59% 29 Missing ⚠️
src/scvi/model/base/_archesmixin.py 82.22% 8 Missing ⚠️
src/scvi/model/base/_training_mixin.py 76.47% 8 Missing ⚠️
src/scvi/model/_scanvi.py 86.66% 4 Missing ⚠️
src/scvi/model/base/_rnamixin.py 93.33% 4 Missing ⚠️
src/scvi/model/base/_vaemixin.py 77.77% 4 Missing ⚠️
src/scvi/data/_utils.py 57.14% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2932      +/-   ##
==========================================
+ Coverage   80.12%   80.16%   +0.04%     
==========================================
  Files         196      197       +1     
  Lines       17570    18156     +586     
==========================================
+ Hits        14078    14555     +477     
- Misses       3492     3601     +109     
Files with missing lines Coverage Δ
src/scvi/dataloaders/__init__.py 100.00% <100.00%> (ø)
src/scvi/dataloaders/_data_splitting.py 95.47% <ø> (ø)
src/scvi/model/_scvi.py 96.42% <100.00%> (+0.51%) ⬆️
src/scvi/model/base/_save_load.py 83.49% <100.00%> (+1.38%) ⬆️
src/scvi/train/_trainingplans.py 85.73% <100.00%> (+0.41%) ⬆️
src/scvi/data/_utils.py 85.00% <57.14%> (-1.13%) ⬇️
src/scvi/model/_scanvi.py 91.17% <86.66%> (-1.85%) ⬇️
src/scvi/model/base/_rnamixin.py 94.17% <93.33%> (-0.36%) ⬇️
src/scvi/model/base/_vaemixin.py 89.13% <77.77%> (+1.17%) ⬆️
src/scvi/model/base/_archesmixin.py 78.20% <82.22%> (+1.31%) ⬆️
... and 3 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@marianogabitto
Copy link
Contributor

Ori, this is not working for me. When I invoke in the notebook:
training_dataloader = (
datamodule.on_before_batch_transfer(batch, None) for batch in datamodule.train_dataloader()
)
I get:
switching torch multiprocessing start method from "fork" to "spawn"
and then errors out....

@marianogabitto
Copy link
Contributor

marianogabitto commented Apr 25, 2025

Ori, all the examples that I am listing below are run by removing the code ".on_before_batch_transfer()". The way I posted before.

  1. When num_workers=0, I can train with low speeds, defined as below. When I fix number of workers, like num_workers=4,12 or 24. The trainer takes forever to initialize and then is even slower than below.

  2. Can you monitor your GPU usage with nvitop or nvtop? Let me tell you my head-to-head comparisons.

  • TIleDB from cell census. I believe that this is reading from S3, so it is never actually copy data to disk.
    It takes 120 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • TIleDB from anndata created from the query . This is reading from a local disk directory.
    It takes 11 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • Regular way of loading anndata into memory. It takes 1.2 sec/it to train. I see GPU activity at 40% all the time.

These led me to believe that we are not loading data into GPU memory fast enough.

  1. I forgot to tell you but the TileDB representative send me this as reference. It is different from the way we run because they launch the processes.
    https://github.com/single-cell-data/TileDB-SOMA-ML/tree/rw/cli/src/tiledbsoma_ml/cli#example-invocation

@ori-kron-wis
Copy link
Collaborator Author

ori-kron-wis commented Apr 29, 2025

Hi @marianogabitto ,
Thanks

  1. I made several changes, and I added the on_before_batch_transform into the class, it is not part of the analysis code now. So if you pulled the branch and reinstalled, you will get errors for running the same code as before.
    I have updated the tutorials (see there), sorry for this.

But im not following on your code, can you share what you are running exactly so we can compare?

  1. Regarding the GPU behavior. I see it the same.
    The data you use matters (the speed enhancement is seen in larger data, not smaller ones).

I dont think GPU is not utilized, it just that the data load is much slower in tiledb, as you said 100 times slower in that sense. so while with adata the data loading is 1s we see almost continuous use of the GPU and in the tileDB s3 there's a 100sec gap between the same GPU usage, so we mistakenly see it underuse.

  1. Num workers is for multiprocessing loading and is a parameter in the torch dataloader. We know that it is also dependent on data size and that we do not always get what we expect from it, specifically, there is overhead with initializing and closing it.
    How do you use it? I will try to check its speed in the custom dataloder context also, in any case, we should run it with the number that best benefits us, it's not a magic thing that helps each time.

  2. I think the common practice is a 1 GPU running on a notebook. We need to make sure this is working, and other scenarios will follow.
    But having said that, the scripts of running with DDP compared to running them in notebooks can be very different, and we need to test all possibilities. We might find that we need to run it as a script like this reference you gave. Will check.

  3. I added SCANVI to the tutorials as well, some issues still exists in the prediction part for tiledb

@marianogabitto
Copy link
Contributor

Ori,
I am testing updates in 12 hours. Sorry for the delay.
One more thing in the meantime. It will be great to expose the scvi Anndata DataLoader as an example of what is going on internally. This code does not work because the BatchDistributedSampler is not outputting the samples with the correct dimensions (In DDP), but if you help me solve it, it will be great.

Code

from scvi.dataloaders import DataSplitter

scvi.model.SCVI.setup_anndata(adata, batch_key="batch", categorical_covariate_keys=['cell_type', 'donor'])
ad_manager = scvi.model.SCVI._get_most_recent_anndata_manager(adata, required=True)

model = scvi.model.SCVI(
registry=ad_manager._registry,
gene_likelihood="nb",
encode_covariates=False,
)

ad_manager.adata = adata
dl = DataSplitter(ad_manager, train_size=0.9, pin_memory=True, num_workers=2, persistent_workers=True)#, prefetch_factor=2)
dl.setup()

model.train(
datamodule=datamodule,
max_epochs=10,
batch_size=128,
train_size=0.9,
early_stopping=False,
accelerator="gpu", devices=-1, strategy="ddp_find_unused_parameters_true",
)

@ori-kron-wis
Copy link
Collaborator Author

Hi @marianogabitto ,
Your code above should work in multiGPU settings, just add distributed_sampler=True to the DataSplitter call

Besides that I made several other updates for this PR, census/lamin custom dataloaders should be working now for scvi/scnavi/scarches/load/save/multiGPU/covariates integration

@ori-kron-wis ori-kron-wis merged commit c4cab3b into main May 13, 2025
17 of 18 checks passed
meeseeksmachine pushed a commit to meeseeksmachine/scvi-tools that referenced this pull request May 13, 2025
ori-kron-wis added a commit that referenced this pull request May 13, 2025
#3318)

Backport PR #2932: Custom dataloader registry support

Co-authored-by: Ori Kronfeld <[email protected]>
@ori-kron-wis ori-kron-wis changed the title Custom dataloader registry support feat(dataloaders): Custom dataloader registry support May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom_dataloader PR 2932 on-merge: backport to 1.3.x on-merge: backport to 1.3.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix custom dataloader registry
3 participants