Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 15, 2024
1 parent 2baf11d commit 1bc5805
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 29 deletions.
16 changes: 5 additions & 11 deletions src/scvi/external/decipher/_components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(

# The multiple outputs are computed as a single output layer, and then split
indices = np.concatenate(([0], np.cumsum(self.output_dims)))
self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])]
self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)]

# Create masked layers
deep_context_dim = self.context_dim if self.deep_context_injection else 0
Expand All @@ -63,21 +63,15 @@ def __init__(
batch_norms.append(nn.BatchNorm1d(hidden_dims[0]))
for i in range(1, len(hidden_dims)):
layers.append(
torch.nn.Linear(
hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]
)
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i])
)
batch_norms.append(nn.BatchNorm1d(hidden_dims[i]))

layers.append(
torch.nn.Linear(
hidden_dims[-1] + deep_context_dim, self.output_total_dim
)
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim)
)
else:
layers.append(
torch.nn.Linear(input_dim + context_dim, self.output_total_dim)
)
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim))

self.layers = torch.nn.ModuleList(layers)

Expand Down
8 changes: 2 additions & 6 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def setup_anndata(
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

Expand Down Expand Up @@ -113,9 +111,7 @@ def get_latent_representation(
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
Expand Down
12 changes: 3 additions & 9 deletions src/scvi/external/decipher/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def device(self):
return self._dummy_param.device

@staticmethod
def _get_fn_args_from_batch(
tensor_dict: dict[str, torch.Tensor]
) -> Iterable | dict:
def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict:
x = tensor_dict[REGISTRY_KEYS.X_KEY]
return (x,), {}

Expand Down Expand Up @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor):
self.theta + self._epsilon
)
# noinspection PyUnresolvedReferences
x_dist = dist.NegativeBinomial(
total_count=self.theta + self._epsilon, logits=logit
)
x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit)
pyro.sample("x", x_dist.to_event(1), obs=x)

@auto_move_data
Expand Down Expand Up @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5):
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(x)
log_weights.append(
model_trace.log_prob_sum() - guide_trace.log_prob_sum()
)
log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

finally:
self.beta = old_beta
Expand Down
4 changes: 1 addition & 3 deletions src/scvi/external/decipher/_trainingplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __init__(
optim_kwargs.update({"lr": 5e-3})
if "weight_decay" not in optim_kwargs.keys():
optim_kwargs.update({"weight_decay": 1e-4})
self.optim = (
pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim
)
self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim
# We let SVI take care of all optimization
self.automatic_optimization = False

Expand Down

0 comments on commit 1bc5805

Please sign in to comment.