Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 14, 2024
1 parent 5c80cc8 commit 12a4719
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
16 changes: 9 additions & 7 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def validate_data_split(

if batch_size is not None:
num_of_cells = n_train % batch_size
if ((num_of_cells < 3 and num_of_cells > 0) and
not (num_of_cells == 1 and drop_last is True)):
if (num_of_cells < 3 and num_of_cells > 0) and not (
num_of_cells == 1 and drop_last is True
):
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
Expand Down Expand Up @@ -166,8 +167,9 @@ def validate_data_split_with_external_indexing(

if batch_size is not None:
num_of_cells = n_train % batch_size
if ((num_of_cells < 3 and num_of_cells > 0)
and not (num_of_cells == 1 and drop_last is True)):
if (num_of_cells < 3 and num_of_cells > 0) and not (
num_of_cells == 1 and drop_last is True
):
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
Expand Down Expand Up @@ -263,7 +265,7 @@ def __init__(
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none
self.train_size_is_none,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -446,7 +448,7 @@ def setup(self, stage: str | None = None):
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none
self.train_size_is_none,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -487,7 +489,7 @@ def setup(self, stage: str | None = None):
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none
self.train_size_is_none,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def __init__(
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none
self.train_size_is_none,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none
self.train_size_is_none,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand Down
4 changes: 3 additions & 1 deletion tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def test_scvi_n_obs_error(n_latent: int = 5):
model = SCVI(adata, n_latent=n_latent)
with pytest.raises(ValueError):
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True}) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(
1, train_size=0.9, datasplitter_kwargs={"drop_last": True}
) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(1)
assert model.is_trained is True

Expand Down

0 comments on commit 12a4719

Please sign in to comment.