Skip to content

Commit e1fea43

Browse files
authored
Initial repo with basic instructions, can load model and do inference (#1)
* add first basic inference code * add stripedhyena * continue to simplify and test * save pip installs * clean up model loading, add auto ckpt download * cleanup evo * continue to polish the repo * add example scripts * add todos to readme, add to readme * fix links in readme * fix subbullets
1 parent b466872 commit e1fea43

29 files changed

+3136
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.egg-info/
2+
__pycache__/
3+
build/
4+
dist/

README.md

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,108 @@
1-
# evo
1+
# Evo: DNA foundation modeling from molecular to genome scale
2+
3+
Tasks remaining:
4+
- [ ] Upload checkpoints to the web and finalize auto downloading code.
5+
- [ ] Verify logits are the same as private repo.
6+
- [ ] Package and upload to PyPI.
7+
- [ ] Update with preprint info, blog info, Together API info, and HF info.
8+
9+
Evo is a biological foundation model capable of long-context modeling and design.
10+
Evo uses the [StripedHyena architecture](https://github.com/togethercomputer/stripedhyena) to enable modeling of sequences at a single-nucleotide, byte-level resolution with near-linear scaling of compute and memory relative to context length.
11+
Evo has 7 billion parameters and is trained on OpenGenome, a prokaryotic whole-genome dataset containing 260 billion tokens.
12+
13+
Technical details about Evo can be found in our preprint and the accompanying blog.
14+
15+
We provide the following model checkpoints:
16+
- `evo-1_stripedhyena_pretrained_8k`: A model pretrained with 8k context. We use this model as the base model for molecular-scale finetuning tasks.
17+
- `evo-1_stripedhyena_pretrained_131k`: A model pretrained with 131k context using `evo-1_stripedhyena_pretrained_8k` as the base model. We use this model to reason about and generate sequences at the genome scale.
18+
19+
## Contents
20+
21+
- [Setup](#setup)
22+
- [Requirements](#requirements)
23+
- [Installation](#installation)
24+
- [Usage](#usage)
25+
- [Web API](#web-api)
26+
- [HuggingFace](#hugging-face)
27+
- [Citation](#citation)
28+
29+
## Setup
30+
31+
### Requirements
32+
33+
Evo uses [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), which may not work on all GPU architectures.
34+
Please consult the [FlashAttention GitHub repository](https://github.com/Dao-AILab/flash-attention#installation-and-features) for the current list of supported GPUs.
35+
36+
Evo also uses PyTorch. Make sure the correct [PyTorch version is installed](https://pytorch.org/) on your system.
37+
38+
### Installation
39+
40+
You can install Evo using `pip`
41+
```bash
42+
pip install evo-model
43+
```
44+
or directly from the GitHub source
45+
```bash
46+
git clone https://github.com/evo-design/evo.git
47+
cd evo/
48+
pip install .
49+
```
50+
51+
## Usage
52+
53+
You can download Evo and use it locally through the Python API. For example:
54+
```python
55+
from evo import Evo
56+
import torch
57+
58+
device = 'cuda:0'
59+
60+
evo_model = Evo('evo-1_stripedhyena_pretrained_8k')
61+
model, tokenizer = evo_model.model, evo_model.tokenizer
62+
model.to(device)
63+
model.eval()
64+
65+
sequence = 'ACGT'
66+
input_ids = torch.tensor(
67+
tokenizer.tokenize(sequence),
68+
dtype=torch.int,
69+
).to(device).unsqueeze(0)
70+
logits, _ = model(input_ids) # (batch, length, vocab)
71+
72+
print('Logits: ', logits)
73+
print('Shape (batch, length, vocab): ', logits.shape)
74+
```
75+
Examples of batched inference can be found in [`scripts/example_inference.py`](scripts/example_inference.py).
76+
77+
We provide an example script for how to prompt the model and sample a set of sequences given the prompt.
78+
```bash
79+
python scripts/generate.py \
80+
--model-name evo-1_stripedhyena_pretrained_8k \
81+
--prompt ACGT \
82+
--n-samples 10 \
83+
--n-tokens 100 \
84+
--temperature 1. \
85+
--top-k 4 \
86+
--device cuda:0
87+
```
88+
89+
We also provide an example script for using the model to score the log-likelihoods of a set of sequences.
90+
```bash
91+
python scripts/score.py \
92+
--input-fasta examples/example_seqs.fasta \
93+
--output-tsv scores.tsv \
94+
--model-name evo-1_stripedhyena_pretrained_8k \
95+
--device cuda:0
96+
```
97+
98+
## Web API
99+
100+
We are working with [Together.AI](https://www.together.ai/) on a web API that will provide logits and sampling functionality for Evo.
101+
102+
## HuggingFace integration
103+
104+
We are working on integration with [HuggingFace](https://huggingface.co/).
105+
106+
## Citation
107+
108+
We will make a preprint publicly available soon.

evo/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .version import version as __version__
2+
3+
from .models import Evo
4+
5+
from .generation import generate
6+
from .scoring import score_sequences, positional_entropies

evo/generation.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import numpy as np
2+
import sys
3+
import torch
4+
from typing import List, Tuple, Union
5+
6+
from .models import load_checkpoint
7+
from .scoring import logits_to_logprobs, prepare_batch
8+
from .stripedhyena.src.generation import Generator
9+
from .stripedhyena.src.model import StripedHyena
10+
from .stripedhyena.src.tokenizer import CharLevelTokenizer
11+
12+
13+
def generate(
14+
prompt_seqs: List[str],
15+
model: StripedHyena,
16+
tokenizer: CharLevelTokenizer,
17+
n_tokens: int = 100,
18+
temperature: float = 0.,
19+
top_k: int = 1,
20+
top_p: float = 1.,
21+
skipped_tokens: Union[str, List[str], List[int]] = None,
22+
batched: bool = True,
23+
prepend_bos: bool = True,
24+
cached_generation: bool = False,
25+
verbose: int = 1,
26+
device: str = 'cuda:0',
27+
**kwargs,
28+
) -> Tuple[List[str], List[float]]:
29+
"""
30+
Performs generation from a list of prompts.
31+
If all prompts are the same length, this can do batched generation.
32+
Also supports cached generation for efficient sampling.
33+
"""
34+
model.eval()
35+
36+
g = Generator(
37+
model,
38+
tokenizer,
39+
top_k=top_k,
40+
top_p=top_p,
41+
temperature=temperature,
42+
)
43+
44+
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
45+
46+
if batched and uniform_lengths:
47+
input_ids_list = [
48+
prepare_batch(
49+
prompt_seqs,
50+
tokenizer,
51+
prepend_bos=prepend_bos,
52+
device=device,
53+
)[0]
54+
]
55+
else:
56+
if verbose:
57+
if not uniform_lengths:
58+
sys.stderr.write('Note: Prompts are of different lengths.\n')
59+
sys.stderr.write('Note: Will not do batched generation.\n')
60+
input_ids_list = [
61+
prepare_batch(
62+
[ prompt_seq ],
63+
tokenizer,
64+
prepend_bos=prepend_bos,
65+
device=device,
66+
)[0]
67+
for prompt_seq in prompt_seqs
68+
]
69+
70+
generated_seqs, generated_scores = [], []
71+
for input_ids in input_ids_list:
72+
batch_size = input_ids.shape[0]
73+
74+
output_ids, logits = g.generate(
75+
input_ids=input_ids,
76+
num_tokens=n_tokens,
77+
cached_generation=cached_generation,
78+
device=device,
79+
print_generation=True,
80+
verbose=(verbose > 1),
81+
skipped_tokens=skipped_tokens,
82+
stop_at_eos=False,
83+
)
84+
if verbose > 1:
85+
print('input_ids.shape', input_ids.shape)
86+
print('output_ids.shape', output_ids.shape)
87+
print('logits.shape', logits.shape)
88+
89+
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
90+
assert len(generated_seqs_batch) == batch_size
91+
generated_seqs += generated_seqs_batch
92+
93+
logprobs = logits_to_logprobs(logits, output_ids, trim_bos=prepend_bos)
94+
logprobs = logprobs.float().cpu().numpy()
95+
96+
generated_scores += [ np.mean(logprobs[idx]) for idx in range(batch_size) ]
97+
98+
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
99+
if verbose:
100+
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
101+
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
102+
103+
return generated_seqs, generated_scores

evo/models.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import requests
3+
import torch
4+
from typing import List, Tuple
5+
import yaml
6+
7+
from .stripedhyena.src.utils import dotdict
8+
from .stripedhyena.src.model import StripedHyena
9+
from .stripedhyena.src.tokenizer import CharLevelTokenizer
10+
11+
12+
VALID_MODEL_NAMES = [
13+
'evo-1_stripedhyena_pretrained_8k',
14+
'evo-1_stripedhyena_pretrained_131k',
15+
]
16+
17+
18+
class Evo:
19+
def __init__(self, model_name: str, device: str = 'cuda:0'):
20+
"""
21+
Loads an Evo model checkpoint given a model name.
22+
If the checkpoint does not exist, automatically downloads the model to
23+
`~/.cache/torch/hub/checkpoints`.
24+
"""
25+
26+
if model_name not in VALID_MODEL_NAMES:
27+
raise ValueError(
28+
f'Invalid model name {model_name}. Should be one of: '
29+
f'{", ".join(VALID_MODEL_NAMES)}.'
30+
)
31+
32+
# Download checkpoint.
33+
34+
home_directory = os.path.expanduser('~')
35+
download_url = f'https://TODO/checkpoints/{model_name}.pt'
36+
cache_dir = f'{home_directory}/.cache/torch/hub/checkpoints'
37+
checkpoint_path = f'{cache_dir}/{model_name}.pt'
38+
39+
if not os.path.exists(checkpoint_path):
40+
print(f'Downloading {download_url} to {cache_dir}...')
41+
42+
if not os.path.exists(cache_dir):
43+
os.makedirs(cache_dir, exist_ok=True)
44+
45+
response = requests.get(download_url, stream=True)
46+
if response.status_code == 200:
47+
with open(checkpoint_path, 'wb') as f:
48+
f.write(response.raw.read())
49+
else:
50+
raise Exception(f'Failed to download the file. Status code: {response.status_code}')
51+
52+
# Load correct config file.
53+
54+
if model_name == 'evo-1_stripedhyena_pretrained_8k':
55+
config_path = 'evo/stripedhyena/configs/sh_inference_config_7b.yml'
56+
elif model_name == 'evo-1_stripedhyena_pretrained_131k':
57+
config_path = 'evo/stripedhyena/configs/sh_inference_config_7b_rotary_scale_16.yml'
58+
else:
59+
raise ValueError(f'Invalid model name {model_name}.')
60+
61+
# Load model.
62+
63+
self.model, self.tokenizer = load_checkpoint(
64+
checkpoint_path,
65+
model_type='stripedhyena',
66+
config_path=config_path,
67+
device=device,
68+
)
69+
self.model = self.model.to(device).eval()
70+
71+
self.device = device
72+
73+
74+
def load_checkpoint(
75+
ckpt_path: str,
76+
config_path: str = './evo/stripedhyena/configs/sh_inference_config_7b.yml',
77+
verbose: int = 0,
78+
device: str = 'cuda:0',
79+
**kwargs: dict,
80+
) -> Tuple[StripedHyena, CharLevelTokenizer]:
81+
"""
82+
Loads a checkpoint from a path and corresponding config.
83+
"""
84+
global_config = dotdict(yaml.load(open(config_path), Loader=yaml.FullLoader))
85+
86+
model = StripedHyena(global_config)
87+
tokenizer = CharLevelTokenizer(512)
88+
89+
model.load_state_dict(torch.load(ckpt_path), strict=True)
90+
91+
model.to_bfloat16_except_poles_residues()
92+
93+
return model, tokenizer

0 commit comments

Comments
 (0)