Skip to content

Commit

Permalink
Merge pull request #48 from nanoporetech/prepare-training-data
Browse files Browse the repository at this point in the history
Prepare training data
  • Loading branch information
iiSeymour authored Sep 4, 2020
2 parents 67fb27f + c2bed31 commit 3a26abf
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 39 deletions.
29 changes: 12 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito)

A convolutional basecaller inspired by QuartzNet.

## Features

- Raw signal input.
- Simple 5 state output `{BLANK, A, C, G, T}`.
- CTC training.
- Small Python codebase.

## Basecalling
A PyTorch Basecaller for Oxford Nanopore Reads.

```bash
$ pip install ont-bonito
Expand All @@ -23,27 +14,31 @@ If a reference is provided in either `.fasta` or `.mmi` format then bonito will
```bash
$ bonito basecaller dna_r9.4.1 --reference reference.mmi /data/reads > basecalls.sam
```

If you have a `turing` or `volta` GPU the `--half` flag can be uses to increase performance.

## Pair Decoding

Pair decoding takes a template and complement read to produce higher quaility calls.

```
```bash
$ bonito basecaller pairs.csv /data/reads > basecalls.fasta
```

The `pairs.csv` file is expected to contain pairs of read ids per line *(seperated by a single space)*.


## Training your own model

To train your own model first download the training data.
To train a model using your own reads, first basecall the reads with the additional `--save-ctc` flag and use the output directory as the input directory for training.

```bash
$ bonito basecaller dna_r9.4.1 --save-ctc --reference reference.mmi /data/reads > /data/training/ctc-data/basecalls.sam
$ bonito train --amp --directory /data/training/ctc-data /data/training/model-dir
```

If you are interested in method development and don't have you own set of reads then a pre-prepared set is provide.

```bash
$ bonito download --training
$ bonito train --amp /data/model-dir
$ bonito train --amp /data/training/model-dir
```

