Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion dynamo/external/celldancer/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,28 @@ def export_velocity_to_dynamo(cellDancer_df,adata):
velocity_matrix = np.zeros(adata.shape)
adata_ds_zeros = pd.DataFrame(velocity_matrix, columns=adata.var.index, index=adata.obs.index)
celldancer_velocity_s_df = (adata_ds_zeros + pivoted).fillna(0)[adata.var.index]

adata.layers['velocity_S'] = scipy.sparse.csr_matrix(celldancer_velocity_s_df.values)

adata.var['use_for_dynamics'] = adata.var.index.isin(dancer_genes)
adata.var['use_for_transition'] = adata.var.index.isin(dancer_genes)
adata.uns['dynamics']={'filter_gene_mode': 'final',
't': None,
'group': None,
'X_data': None,
'X_fit_data': None,
'asspt_mRNA': 'ss',
'experiment_type': 'conventional',
'normalized': True,
'model': 'stochastic',
'est_method': 'gmm',
'has_splicing': True,
'has_labeling': False,
'splicing_labeling': False,
'has_protein': False,
'use_smoothed': True,
'NTR_vel': False,
'log_unnormalized': True,
'fraction_for_deg': False}
return(adata.copy())

def adata_to_raw(adata,save_path,gene_list=None):
Expand Down
5 changes: 5 additions & 0 deletions dynamo/external/deepvelo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .train import *

from . import tool as tl
from . import plot as pl
from . import pipeline as pipe
3 changes: 3 additions & 0 deletions dynamo/external/deepvelo/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_data_loader import *
from .base_model import *
from .base_trainer import *
72 changes: 72 additions & 0 deletions dynamo/external/deepvelo/base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler


class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""

def __init__(
self,
dataset,
batch_size,
shuffle,
validation_split,
num_workers,
collate_fn=default_collate,
):
self.validation_split = validation_split
self.shuffle = shuffle

self.batch_idx = 0
self.n_samples = len(dataset)

self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

self.init_kwargs = {
"dataset": dataset,
"batch_size": batch_size,
"shuffle": self.shuffle,
"collate_fn": collate_fn,
"num_workers": num_workers,
}
super().__init__(sampler=self.sampler, **self.init_kwargs)

def _split_sampler(self, split):
if split == 0.0:
return None, None

idx_full = np.arange(self.n_samples)

np.random.seed(0)
np.random.shuffle(idx_full)

if isinstance(split, int):
assert split > 0
assert (
split < self.n_samples
), "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)

return train_sampler, valid_sampler

def split_validation(self):
if self.valid_sampler is None:
return None
else:
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
26 changes: 26 additions & 0 deletions dynamo/external/deepvelo/base/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
"""
Base class for all models
"""

@abstractmethod
def forward(self, *inputs):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + "\nTrainable parameters: {}".format(params)
242 changes: 242 additions & 0 deletions dynamo/external/deepvelo/base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import time
import torch
from abc import abstractmethod
from numpy import inf
from tqdm.auto import tqdm
from ..logger import TensorboardWriter


class BaseTrainer:
"""
Base class for all trainers
"""

def __init__(self, model, criterion, metric_ftns, optimizer, config):
self.config = config
self.logger = config.get_logger("trainer", config["trainer"]["verbosity"])

# setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(config["n_gpu"])
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)

self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer

cfg_trainer = config["trainer"]
self.epochs = cfg_trainer["epochs"]
self.save_period = cfg_trainer["save_period"]
self.monitor = cfg_trainer.get("monitor", "off")

# configuration to monitor model performance and save best
if self.monitor == "off":
self.mnt_mode = "off"
self.mnt_best = 0
else:
self.mnt_mode, self.mnt_metric = self.monitor.split()
assert self.mnt_mode in ["min", "max"]

self.mnt_best = inf if self.mnt_mode == "min" else -inf
self.early_stop = cfg_trainer.get("early_stop", inf)

self.start_epoch = 1

self.checkpoint_dir = config.save_dir

# setup visualization writer instance
self.writer = TensorboardWriter(
config.log_dir, self.logger, cfg_trainer["tensorboard"]
)

if config.resume is not None:
self._resume_checkpoint(config.resume)

@abstractmethod
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError

def train(self, callback=None, callback_freq=1):
"""
Full training logic
"""
not_improved_count = 0
tik = time.time()
if "mle" in self.config["loss"]["type"]:
if self.config["arch"]["args"]["pred_unspliced"]:
self.candidate_states = torch.cat(
[
self.data_loader.dataset.Sx_sz,
self.data_loader.dataset.Ux_sz,
],
dim=1,
).to(self.device)
else:
self.candidate_states = self.data_loader.dataset.Sx_sz.to(self.device)

# Create progress bar for epochs
use_pbar = self.config["trainer"].get("use_progress_bar", True)
if use_pbar:
pbar = tqdm(range(self.start_epoch, self.epochs + 1),
desc="Training",
dynamic_ncols=True,
leave=True,
position=0)
else:
pbar = range(self.start_epoch, self.epochs + 1)

for epoch in pbar:
result = self._train_epoch(epoch)

# save logged informations into log dict
log = {"epoch": epoch, "time:": time.time() - tik}
log.update(result)
tik = time.time()

# Update progress bar with metrics or print to logger
if use_pbar:
postfix_dict = {k: f'{v:.4f}' if isinstance(v, float) else v
for k, v in log.items() if k not in ['epoch', 'time:']}
pbar.set_postfix(postfix_dict)
pbar.refresh()
else:
# print logged informations to the screen
for key, value in log.items():
self.logger.info(" {:15s}: {}".format(str(key), value))

if callback is not None:
if epoch % callback_freq == 0:
callback(epoch)

# evaluate model performance according to configured metric, save best checkpoint as model_best
best = False
if self.mnt_mode != "off":
try:
# check whether model performance improved or not, according to specified metric(mnt_metric)
improved = (
self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best
) or (
self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best
)
except KeyError:
self.logger.warning(
"Warning: Metric '{}' is not found. "
"Model performance monitoring is disabled.".format(
self.mnt_metric
)
)
self.mnt_mode = "off"
improved = False

if improved:
self.mnt_best = log[self.mnt_metric]
not_improved_count = 0
best = True
else:
not_improved_count += 1

if not_improved_count > self.early_stop:
if use_pbar:
pbar.close()
self.logger.info(
"Validation performance didn't improve for {} epochs. "
"Training stops.".format(self.early_stop)
)
break

if epoch % self.save_period == 0:
self._save_checkpoint(epoch, save_best=best)

if use_pbar:
pbar.close()

def train_with_epoch_callback(self, callback, freq):
self.train(callback, freq)

def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning(
"Warning: There's no GPU available on this machine,"
"training will be performed on CPU."
)
n_gpu_use = 0
if n_gpu_use > n_gpu:
self.logger.warning(
"Warning: The number of GPU's configured to use is {}, but only {} are available "
"on this machine.".format(n_gpu_use, n_gpu)
)
n_gpu_use = n_gpu
device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu")
list_ids = list(range(n_gpu_use))
return device, list_ids

def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
"""
arch = type(self.model).__name__
state = {
"arch": arch,
"epoch": epoch,
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"monitor_best": self.mnt_best,
"config": self.config,
}
filename = str(self.checkpoint_dir / "checkpoint-epoch{}.pth".format(epoch))
torch.save(state, filename)
self.logger.info("Saving checkpoint: {} ...".format(filename))
if save_best:
best_path = str(self.checkpoint_dir / "model_best.pth")
torch.save(state, best_path)
self.logger.info("Saving current best: model_best.pth ...")

def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
:param resume_path: Checkpoint path to be resumed
"""
resume_path = str(resume_path)
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint["epoch"] + 1
self.mnt_best = checkpoint["monitor_best"]

# load architecture params from checkpoint.
if checkpoint["config"]["arch"] != self.config["arch"]:
self.logger.warning(
"Warning: Architecture configuration given in config file is different from that of "
"checkpoint. This may yield an exception while state_dict is being loaded."
)
self.model.load_state_dict(checkpoint["state_dict"])

# load optimizer state from checkpoint only when optimizer type is not changed.
if (
checkpoint["config"]["optimizer"]["type"]
!= self.config["optimizer"]["type"]
):
self.logger.warning(
"Warning: Optimizer type given in config file is different from that of checkpoint. "
"Optimizer parameters not being resumed."
)
else:
self.optimizer.load_state_dict(checkpoint["optimizer"])

self.logger.info(
"Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
)
1 change: 1 addition & 0 deletions dynamo/external/deepvelo/data_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Data loader module for DeepVelo
Loading
Loading