From 12a4719f7cf0942db434d0b394f0c493107e4594 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:18:45 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/dataloaders/_data_splitting.py | 16 +++++++++------- .../contrastivevi/_contrastive_data_splitting.py | 4 ++-- tests/model/test_scvi.py | 4 +++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index e51ffdf264..85072d87f5 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -69,8 +69,9 @@ def validate_data_split( if batch_size is not None: num_of_cells = n_train % batch_size - if ((num_of_cells < 3 and num_of_cells > 0) and - not (num_of_cells == 1 and drop_last is True)): + if (num_of_cells < 3 and num_of_cells > 0) and not ( + num_of_cells == 1 and drop_last is True + ): warnings.warn( f"Last batch will have a small size of {num_of_cells} " f"samples. Consider changing settings.batch_size or batch_size in model.train " @@ -166,8 +167,9 @@ def validate_data_split_with_external_indexing( if batch_size is not None: num_of_cells = n_train % batch_size - if ((num_of_cells < 3 and num_of_cells > 0) - and not (num_of_cells == 1 and drop_last is True)): + if (num_of_cells < 3 and num_of_cells > 0) and not ( + num_of_cells == 1 and drop_last is True + ): warnings.warn( f"Last batch will have a small size of {num_of_cells} " f"samples. Consider changing settings.batch_size or batch_size in model.train " @@ -263,7 +265,7 @@ def __init__( self.validation_size, self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, - self.train_size_is_none + self.train_size_is_none, ) def setup(self, stage: str | None = None): @@ -446,7 +448,7 @@ def setup(self, stage: str | None = None): self.validation_size, self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, - self.train_size_is_none + self.train_size_is_none, ) labeled_permutation = self._labeled_indices @@ -487,7 +489,7 @@ def setup(self, stage: str | None = None): self.validation_size, self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, - self.train_size_is_none + self.train_size_is_none, ) unlabeled_permutation = self._unlabeled_indices diff --git a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py index e48d654b5f..aef13b8b6a 100644 --- a/src/scvi/external/contrastivevi/_contrastive_data_splitting.py +++ b/src/scvi/external/contrastivevi/_contrastive_data_splitting.py @@ -83,7 +83,7 @@ def __init__( self.validation_size, self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, - self.train_size_is_none + self.train_size_is_none, ) self.n_target_train, self.n_target_val = validate_data_split( self.n_target, @@ -91,7 +91,7 @@ def __init__( self.validation_size, self.data_loader_kwargs.pop("batch_size", settings.batch_size), self.drop_last, - self.train_size_is_none + self.train_size_is_none, ) else: # we need to intersect the external indexing given with the bg/target indices diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 7845b9398c..49fe18e531 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -488,7 +488,9 @@ def test_scvi_n_obs_error(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) with pytest.raises(ValueError): model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1 - model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True}) # np.ceil(n_cells * 0.9) % 128 == 1 + model.train( + 1, train_size=0.9, datasplitter_kwargs={"drop_last": True} + ) # np.ceil(n_cells * 0.9) % 128 == 1 model.train(1) assert model.is_trained is True