Automatic mixed precision can be used to speed up training with the `--amp` flag *(however [apex](https://github.com/nvidia/apex#quick-start) needs to be installed manually)*.
Expand All @@ -55,7 +50,7 @@ $ export CUDA_VISIBLE_DEVICES=0,1,2,3
$ bonito train --amp --multi-gpu --batch 256 /data/model-dir
```

To evaluate the pretrained model run `bonito evaluate dna_r9.4.1 --half`.
To evaluate the pretrained model run `bonito evaluate dna_r9.4.1`.

For a model you have trainined yourself, replace `dna_r9.4.1` with the model directory.

Expand Down
36 changes: 29 additions & 7 deletions bonito/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,52 @@
from datetime import timedelta
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

from bonito.util import load_model, chunk, stitch
from bonito.io import DecoderWriterPool, PreprocessReader
from bonito.util import load_model, chunk, stitch, half_supported
from bonito.io import DecoderWriterPool, PreprocessReader, CTCWriter

import torch
import numpy as np
from mappy import Aligner


def main(args):

if args.save_ctc and not args.reference:
sys.stderr.write("> a reference is needed to output ctc training data\n")
exit(1)

if args.save_ctc:
args.overlap = 900
args.chunksize = 3600

sys.stderr.write("> loading model\n")

model = load_model(
args.model_directory, args.device, weights=int(args.weights),
half=args.half, chunksize=args.chunksize, use_rt=args.cudart,
)

if args.reference:
sys.stderr.write("> loading reference\n")
aligner = Aligner(args.reference, preset='ont-map')
if not aligner:
sys.stderr.write("> failed to load/build index\n")
sys.exit(1)
else:
aligner = None

samples = 0
num_reads = 0
max_read_size = 4e6
dtype = np.float16 if args.half else np.float32
ctc_writer = CTCWriter(model, aligner)
reader = PreprocessReader(args.reads_directory)
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, reference=args.reference)
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, aligner=aligner)

t0 = time.perf_counter()
sys.stderr.write("> calling\n")

with writer, reader, torch.no_grad():
with writer, ctc_writer, reader, torch.no_grad():

while True:

Expand All @@ -51,10 +70,12 @@ def main(args):
raw_data = torch.tensor(read.signal.astype(dtype))
chunks = chunk(raw_data, args.chunksize, args.overlap)

posteriors = model(chunks.to(args.device)).cpu().numpy()
posteriors = stitch(posteriors, args.overlap // model.stride // 2)
posteriors_ = model(chunks.to(args.device)).cpu().numpy()
posteriors = stitch(posteriors_, args.overlap // model.stride // 2)

writer.queue.put((read, posteriors[:raw_data.shape[0]]))
if args.save_ctc and len(raw_data) > args.chunksize:
ctc_writer.queue.put((chunks.numpy(), posteriors_))

duration = time.perf_counter() - t0

Expand All @@ -77,7 +98,8 @@ def argparser():
parser.add_argument("--beamsize", default=5, type=int)
parser.add_argument("--chunksize", default=0, type=int)
parser.add_argument("--overlap", default=0, type=int)
parser.add_argument("--half", action="store_true", default=False)
parser.add_argument("--half", action="store_true", default=half_supported())
parser.add_argument("--fastq", action="store_true", default=False)
parser.add_argument("--cudart", action="store_true", default=False)
parser.add_argument("--save-ctc", action="store_true", default=False)
return parser
4 changes: 2 additions & 2 deletions bonito/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from bonito.training import ChunkDataSet
from bonito.util import accuracy, poa, decode_ref
from bonito.util import init, load_data, load_model
from bonito.util import init, load_data, load_model, half_supported

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -80,7 +80,7 @@ def argparser():
parser.add_argument("model_directory")
parser.add_argument("--directory", default=None)
parser.add_argument("--device", default="cuda")
parser.add_argument("--half", action="store_true", default=False)
parser.add_argument("--half", action="store_true", default=half_supported())
parser.add_argument("--seed", default=9, type=int)
parser.add_argument("--weights", default="0", type=str)
parser.add_argument("--chunks", default=500, type=int)
Expand Down
106 changes: 94 additions & 12 deletions bonito/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from glob import glob
from warnings import warn
from logging import getLogger
from os.path import realpath, splitext
from os.path import realpath, splitext, dirname
from multiprocessing import Process, Queue, Lock, cpu_count

import numpy as np
from tqdm import tqdm
from mappy import Aligner, revcomp
from mappy import revcomp

import bonito
from bonito.training import ChunkDataSet
from bonito.convert import filter_chunks
from bonito.util import get_raw_data, mean_qscore_from_qstring


Expand Down Expand Up @@ -214,25 +216,105 @@ def stop(self):
self.join()


class CTCWriter(Process):
"""
CTC writer process that writes output numpy training data
"""
def __init__(self, model, aligner, min_coverage=0.90, min_accuracy=0.90):
super().__init__()
self.model = model
self.queue = Queue()
self.aligner = aligner
self.min_coverage = min_coverage
self.min_accuracy = min_accuracy

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.queue.put(None)
self.stop()

def run(self):

chunks = []
targets = []
target_lens = []

while True:

job = self.queue.get()
if job is None: break
chunks_, predictions = job

# convert logprobs to probs
predictions = np.exp(predictions.astype(np.float32))

for chunk, pred in zip(chunks_, predictions):

sequence = self.model.decode(pred)

if not sequence:
continue

for mapping in self.aligner.map(sequence):
cov = (mapping.q_en - mapping.q_st) / len(sequence)
acc = mapping.mlen / mapping.blen
refseq = self.aligner.seq(mapping.ctg, mapping.r_st + 1, mapping.r_en)
if 'N' in refseq: continue
if mapping.strand == -1: refseq = revcomp(refseq)
break
else:
continue

if acc > self.min_accuracy and cov > self.min_accuracy:
chunks.append(chunk.squeeze())
targets.append([
int(x) for x in refseq.translate({65: '1', 67: '2', 71: '3', 84: '4'})
])
target_lens.append(len(refseq))

if len(chunks) == 0: return

chunks = np.array(chunks, dtype=np.float32)
chunk_lens = np.full(chunks.shape[0], chunks.shape[1], dtype=np.int16)

targets_ = np.zeros((chunks.shape[0], max(target_lens)), dtype=np.uint8)
for idx, target in enumerate(targets): targets_[idx, :len(target)] = target
target_lens = np.array(target_lens, dtype=np.uint16)

training = ChunkDataSet(chunks, chunk_lens, targets_, target_lens)
training = filter_chunks(training)

output_directory = '.' if sys.stdout.isatty() else dirname(realpath('/dev/fd/1'))
np.save(os.path.join(output_directory, "chunks.npy"), training.chunks.squeeze(1))
np.save(os.path.join(output_directory, "chunk_lengths.npy"), training.chunk_lengths)
np.save(os.path.join(output_directory, "references.npy"), training.targets)
np.save(os.path.join(output_directory, "reference_lengths.npy"), training.target_lengths)

sys.stderr.write("> written ctc training data\n")
sys.stderr.write(" - chunks.npy with shape (%s)\n" % ','.join(map(str, training.chunks.squeeze(1).shape)))
sys.stderr.write(" - chunk_lengths.npy with shape (%s)\n" % ','.join(map(str, training.chunk_lengths.shape)))
sys.stderr.write(" - references.npy with shape (%s)\n" % ','.join(map(str, training.targets.shape)))
sys.stderr.write(" - reference_lengths.npy shape (%s)\n" % ','.join(map(str, training.target_lengths.shape)))

def stop(self):
self.join()


class DecoderWriterPool:
"""
Simple pool of decoder writers
"""
def __init__(self, model, procs=4, reference=None, **kwargs):
def __init__(self, model, procs=4, aligner=None, **kwargs):
self.lock = Lock()
self.queue = Queue()
self.procs = procs if procs else cpu_count()
self.aligner = aligner
self.decoders = []

if reference:
sys.stderr.write("> loading reference\n")
aligner = Aligner(reference, preset='ont-map')
if not aligner:
sys.stderr.write("> failed to load/build index\n")
sys.exit(1)
write_sam_header(aligner)
else:
aligner = None
if aligner: write_sam_header(aligner)

with open(summary_file(), 'w') as summary:
write_summary_header(summary, alignment=aligner)
Expand Down
2 changes: 1 addition & 1 deletion bonito/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def argparser():
parser.add_argument("--device", default="cuda")
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--seed", default=25, type=int)
parser.add_argument("--epochs", default=400, type=int)
parser.add_argument("--epochs", default=20, type=int)
parser.add_argument("--batch", default=32, type=int)
parser.add_argument("--chunks", default=2000000, type=int)
parser.add_argument("--validation_split", default=0.97, type=float)
Expand Down
8 changes: 8 additions & 0 deletions bonito/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import parasail
import numpy as np
from scipy.signal import find_peaks
from torch.cuda import get_device_capability
from ont_fast5_api.fast5_interface import get_fast5_file

try:
Expand Down Expand Up @@ -81,6 +82,13 @@ def init(seed, device):
assert(torch.cuda.is_available())


def half_supported():
"""
Returns whether FP16 is support on the GPU
"""
return get_device_capability()[0] >= 7


def phred(prob, scale=1.0, bias=0.0):
"""
Converts `prob` into a ascii encoded phred quality score between 0 and 40.
Expand Down

0 comments on commit 3a26abf

Please sign in to comment.