Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models): add k-NN model #649

Draft
wants to merge 47 commits into
base: beta
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b4c100f
added atac_layer argument to train_model tasks and made tests for it
AlexanderAivazidis Jul 8, 2024
c1ee5cb
fix(vscode): disable automatic python env activation
AlexanderAivazidis Jul 8, 2024
f8d04f3
corrected ATAC argument type from layer name to anndata object
AlexanderAivazidis Jul 10, 2024
c83e065
feat[PyroVelocity]: added atac_data to setup_anndata method
AlexanderAivazidis Jul 10, 2024
8661a36
fix[PyroVelocity]: missing colon in Args description
AlexanderAivazidis Jul 10, 2024
c643465
feat(VelocityTrainingMixin): Added atac data to train_faster method
AlexanderAivazidis Jul 10, 2024
b0ce0f9
feat(_velocity_module): Added a MultiVelocityModule for multiome data.
AlexanderAivazidis Jul 10, 2024
2abf5ec
feat(_velocity_model): Added a MultiVelocityModelAuto class for multi…
AlexanderAivazidis Jul 11, 2024
7b9c355
feat(_transcription_dynamics): Added function for multiome dynamics.
AlexanderAivazidis Jul 11, 2024
6955c46
feat(_test_transcription_dynamics): Unit tests for transcription dyna…
AlexanderAivazidis Jul 22, 2024
db8cbb1
feat(_test_velocity_model): Unit tests for velocity model
AlexanderAivazidis Jul 22, 2024
b708935
feat(.gitignore): Added example_notebooks directory to .gitignore file.
AlexanderAivazidis Jul 22, 2024
7c7c73b
fix[_trainer_]: checking for existence of atac data
AlexanderAivazidis Jul 22, 2024
e383388
fix(_velocity): Save existence of atac data in adata
AlexanderAivazidis Jul 22, 2024
fc9be1a
fix(_velocity_model): LogNormal instead of Normal likelihood for atac…
AlexanderAivazidis Jul 22, 2024
94569f8
fix(_trainer.py): Properly processing atac data
AlexanderAivazidis Jul 26, 2024
750be4a
fix(_transcription_dynamics): Ensuring no inplace tensor operations i…
AlexanderAivazidis Jul 26, 2024
a2628d4
fix(_velocity): Handling atac data.
AlexanderAivazidis Jul 26, 2024
ec34568
fix(_velocity_module): Removed rates from multivariateNormalGuide, be…
AlexanderAivazidis Jul 26, 2024
085d7ac
fix(train): Ensure atac data is handled properly.
AlexanderAivazidis Jul 26, 2024
fbd06a8
feat(_transcription_dynamics): Added latent discrete parameter for mo…
AlexanderAivazidis Jul 26, 2024
113c3a4
feat(_velocity_model): Sampling latent discrete parameter for modelli…
AlexanderAivazidis Jul 26, 2024
40581dc
feat(_test_transcription_dynamics): Adapted test to include latent di…
AlexanderAivazidis Jul 26, 2024
c90af64
feat(_models): Initial commit for new knn_model.
AlexanderAivazidis Aug 23, 2024
211d9b5
feat(preprocess.py): Added function for computation of metacells.
AlexanderAivazidis Aug 23, 2024
73b1b87
feat(tests): Added function to produce synthetic AnnData.
AlexanderAivazidis Aug 23, 2024
725a0bd
feat(tests): Added test for compute_metacell function.
AlexanderAivazidis Aug 23, 2024
a5f31df
feat(tests): Added test for synthetic_AnnData function.
AlexanderAivazidis Aug 23, 2024
58a48e9
feat(models): Started knn model in a new folder.
AlexanderAivazidis Aug 26, 2024
75f8196
feat(knn_model): Basic vector field for knn_model.
AlexanderAivazidis Aug 29, 2024
3c840cb
feat(knn_model): Basic regulatory function, using 2 layer neural net,…
AlexanderAivazidis Aug 29, 2024
fbca112
feat(knn_model): Initial commit for pyro model for knn model.
AlexanderAivazidis Aug 29, 2024
33aae38
feat(train.py): resolved merge conflict by keeping change proposed in…
AlexanderAivazidis Sep 3, 2024
ae19b05
fix(preprocess): Added spliced/unspliced count aggregation to metacells.
AlexanderAivazidis Sep 3, 2024
567f700
fix(regulatory_function_1): needs numpy import
AlexanderAivazidis Sep 3, 2024
8586e15
fix(preprocess): need to return sparse matrix in layers after metacel…
AlexanderAivazidis Sep 3, 2024
5628463
fix(knn_model._velocity): Fixed various bugs in imports.
AlexanderAivazidis Sep 5, 2024
f34d7c5
feat(knn_model._velocity_model): Completed first draft of the model.
AlexanderAivazidis Sep 5, 2024
3963a4d
fix(_velocity_module): Fixed various imports.
AlexanderAivazidis Sep 5, 2024
a2e301e
feat(regulatory_function_1): Adapted function to deal with more than …
AlexanderAivazidis Sep 5, 2024
a69fda0
feat(knn_model): Added init file
AlexanderAivazidis Sep 5, 2024
39da15c
feat(knn_model): Completed first version of knn_model that trains wit…
AlexanderAivazidis Sep 6, 2024
e5a715b
fix(_velocity): Provide number of cells in each metacell to model.
AlexanderAivazidis Sep 13, 2024
f51a3dd
fix(_velocity_model): Various fixes to model architecture.
AlexanderAivazidis Sep 13, 2024
46cb17d
feat(_velocity_module): Provide number of cells in each metacell to m…
AlexanderAivazidis Sep 13, 2024
9fec683
fix(compute_metacell): copying over var_names
AlexanderAivazidis Sep 13, 2024
bba3a2f
feat(velocity_model): changed neural networks to be similar to cellda…
AlexanderAivazidis Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
/archive/
example_notebooks/*

#
.DS_Store
Expand Down
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"search.followSymlinks": false,
"terminal.integrated.fontSize": 14,
"terminal.integrated.scrollback": 100000,
"python.terminal.activateEnvironment": false,
"workbench.colorTheme": "Catppuccin Mocha",
"workbench.iconTheme": "vscode-icons",
// Passing --no-cov to pytestArgs is required to respect breakpoints
Expand Down
3 changes: 2 additions & 1 deletion src/pyrovelocity/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from pyrovelocity.models._deterministic_simulation import (
solve_transcription_splicing_model_analytical,
)
from pyrovelocity.models._transcription_dynamics import mrna_dynamics
from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics
from pyrovelocity.models._velocity import PyroVelocity


__all__ = [
deterministic_transcription_splicing_probabilistic_model,
mrna_dynamics,
atac_mrna_dynamics,
PyroVelocity,
solve_transcription_splicing_model,
solve_transcription_splicing_model_analytical,
Expand Down
163 changes: 115 additions & 48 deletions src/pyrovelocity/models/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def train_faster(
if scipy.sparse.issparse(self.adata.layers["raw_spliced"])
else self.adata.layers["raw_spliced"],
dtype=torch.float32,
).to(device)

).to(device)
epsilon = 1e-6

log_u_library_size = np.log(
Expand Down Expand Up @@ -335,60 +335,127 @@ def train_faster(

losses = []
patience = patient_init
for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,

if not self.adata.uns['atac']:

for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,
)
/ normalizer
)
/ normalizer
)
else:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
else:
elbos = (
svi.step(
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)

else:

c = torch.tensor(
np.array(
self.adata.layers["atac"].toarray(), dtype="float32"
)
if scipy.sparse.issparse(self.adata.layers["atac"])
else self.adata.layers["atac"],
dtype=torch.float32,
).to(device)


for step in range(max_epochs):
if cell_state is None:
elbos = (
svi.step(
c,
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
None,
)
/ normalizer
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)
elbos = (
svi.step(
c,
u,
s,
u_library.reshape(-1, 1),
s_library.reshape(-1, 1),
u_library_mean.reshape(-1, 1),
s_library_mean.reshape(-1, 1),
u_library_scale.reshape(-1, 1),
s_library_scale.reshape(-1, 1),
None,
cell_state.reshape(-1, 1),
)
/ normalizer
)
if (step == 0) or (
((step + 1) % log_every == 0) and ((step + 1) < max_epochs)
):
mlflow.log_metric("-ELBO", -elbos, step=step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
if step > log_every:
if (losses[-1] - elbos) < losses[-1] * patient_improve:
patience -= 1
else:
patience = patient_init
if patience <= 0:
break
losses.append(elbos)

mlflow.log_metric("-ELBO", -elbos, step=step + 1)
mlflow.log_metric("real_epochs", step + 1)
logger.info(
f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}"
)
return losses
return losses

def train_faster_with_batch(
self,
Expand Down
Loading
Loading