Skip to content

Commit af58715

Browse files
Merge pull request #731 from Starlitnightly/master
Added `celldancer`, `deepvelo`, and `latentvelo` to `dyn.tl.extvelo` (#721)
2 parents 27eb07f + 68c50fd commit af58715

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+7642
-18
lines changed

dynamo/external/celldancer/utilities.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,28 @@ def export_velocity_to_dynamo(cellDancer_df,adata):
443443
velocity_matrix = np.zeros(adata.shape)
444444
adata_ds_zeros = pd.DataFrame(velocity_matrix, columns=adata.var.index, index=adata.obs.index)
445445
celldancer_velocity_s_df = (adata_ds_zeros + pivoted).fillna(0)[adata.var.index]
446-
447446
adata.layers['velocity_S'] = scipy.sparse.csr_matrix(celldancer_velocity_s_df.values)
447+
448448
adata.var['use_for_dynamics'] = adata.var.index.isin(dancer_genes)
449449
adata.var['use_for_transition'] = adata.var.index.isin(dancer_genes)
450+
adata.uns['dynamics']={'filter_gene_mode': 'final',
451+
't': None,
452+
'group': None,
453+
'X_data': None,
454+
'X_fit_data': None,
455+
'asspt_mRNA': 'ss',
456+
'experiment_type': 'conventional',
457+
'normalized': True,
458+
'model': 'stochastic',
459+
'est_method': 'gmm',
460+
'has_splicing': True,
461+
'has_labeling': False,
462+
'splicing_labeling': False,
463+
'has_protein': False,
464+
'use_smoothed': True,
465+
'NTR_vel': False,
466+
'log_unnormalized': True,
467+
'fraction_for_deg': False}
450468
return(adata.copy())
451469

452470
def adata_to_raw(adata,save_path,gene_list=None):
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .train import *
2+
3+
from . import tool as tl
4+
from . import plot as pl
5+
from . import pipeline as pipe
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base_data_loader import *
2+
from .base_model import *
3+
from .base_trainer import *
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
from torch.utils.data import DataLoader
3+
from torch.utils.data.dataloader import default_collate
4+
from torch.utils.data.sampler import SubsetRandomSampler
5+
6+
7+
class BaseDataLoader(DataLoader):
8+
"""
9+
Base class for all data loaders
10+
"""
11+
12+
def __init__(
13+
self,
14+
dataset,
15+
batch_size,
16+
shuffle,
17+
validation_split,
18+
num_workers,
19+
collate_fn=default_collate,
20+
):
21+
self.validation_split = validation_split
22+
self.shuffle = shuffle
23+
24+
self.batch_idx = 0
25+
self.n_samples = len(dataset)
26+
27+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
28+
29+
self.init_kwargs = {
30+
"dataset": dataset,
31+
"batch_size": batch_size,
32+
"shuffle": self.shuffle,
33+
"collate_fn": collate_fn,
34+
"num_workers": num_workers,
35+
}
36+
super().__init__(sampler=self.sampler, **self.init_kwargs)
37+
38+
def _split_sampler(self, split):
39+
if split == 0.0:
40+
return None, None
41+
42+
idx_full = np.arange(self.n_samples)
43+
44+
np.random.seed(0)
45+
np.random.shuffle(idx_full)
46+
47+
if isinstance(split, int):
48+
assert split > 0
49+
assert (
50+
split < self.n_samples
51+
), "validation set size is configured to be larger than entire dataset."
52+
len_valid = split
53+
else:
54+
len_valid = int(self.n_samples * split)
55+
56+
valid_idx = idx_full[0:len_valid]
57+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
58+
59+
train_sampler = SubsetRandomSampler(train_idx)
60+
valid_sampler = SubsetRandomSampler(valid_idx)
61+
62+
# turn off shuffle option which is mutually exclusive with sampler
63+
self.shuffle = False
64+
self.n_samples = len(train_idx)
65+
66+
return train_sampler, valid_sampler
67+
68+
def split_validation(self):
69+
if self.valid_sampler is None:
70+
return None
71+
else:
72+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch.nn as nn
2+
import numpy as np
3+
from abc import abstractmethod
4+
5+
6+
class BaseModel(nn.Module):
7+
"""
8+
Base class for all models
9+
"""
10+
11+
@abstractmethod
12+
def forward(self, *inputs):
13+
"""
14+
Forward pass logic
15+
16+
:return: Model output
17+
"""
18+
raise NotImplementedError
19+
20+
def __str__(self):
21+
"""
22+
Model prints with number of trainable parameters
23+
"""
24+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
25+
params = sum([np.prod(p.size()) for p in model_parameters])
26+
return super().__str__() + "\nTrainable parameters: {}".format(params)
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import time
2+
import torch
3+
from abc import abstractmethod
4+
from numpy import inf
5+
from tqdm.auto import tqdm
6+
from ..logger import TensorboardWriter
7+
8+
9+
class BaseTrainer:
10+
"""
11+
Base class for all trainers
12+
"""
13+
14+
def __init__(self, model, criterion, metric_ftns, optimizer, config):
15+
self.config = config
16+
self.logger = config.get_logger("trainer", config["trainer"]["verbosity"])
17+
18+
# setup GPU device if available, move model into configured device
19+
self.device, device_ids = self._prepare_device(config["n_gpu"])
20+
self.model = model.to(self.device)
21+
if len(device_ids) > 1:
22+
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
23+
24+
self.criterion = criterion
25+
self.metric_ftns = metric_ftns
26+
self.optimizer = optimizer
27+
28+
cfg_trainer = config["trainer"]
29+
self.epochs = cfg_trainer["epochs"]
30+
self.save_period = cfg_trainer["save_period"]
31+
self.monitor = cfg_trainer.get("monitor", "off")
32+
33+
# configuration to monitor model performance and save best
34+
if self.monitor == "off":
35+
self.mnt_mode = "off"
36+
self.mnt_best = 0
37+
else:
38+
self.mnt_mode, self.mnt_metric = self.monitor.split()
39+
assert self.mnt_mode in ["min", "max"]
40+
41+
self.mnt_best = inf if self.mnt_mode == "min" else -inf
42+
self.early_stop = cfg_trainer.get("early_stop", inf)
43+
44+
self.start_epoch = 1
45+
46+
self.checkpoint_dir = config.save_dir
47+
48+
# setup visualization writer instance
49+
self.writer = TensorboardWriter(
50+
config.log_dir, self.logger, cfg_trainer["tensorboard"]
51+
)
52+
53+
if config.resume is not None:
54+
self._resume_checkpoint(config.resume)
55+
56+
@abstractmethod
57+
def _train_epoch(self, epoch):
58+
"""
59+
Training logic for an epoch
60+
61+
:param epoch: Current epoch number
62+
"""
63+
raise NotImplementedError
64+
65+
def train(self, callback=None, callback_freq=1):
66+
"""
67+
Full training logic
68+
"""
69+
not_improved_count = 0
70+
tik = time.time()
71+
if "mle" in self.config["loss"]["type"]:
72+
if self.config["arch"]["args"]["pred_unspliced"]:
73+
self.candidate_states = torch.cat(
74+
[
75+
self.data_loader.dataset.Sx_sz,
76+
self.data_loader.dataset.Ux_sz,
77+
],
78+
dim=1,
79+
).to(self.device)
80+
else:
81+
self.candidate_states = self.data_loader.dataset.Sx_sz.to(self.device)
82+
83+
# Create progress bar for epochs
84+
use_pbar = self.config["trainer"].get("use_progress_bar", True)
85+
if use_pbar:
86+
pbar = tqdm(range(self.start_epoch, self.epochs + 1),
87+
desc="Training",
88+
dynamic_ncols=True,
89+
leave=True,
90+
position=0)
91+
else:
92+
pbar = range(self.start_epoch, self.epochs + 1)
93+
94+
for epoch in pbar:
95+
result = self._train_epoch(epoch)
96+
97+
# save logged informations into log dict
98+
log = {"epoch": epoch, "time:": time.time() - tik}
99+
log.update(result)
100+
tik = time.time()
101+
102+
# Update progress bar with metrics or print to logger
103+
if use_pbar:
104+
postfix_dict = {k: f'{v:.4f}' if isinstance(v, float) else v
105+
for k, v in log.items() if k not in ['epoch', 'time:']}
106+
pbar.set_postfix(postfix_dict)
107+
pbar.refresh()
108+
else:
109+
# print logged informations to the screen
110+
for key, value in log.items():
111+
self.logger.info(" {:15s}: {}".format(str(key), value))
112+
113+
if callback is not None:
114+
if epoch % callback_freq == 0:
115+
callback(epoch)
116+
117+
# evaluate model performance according to configured metric, save best checkpoint as model_best
118+
best = False
119+
if self.mnt_mode != "off":
120+
try:
121+
# check whether model performance improved or not, according to specified metric(mnt_metric)
122+
improved = (
123+
self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best
124+
) or (
125+
self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best
126+
)
127+
except KeyError:
128+
self.logger.warning(
129+
"Warning: Metric '{}' is not found. "
130+
"Model performance monitoring is disabled.".format(
131+
self.mnt_metric
132+
)
133+
)
134+
self.mnt_mode = "off"
135+
improved = False
136+
137+
if improved:
138+
self.mnt_best = log[self.mnt_metric]
139+
not_improved_count = 0
140+
best = True
141+
else:
142+
not_improved_count += 1
143+
144+
if not_improved_count > self.early_stop:
145+
if use_pbar:
146+
pbar.close()
147+
self.logger.info(
148+
"Validation performance didn't improve for {} epochs. "
149+
"Training stops.".format(self.early_stop)
150+
)
151+
break
152+
153+
if epoch % self.save_period == 0:
154+
self._save_checkpoint(epoch, save_best=best)
155+
156+
if use_pbar:
157+
pbar.close()
158+
159+
def train_with_epoch_callback(self, callback, freq):
160+
self.train(callback, freq)
161+
162+
def _prepare_device(self, n_gpu_use):
163+
"""
164+
setup GPU device if available, move model into configured device
165+
"""
166+
n_gpu = torch.cuda.device_count()
167+
if n_gpu_use > 0 and n_gpu == 0:
168+
self.logger.warning(
169+
"Warning: There's no GPU available on this machine,"
170+
"training will be performed on CPU."
171+
)
172+
n_gpu_use = 0
173+
if n_gpu_use > n_gpu:
174+
self.logger.warning(
175+
"Warning: The number of GPU's configured to use is {}, but only {} are available "
176+
"on this machine.".format(n_gpu_use, n_gpu)
177+
)
178+
n_gpu_use = n_gpu
179+
device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu")
180+
list_ids = list(range(n_gpu_use))
181+
return device, list_ids
182+
183+
def _save_checkpoint(self, epoch, save_best=False):
184+
"""
185+
Saving checkpoints
186+
187+
:param epoch: current epoch number
188+
:param log: logging information of the epoch
189+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
190+
"""
191+
arch = type(self.model).__name__
192+
state = {
193+
"arch": arch,
194+
"epoch": epoch,
195+
"state_dict": self.model.state_dict(),
196+
"optimizer": self.optimizer.state_dict(),
197+
"monitor_best": self.mnt_best,
198+
"config": self.config,
199+
}
200+
filename = str(self.checkpoint_dir / "checkpoint-epoch{}.pth".format(epoch))
201+
torch.save(state, filename)
202+
self.logger.info("Saving checkpoint: {} ...".format(filename))
203+
if save_best:
204+
best_path = str(self.checkpoint_dir / "model_best.pth")
205+
torch.save(state, best_path)
206+
self.logger.info("Saving current best: model_best.pth ...")
207+
208+
def _resume_checkpoint(self, resume_path):
209+
"""
210+
Resume from saved checkpoints
211+
212+
:param resume_path: Checkpoint path to be resumed
213+
"""
214+
resume_path = str(resume_path)
215+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
216+
checkpoint = torch.load(resume_path)
217+
self.start_epoch = checkpoint["epoch"] + 1
218+
self.mnt_best = checkpoint["monitor_best"]
219+
220+
# load architecture params from checkpoint.
221+
if checkpoint["config"]["arch"] != self.config["arch"]:
222+
self.logger.warning(
223+
"Warning: Architecture configuration given in config file is different from that of "
224+
"checkpoint. This may yield an exception while state_dict is being loaded."
225+
)
226+
self.model.load_state_dict(checkpoint["state_dict"])
227+
228+
# load optimizer state from checkpoint only when optimizer type is not changed.
229+
if (
230+
checkpoint["config"]["optimizer"]["type"]
231+
!= self.config["optimizer"]["type"]
232+
):
233+
self.logger.warning(
234+
"Warning: Optimizer type given in config file is different from that of checkpoint. "
235+
"Optimizer parameters not being resumed."
236+
)
237+
else:
238+
self.optimizer.load_state_dict(checkpoint["optimizer"])
239+
240+
self.logger.info(
241+
"Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
242+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Data loader module for DeepVelo

0 commit comments

Comments
 (0)