Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 16 additions & 34 deletions docs/docs/models/geneformer.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# Geneformer
!!! note "Current checkpoints trained in BioNeMo1"

This document references performance numbers and runtime engines that are from the bionemo v1 variant of the model.
These numbers will be updated in a coming release to reflect the new bionemo v2 codebase. The model architecture and
training information will be the same, as checkpoints are converted from bionemo v1 format to v2 format. Benchmarks below
are annotated with which version of bionemo generated them. Accuracy should be the same within a small epsilon
since we have tests in place showing model equivalency between the two versions.

## Model Overview

Expand Down Expand Up @@ -155,32 +148,21 @@ NVIDIA believes Trustworthy AI is a shared responsibility and we have establishe

## Training diagnostics

### geneformer-10M-240530

This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 8 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 32 and global batch size of 2048. Training took a total of 1 day, 20 hours and 19 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
![Validation and training losses both decreased smoothly through training](../assets/old_images/sc_fm/geneformer-10m-240530-val-train-loss.png)

!!! note "Training curves from BioNeMo1"

Note that these curves were generated on BioNeMo1. We see the same general training curves in our initial testing of
BioNeMo2, however. In the following figure the blue line is the previous training run of the 10M model and the
red curve is an equivalent training run on BioNeMo2. As we release new checkpoints they will be trained on BioNeMo2.

![Training curve equivalence](../assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png)


### geneformer-106M-240530
### geneformer-10M
<!-- WandB Logs: https://wandb.ai/clara-discovery/Geneformer-pretraining-jsjconfigs/runs/i8LWOctg?nw=nwuserjomitchell -->
Training was performed on 8 servers with 8 A100 GPUs each for a total of 81485 steps using the CELLxGENE split with a per-gpu micro batch size 32 and global batch size of 2048. Training took a total of 4 days, 8 hours of wallclock time. As can be seen in the following images, training and validation curves both decreased fairly smoothly throughout the course of training.

This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 16 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 16 and global batch size of 2048. Training took a total of 3 days, 18 hours and 55 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
![Validation and training losses both decreased smoothly through training](../assets/old_images/sc_fm/geneformer-106m-240530-val-train-loss.png)
![Training Loss Geneformer 10M](../assets/images/geneformer/geneformer_10m_training_loss.png)
![Validation Loss Geneformer 10M](../assets/images/geneformer/geneformer_10m_val_loss.png)


Additionally, validation loss decreased both faster and continued to decrease at the same improved rate throughout training in the 106M parameter model (red) as compared to the 10M parameter model (blue). It would be interesting to test even larger models to see if we continue to observe improved performance in larger models.
![106M parameter model outperformed 10M parameter model](../assets/old_images/sc_fm/geneformer-240530-val-comparison.png)

!! note "Training curves from BioNeMo1"
### geneformer-106M
<!-- WandB Logs https://wandb.ai/clara-discovery/geneformer-pretraining-106m-16node-spike -->
This checkpoint was trained for approximately 35,650 steps using the CELLxGENE split. Training was performed on 16 servers with 8 A100 GPUs each for a total of 35,650 steps using the CELLxGENE split with a per-gpu micro batch size 16 and global batch size of 2,048. Training took a total of 8 hours of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training.

As stated in the previous section, the figures are from our BioNeMo1 code base where these checkpoints were originally
trained. As we release new checkpoints they will be trained on BioNeMo2.
![Training Loss Geneformer 106M](../assets/images/geneformer/Geneformer_steven_106m_train.png)
![Validation Loss Geneformer 106M](../assets/images/geneformer/Geneformer_steven_106m_val.png)

## Benchmarking

Expand All @@ -192,9 +174,9 @@ The following describes the bert MLM token loss. Like in the original BERT paper

| Model Description | Token Loss (lower is better) |
| ---------------------- | ---------------------------- |
| Baseline geneformer | 2.26* |
| geneformer-10M-240530 | 2.64 |
| geneformer-106M-240530 | 2.34 |
| Baseline geneformer | 3.206* |
| geneformer-10M-240530 | 3.18 |
| geneformer-106M-240530 | 2.89 |

!!! bug "Baseline Geneformer was recently updated on huggingface making loss comparisons challenging."

Expand Down Expand Up @@ -222,8 +204,8 @@ Elmentaite et al. (2020), Developmental Cell. This dataset contains approximatel

For more details see the example notebook titled Geneformer-celltype-classification-example.ipynb

![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/F1-score-models.png)
![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/average-accuracy-models.png)
![F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/F1-score-models-04-18-25.png)
![Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.](../assets/images/geneformer/average-accuracy-models-04-18-25.png)

### Performance Benchmarks

Expand Down
Loading