Predict drug perturbation effects on single-cell gene expression
Award-winning solution from the NeurIPS 2023 Single-Cell Perturbations Challenge:
- π₯ $10,000 Judges' Prize for performance and methodology
- π₯ 2nd place in post-hoc analysis
- π Top 2% overall (16th/1097 teams)
pip install git+https://github.com/scapeML/scape.git
import scape
# Train model with drug cross-validation
result = scape.train(
de_file="de_train.parquet",
lfc_file="lfc_train.parquet",
cv_drug="Belinostat",
n_genes=64
)
# Visualize performance vs baselines
scape.plot_result(result)
ScAPE is a lightweight neural network (~9.6M parameters) that predicts differential gene expression in response to drug perturbations. Built with Keras 3 for multi-backend support (TensorFlow, JAX, PyTorch).
- π― Single or Multi-Task Learning: Predict p-values only or jointly with fold changes
- π Multi-Backend Support: Choose between TensorFlow, JAX, or PyTorch
- π² Built-in Ensemble Methods: Simple blending for robust predictions
- π Cross-Validation: Cell-type and drug-based validation strategies
- β‘ Efficient: Handles ~18,000 genes with median-based feature engineering
The model uses median-based feature engineering: for each drug and cell type, we compute median differential expression values across the dataset. This reduces ~18,000 genes to manageable drug/cell signatures while preserving biological signal.
Key design choices:- Dual conditioning: Cell features are used in both encoder and decoder (similar to CVAEs)
- Non-probabilistic: After testing VAE variants, we found a simpler deterministic NN performed equally well.
- Multi-source features: Combines signed log p-values and log fold changes for richer representations
# Command line
python -m scape train --n-genes 64 --cv-drug Belinostat de_train.parquet lfc_train.parquet
# Python API
import scape
model = scape.model.create_default_model(
n_genes=64,
df_de=de_data,
df_lfc=lfc_data
)
results = model.train(
val_cells=['NK cells'],
val_drugs=['Belinostat'],
epochs=600
)
Configure the model to jointly predict both p-values and fold changes:
# Multi-task configuration with optimal weights
model.model.compile(
optimizer=optimizer,
loss={'slogpval': mrrmse, 'lfc': mrrmse},
loss_weights={'slogpval': 0.8, 'lfc': 0.2}
)
# Use JAX backend (recommended for performance)
KERAS_BACKEND=jax python -m scape train ...
# Use TensorFlow backend
KERAS_BACKEND=tensorflow python -m scape train ...
# Use PyTorch backend
KERAS_BACKEND=torch python -m scape train ...
Improve robustness with simple ensemble blending:
from sklearn.model_selection import KFold
import numpy as np
# Train multiple models with K-fold
predictions = []
for train_idx, val_idx in KFold(n_splits=5).split(all_combinations):
model = scape.model.create_default_model(...)
model.train(...)
predictions.append(model.predict(test_combinations))
# Blend predictions (median)
ensemble_pred = np.median([p.values for p in predictions], axis=0)
# Custom architecture
config = {
"encoder_hidden_layer_sizes": [128, 128],
"decoder_hidden_layer_sizes": [128, 512],
"outputs": {
"slogpval": (64, "linear"),
"lfc": (64, "linear"), # Multi-task
},
"noise": 0.01,
"dropout": 0.05
}
model = scape.model.create_model(
n_genes=64,
df_de=de_data,
df_lfc=lfc_data,
config=config
)
Track model improvement over baselines:
- Zero baseline: Always predicts 0 (competition baseline)
- Median baseline: Predicts drug-specific medians
- π Quick Start Tutorial
- π Training Pipeline
- π Google Colab Demo
- π Technical Report
- πΎ Dataset (Zenodo)
# Setup with pixi
pixi install
pixi shell -e dev
# Run tests (JAX backend recommended)
KERAS_BACKEND=jax pixi run -e dev test
# Lint & format
pixi run lint
pixi run format
@misc{rodriguezmier24scape,
author = {Rodriguez-Mier, Pablo and Garrido-Rodriguez, Martin},
title = {ScAPE: Single-cell Analysis of Perturbational Effects},
year = {2024},
url = {https://github.com/scapeML/scape}
}