Skip to content

Commit 12a4719

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5c80cc8 commit 12a4719

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

src/scvi/dataloaders/_data_splitting.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ def validate_data_split(
6969

7070
if batch_size is not None:
7171
num_of_cells = n_train % batch_size
72-
if ((num_of_cells < 3 and num_of_cells > 0) and
73-
not (num_of_cells == 1 and drop_last is True)):
72+
if (num_of_cells < 3 and num_of_cells > 0) and not (
73+
num_of_cells == 1 and drop_last is True
74+
):
7475
warnings.warn(
7576
f"Last batch will have a small size of {num_of_cells} "
7677
f"samples. Consider changing settings.batch_size or batch_size in model.train "
@@ -166,8 +167,9 @@ def validate_data_split_with_external_indexing(
166167

167168
if batch_size is not None:
168169
num_of_cells = n_train % batch_size
169-
if ((num_of_cells < 3 and num_of_cells > 0)
170-
and not (num_of_cells == 1 and drop_last is True)):
170+
if (num_of_cells < 3 and num_of_cells > 0) and not (
171+
num_of_cells == 1 and drop_last is True
172+
):
171173
warnings.warn(
172174
f"Last batch will have a small size of {num_of_cells} "
173175
f"samples. Consider changing settings.batch_size or batch_size in model.train "
@@ -263,7 +265,7 @@ def __init__(
263265
self.validation_size,
264266
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
265267
self.drop_last,
266-
self.train_size_is_none
268+
self.train_size_is_none,
267269
)
268270

269271
def setup(self, stage: str | None = None):
@@ -446,7 +448,7 @@ def setup(self, stage: str | None = None):
446448
self.validation_size,
447449
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
448450
self.drop_last,
449-
self.train_size_is_none
451+
self.train_size_is_none,
450452
)
451453

452454
labeled_permutation = self._labeled_indices
@@ -487,7 +489,7 @@ def setup(self, stage: str | None = None):
487489
self.validation_size,
488490
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
489491
self.drop_last,
490-
self.train_size_is_none
492+
self.train_size_is_none,
491493
)
492494

493495
unlabeled_permutation = self._unlabeled_indices

src/scvi/external/contrastivevi/_contrastive_data_splitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ def __init__(
8383
self.validation_size,
8484
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
8585
self.drop_last,
86-
self.train_size_is_none
86+
self.train_size_is_none,
8787
)
8888
self.n_target_train, self.n_target_val = validate_data_split(
8989
self.n_target,
9090
self.train_size,
9191
self.validation_size,
9292
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
9393
self.drop_last,
94-
self.train_size_is_none
94+
self.train_size_is_none,
9595
)
9696
else:
9797
# we need to intersect the external indexing given with the bg/target indices

tests/model/test_scvi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,9 @@ def test_scvi_n_obs_error(n_latent: int = 5):
488488
model = SCVI(adata, n_latent=n_latent)
489489
with pytest.raises(ValueError):
490490
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
491-
model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True}) # np.ceil(n_cells * 0.9) % 128 == 1
491+
model.train(
492+
1, train_size=0.9, datasplitter_kwargs={"drop_last": True}
493+
) # np.ceil(n_cells * 0.9) % 128 == 1
492494
model.train(1)
493495
assert model.is_trained is True
494496

0 commit comments

Comments
 (0)