Skip to content

Commit 68c50fd

Browse files
Refactor extvelo function to improve code clarity and functionality
- Adjusted function signature for extvelo to enhance readability. - Added missing comma in the parameters dictionary for consistency. - Included additional import for data preprocessing to support velovi method. - Ensured adata is copied before processing to maintain data integrity.
1 parent 6a22e13 commit 68c50fd

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

dynamo/tools/_enum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,5 @@ def _format(cls, value) -> str:
7373
class ModeEnum(str, ErrorFormatterABC, PrettyEnum, metaclass=ABCEnumMeta): # noqa: D101
7474
def _generate_next_value_(self, start, count, last_values):
7575
return str(self).lower()
76+
77+

dynamo/tools/_extvelo.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def extvelo(
1515
latentvelo_VAE_kwargs: dict = {},
1616
param_name_key: str = 'tmp/latentvelo_params',
1717
**kwargs,
18-
) -> AnnData:
18+
):
1919
if method == "celldancer":
2020
#tested successfully
2121
from ..external.celldancer.utilities import adata_to_df_with_embed
@@ -41,7 +41,7 @@ def extvelo(
4141
'model': 'stochastic','est_method': 'gmm','has_splicing': True,
4242
'has_labeling': False,'splicing_labeling': False,
4343
'has_protein': False,'use_smoothed': True,'NTR_vel': False,
44-
'log_unnormalized': True,'fraction_for_deg': False
44+
'log_unnormalized': True,'fraction_for_deg': False,
4545
}
4646
return cellDancer_df,adata
4747
elif method == "latentvelo":
@@ -79,10 +79,11 @@ def extvelo(
7979
return adata
8080
elif method == "velovi":
8181
#Need to be tested
82-
from velovi import VELOVI
82+
from velovi import VELOVI,preprocess_data
8383
import torch
8484
import numpy as np
85-
import scipy as sp
85+
adata=adata.copy()
86+
adata = preprocess_data(adata,spliced_layer=Ms_key,unspliced_layer=Mu_key)
8687
VELOVI.setup_anndata(adata, spliced_layer=Ms_key, unspliced_layer=Mu_key)
8788
vae = VELOVI(adata)
8889
vae.train(**kwargs)

0 commit comments

Comments
 (0)