Skip to content

Commit 3a26abf

Browse files
authored
Merge pull request #48 from nanoporetech/prepare-training-data
Prepare training data
2 parents 67fb27f + c2bed31 commit 3a26abf

File tree

6 files changed

+146
-39
lines changed

6 files changed

+146
-39
lines changed

README.md

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,7 @@
22

33
[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito)
44

5-
A convolutional basecaller inspired by QuartzNet.
6-
7-
## Features
8-
9-
- Raw signal input.
10-
- Simple 5 state output `{BLANK, A, C, G, T}`.
11-
- CTC training.
12-
- Small Python codebase.
13-
14-
## Basecalling
5+
A PyTorch Basecaller for Oxford Nanopore Reads.
156

167
```bash
178
$ pip install ont-bonito
@@ -23,27 +14,31 @@ If a reference is provided in either `.fasta` or `.mmi` format then bonito will
2314
```bash
2415
$ bonito basecaller dna_r9.4.1 --reference reference.mmi /data/reads > basecalls.sam
2516
```
26-
27-
If you have a `turing` or `volta` GPU the `--half` flag can be uses to increase performance.
2817

2918
## Pair Decoding
3019

3120
Pair decoding takes a template and complement read to produce higher quaility calls.
3221

33-
```
22+
```bash
3423
$ bonito basecaller pairs.csv /data/reads > basecalls.fasta
3524
```
3625

3726
The `pairs.csv` file is expected to contain pairs of read ids per line *(seperated by a single space)*.
3827

39-
4028
## Training your own model
4129

42-
To train your own model first download the training data.
30+
To train a model using your own reads, first basecall the reads with the additional `--save-ctc` flag and use the output directory as the input directory for training.
31+
32+
```bash
33+
$ bonito basecaller dna_r9.4.1 --save-ctc --reference reference.mmi /data/reads > /data/training/ctc-data/basecalls.sam
34+
$ bonito train --amp --directory /data/training/ctc-data /data/training/model-dir
35+
```
36+
37+
If you are interested in method development and don't have you own set of reads then a pre-prepared set is provide.
4338

4439
```bash
4540
$ bonito download --training
46-
$ bonito train --amp /data/model-dir
41+
$ bonito train --amp /data/training/model-dir
4742
```
4843

4944
Automatic mixed precision can be used to speed up training with the `--amp` flag *(however [apex](https://github.com/nvidia/apex#quick-start) needs to be installed manually)*.
@@ -55,7 +50,7 @@ $ export CUDA_VISIBLE_DEVICES=0,1,2,3
5550
$ bonito train --amp --multi-gpu --batch 256 /data/model-dir
5651
```
5752

58-
To evaluate the pretrained model run `bonito evaluate dna_r9.4.1 --half`.
53+
To evaluate the pretrained model run `bonito evaluate dna_r9.4.1`.
5954

6055
For a model you have trainined yourself, replace `dna_r9.4.1` with the model directory.
6156

bonito/basecaller.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,52 @@
77
from datetime import timedelta
88
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
99

10-
from bonito.util import load_model, chunk, stitch
11-
from bonito.io import DecoderWriterPool, PreprocessReader
10+
from bonito.util import load_model, chunk, stitch, half_supported
11+
from bonito.io import DecoderWriterPool, PreprocessReader, CTCWriter
1212

1313
import torch
1414
import numpy as np
15+
from mappy import Aligner
1516

1617

1718
def main(args):
1819

20+
if args.save_ctc and not args.reference:
21+
sys.stderr.write("> a reference is needed to output ctc training data\n")
22+
exit(1)
23+
24+
if args.save_ctc:
25+
args.overlap = 900
26+
args.chunksize = 3600
27+
1928
sys.stderr.write("> loading model\n")
2029

2130
model = load_model(
2231
args.model_directory, args.device, weights=int(args.weights),
2332
half=args.half, chunksize=args.chunksize, use_rt=args.cudart,
2433
)
2534

35+
if args.reference:
36+
sys.stderr.write("> loading reference\n")
37+
aligner = Aligner(args.reference, preset='ont-map')
38+
if not aligner:
39+
sys.stderr.write("> failed to load/build index\n")
40+
sys.exit(1)
41+
else:
42+
aligner = None
43+
2644
samples = 0
2745
num_reads = 0
2846
max_read_size = 4e6
2947
dtype = np.float16 if args.half else np.float32
48+
ctc_writer = CTCWriter(model, aligner)
3049
reader = PreprocessReader(args.reads_directory)
31-
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, reference=args.reference)
50+
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, aligner=aligner)
3251

3352
t0 = time.perf_counter()
3453
sys.stderr.write("> calling\n")
3554

36-
with writer, reader, torch.no_grad():
55+
with writer, ctc_writer, reader, torch.no_grad():
3756

3857
while True:
3958

@@ -51,10 +70,12 @@ def main(args):
5170
raw_data = torch.tensor(read.signal.astype(dtype))
5271
chunks = chunk(raw_data, args.chunksize, args.overlap)
5372

54-
posteriors = model(chunks.to(args.device)).cpu().numpy()
55-
posteriors = stitch(posteriors, args.overlap // model.stride // 2)
73+
posteriors_ = model(chunks.to(args.device)).cpu().numpy()
74+
posteriors = stitch(posteriors_, args.overlap // model.stride // 2)
5675

5776
writer.queue.put((read, posteriors[:raw_data.shape[0]]))
77+
if args.save_ctc and len(raw_data) > args.chunksize:
78+
ctc_writer.queue.put((chunks.numpy(), posteriors_))
5879

5980
duration = time.perf_counter() - t0
6081

@@ -77,7 +98,8 @@ def argparser():
7798
parser.add_argument("--beamsize", default=5, type=int)
7899
parser.add_argument("--chunksize", default=0, type=int)
79100
parser.add_argument("--overlap", default=0, type=int)
80-
parser.add_argument("--half", action="store_true", default=False)
101+
parser.add_argument("--half", action="store_true", default=half_supported())
81102
parser.add_argument("--fastq", action="store_true", default=False)
82103
parser.add_argument("--cudart", action="store_true", default=False)
104+
parser.add_argument("--save-ctc", action="store_true", default=False)
83105
return parser

bonito/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from bonito.training import ChunkDataSet
1212
from bonito.util import accuracy, poa, decode_ref
13-
from bonito.util import init, load_data, load_model
13+
from bonito.util import init, load_data, load_model, half_supported
1414

1515
from torch.utils.data import DataLoader
1616

@@ -80,7 +80,7 @@ def argparser():
8080
parser.add_argument("model_directory")
8181
parser.add_argument("--directory", default=None)
8282
parser.add_argument("--device", default="cuda")
83-
parser.add_argument("--half", action="store_true", default=False)
83+
parser.add_argument("--half", action="store_true", default=half_supported())
8484
parser.add_argument("--seed", default=9, type=int)
8585
parser.add_argument("--weights", default="0", type=str)
8686
parser.add_argument("--chunks", default=500, type=int)

bonito/io.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from glob import glob
88
from warnings import warn
99
from logging import getLogger
10-
from os.path import realpath, splitext
10+
from os.path import realpath, splitext, dirname
1111
from multiprocessing import Process, Queue, Lock, cpu_count
1212

1313
import numpy as np
1414
from tqdm import tqdm
15-
from mappy import Aligner, revcomp
15+
from mappy import revcomp
1616

1717
import bonito
18+
from bonito.training import ChunkDataSet
19+
from bonito.convert import filter_chunks
1820
from bonito.util import get_raw_data, mean_qscore_from_qstring
1921

2022

@@ -214,25 +216,105 @@ def stop(self):
214216
self.join()
215217

216218

219+
class CTCWriter(Process):
220+
"""
221+
CTC writer process that writes output numpy training data
222+
"""
223+
def __init__(self, model, aligner, min_coverage=0.90, min_accuracy=0.90):
224+
super().__init__()
225+
self.model = model
226+
self.queue = Queue()
227+
self.aligner = aligner
228+
self.min_coverage = min_coverage
229+
self.min_accuracy = min_accuracy
230+
231+
def __enter__(self):
232+
self.start()
233+
return self
234+
235+
def __exit__(self, exc_type, exc_val, exc_tb):
236+
self.queue.put(None)
237+
self.stop()
238+
239+
def run(self):
240+
241+
chunks = []
242+
targets = []
243+
target_lens = []
244+
245+
while True:
246+
247+
job = self.queue.get()
248+
if job is None: break
249+
chunks_, predictions = job
250+
251+
# convert logprobs to probs
252+
predictions = np.exp(predictions.astype(np.float32))
253+
254+
for chunk, pred in zip(chunks_, predictions):
255+
256+
sequence = self.model.decode(pred)
257+
258+
if not sequence:
259+
continue
260+
261+
for mapping in self.aligner.map(sequence):
262+
cov = (mapping.q_en - mapping.q_st) / len(sequence)
263+
acc = mapping.mlen / mapping.blen
264+
refseq = self.aligner.seq(mapping.ctg, mapping.r_st + 1, mapping.r_en)
265+
if 'N' in refseq: continue
266+
if mapping.strand == -1: refseq = revcomp(refseq)
267+
break
268+
else:
269+
continue
270+
271+
if acc > self.min_accuracy and cov > self.min_accuracy:
272+
chunks.append(chunk.squeeze())
273+
targets.append([
274+
int(x) for x in refseq.translate({65: '1', 67: '2', 71: '3', 84: '4'})
275+
])
276+
target_lens.append(len(refseq))
277+
278+
if len(chunks) == 0: return
279+
280+
chunks = np.array(chunks, dtype=np.float32)
281+
chunk_lens = np.full(chunks.shape[0], chunks.shape[1], dtype=np.int16)
282+
283+
targets_ = np.zeros((chunks.shape[0], max(target_lens)), dtype=np.uint8)
284+
for idx, target in enumerate(targets): targets_[idx, :len(target)] = target
285+
target_lens = np.array(target_lens, dtype=np.uint16)
286+
287+
training = ChunkDataSet(chunks, chunk_lens, targets_, target_lens)
288+
training = filter_chunks(training)
289+
290+
output_directory = '.' if sys.stdout.isatty() else dirname(realpath('/dev/fd/1'))
291+
np.save(os.path.join(output_directory, "chunks.npy"), training.chunks.squeeze(1))
292+
np.save(os.path.join(output_directory, "chunk_lengths.npy"), training.chunk_lengths)
293+
np.save(os.path.join(output_directory, "references.npy"), training.targets)
294+
np.save(os.path.join(output_directory, "reference_lengths.npy"), training.target_lengths)
295+
296+
sys.stderr.write("> written ctc training data\n")
297+
sys.stderr.write(" - chunks.npy with shape (%s)\n" % ','.join(map(str, training.chunks.squeeze(1).shape)))
298+
sys.stderr.write(" - chunk_lengths.npy with shape (%s)\n" % ','.join(map(str, training.chunk_lengths.shape)))
299+
sys.stderr.write(" - references.npy with shape (%s)\n" % ','.join(map(str, training.targets.shape)))
300+
sys.stderr.write(" - reference_lengths.npy shape (%s)\n" % ','.join(map(str, training.target_lengths.shape)))
301+
302+
def stop(self):
303+
self.join()
304+
305+
217306
class DecoderWriterPool:
218307
"""
219308
Simple pool of decoder writers
220309
"""
221-
def __init__(self, model, procs=4, reference=None, **kwargs):
310+
def __init__(self, model, procs=4, aligner=None, **kwargs):
222311
self.lock = Lock()
223312
self.queue = Queue()
224313
self.procs = procs if procs else cpu_count()
314+
self.aligner = aligner
225315
self.decoders = []
226316

227-
if reference:
228-
sys.stderr.write("> loading reference\n")
229-
aligner = Aligner(reference, preset='ont-map')
230-
if not aligner:
231-
sys.stderr.write("> failed to load/build index\n")
232-
sys.exit(1)
233-
write_sam_header(aligner)
234-
else:
235-
aligner = None
317+
if aligner: write_sam_header(aligner)
236318

237319
with open(summary_file(), 'w') as summary:
238320
write_summary_header(summary, alignment=aligner)

bonito/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def argparser():
111111
parser.add_argument("--device", default="cuda")
112112
parser.add_argument("--lr", default=1e-3, type=float)
113113
parser.add_argument("--seed", default=25, type=int)
114-
parser.add_argument("--epochs", default=400, type=int)
114+
parser.add_argument("--epochs", default=20, type=int)
115115
parser.add_argument("--batch", default=32, type=int)
116116
parser.add_argument("--chunks", default=2000000, type=int)
117117
parser.add_argument("--validation_split", default=0.97, type=float)

bonito/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import parasail
1717
import numpy as np
1818
from scipy.signal import find_peaks
19+
from torch.cuda import get_device_capability
1920
from ont_fast5_api.fast5_interface import get_fast5_file
2021

2122
try:
@@ -81,6 +82,13 @@ def init(seed, device):
8182
assert(torch.cuda.is_available())
8283

8384

85+
def half_supported():
86+
"""
87+
Returns whether FP16 is support on the GPU
88+
"""
89+
return get_device_capability()[0] >= 7
90+
91+
8492
def phred(prob, scale=1.0, bias=0.0):
8593
"""
8694
Converts `prob` into a ascii encoded phred quality score between 0 and 40.

0 commit comments

Comments
 (0)