Skip to content

Commit

Permalink
Merge branch 'v4' into 'master'
Browse files Browse the repository at this point in the history
v4 Model Changes

See merge request machine-learning/bonito!115
  • Loading branch information
iiSeymour committed Nov 8, 2022
2 parents dd57409 + da7fe39 commit 804baf0
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 39 deletions.
46 changes: 24 additions & 22 deletions bonito/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,6 @@ def main(args):
init(args.seed, args.device, (not args.nondeterministic))
device = torch.device(args.device)

print("[loading data]")
try:
train_loader_kwargs, valid_loader_kwargs = load_numpy(
args.chunks, args.directory
)
except FileNotFoundError:
train_loader_kwargs, valid_loader_kwargs = load_script(
args.directory,
seed=args.seed,
chunks=args.chunks,
valid_chunks=args.valid_chunks
)

loader_kwargs = {
"batch_size": args.batch, "num_workers": 4, "pin_memory": True
}
train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

if not args.pretrained:
config = toml.load(args.config)
else:
Expand All @@ -65,16 +46,37 @@ def main(args):

argsdict = dict(training=vars(args))

os.makedirs(workdir, exist_ok=True)
toml.dump({**config, **argsdict}, open(os.path.join(workdir, 'config.toml'), 'w'))

print("[loading model]")
if args.pretrained:
print("[using pretrained model {}]".format(args.pretrained))
model = load_model(args.pretrained, device, half=False)
else:
model = load_symbol(config, 'Model')(config)

print("[loading data]")
try:
train_loader_kwargs, valid_loader_kwargs = load_numpy(
args.chunks, args.directory
)
except FileNotFoundError:
train_loader_kwargs, valid_loader_kwargs = load_script(
args.directory,
seed=args.seed,
chunks=args.chunks,
valid_chunks=args.valid_chunks,
n_pre_context_bases=getattr(model, "n_pre_context_bases", 0),
n_post_context_bases=getattr(model, "n_post_context_bases", 0),
)

loader_kwargs = {
"batch_size": args.batch, "num_workers": 4, "pin_memory": True
}
train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

os.makedirs(workdir, exist_ok=True)
toml.dump({**config, **argsdict}, open(os.path.join(workdir, 'config.toml'), 'w'))

if config.get("lr_scheduler"):
sched_config = config["lr_scheduler"]
lr_scheduler_fn = getattr(
Expand Down
41 changes: 31 additions & 10 deletions bonito/crf/model.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
"""
Bonito CTC-CRF Model.
"""
import math

import torch
import numpy as np

import koi
from koi.ctc import SequenceDist, Max, Log, semiring
from koi.ctc import logZ_cu, viterbi_alignments, logZ_cu_sparse, bwd_scores_cu_sparse, fwd_scores_cu_sparse

from bonito.nn import Module, Convolution, LinearCRFEncoder, Serial, Permute, layers, from_dict


def get_stride(m):
if hasattr(m, 'stride'):
return m.stride if isinstance(m.stride, int) else m.stride[0]
if isinstance(m, Convolution):
return get_stride(m.conv)
if isinstance(m, Serial):
return int(np.prod([get_stride(x) for x in m]))
return 1
children = list(m.children())
if len(children) == 0:
if hasattr(m, "stride"):
stride = m.stride
if isinstance(stride, int):
return stride
return math.prod(stride)
return 1
return math.prod(get_stride(c) for c in children)


class CTC_CRF(SequenceDist):

def __init__(self, state_len, alphabet):
def __init__(self, state_len, alphabet, n_pre_context_bases=0, n_post_context_bases=0):
super().__init__()
self.alphabet = alphabet
self.state_len = state_len
self.n_pre_context_bases = n_pre_context_bases
self.n_post_context_bases = n_post_context_bases
self.n_base = len(alphabet[1:])
self.idx = torch.cat([
torch.arange(self.n_base**(self.state_len))[:, None],
Expand Down Expand Up @@ -157,12 +163,17 @@ def rnn_encoder(n_base, state_len, insize=1, stride=5, winlen=19, activation='sw


class SeqdistModel(Module):
def __init__(self, encoder, seqdist):
def __init__(self, encoder, seqdist, n_pre_post_context_bases=None):
super().__init__()
self.seqdist = seqdist
self.encoder = encoder
self.stride = get_stride(encoder)
self.alphabet = seqdist.alphabet
if n_pre_post_context_bases is None:
self.n_pre_context_bases = self.seqdist.state_len - 1
self.n_post_context_bases = 1
else:
self.n_pre_context_bases, self.n_post_context_bases = n_pre_post_context_bases

def forward(self, x):
return self.encoder(x)
Expand All @@ -178,6 +189,15 @@ def decode(self, x):
def loss(self, scores, targets, target_lengths, **kwargs):
return self.seqdist.ctc_loss(scores.to(torch.float32), targets, target_lengths, **kwargs)

def use_koi(self, **kwargs):
self.encoder = koi.lstm.update_graph(
self.encoder,
batchsize=kwargs["batchsize"],
chunksize=kwargs["chunksize"] // self.stride,
quantize=kwargs["quantize"],
)


class Model(SeqdistModel):

def __init__(self, config):
Expand All @@ -189,5 +209,6 @@ def __init__(self, config):
encoder = from_dict(config['encoder'])
else: #old-style
encoder = rnn_encoder(seqdist.n_base, seqdist.state_len, insize=config['input']['features'], **config['encoder'])
super().__init__(encoder, seqdist)

super().__init__(encoder, seqdist, n_pre_post_context_bases=config['input'].get('n_pre_post_context_bases'))
self.config = config
2 changes: 1 addition & 1 deletion bonito/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def load_script(directory, name="dataset", suffix=".py", **kwargs):
spec = importlib.util.spec_from_file_location(name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
loader = module.Loader()
loader = module.Loader(**kwargs)
return loader.train_loader_kwargs(**kwargs), loader.valid_loader_kwargs(**kwargs)


Expand Down
30 changes: 29 additions & 1 deletion bonito/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,34 @@ def register(layer):
register(torch.nn.Tanh)


@register
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.linear = torch.nn.Linear(
in_features=in_features, out_features=out_features, bias=bias
)

def forward(self, x):
return self.linear(x)

def to_dict(self, include_weights=False):
res = {
"in_features": self.in_features,
"out_features": self.out_features,
"bias": self.bias,
}
if include_weights:
res['params'] = {
'W': self.linear.weight,
'b': self.linear.bias if self.bias is not None else []
}
return res


@register
class Swish(torch.nn.SiLU):
pass
Expand Down Expand Up @@ -245,7 +273,7 @@ def disable_state_bias(self):
class LSTM(RNNWrapper):

def __init__(self, size, insize, bias=True, reverse=False):
super().__init__(torch.nn.LSTM, size, insize, bias=bias, reverse=reverse)
super().__init__(torch.nn.LSTM, insize, size, bias=bias, reverse=reverse)

def to_dict(self, include_weights=False):
res = {
Expand Down
6 changes: 6 additions & 0 deletions bonito/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def validate_one_step(self, batch):
else:
seqs = [self.model.decode(x) for x in permute(scores, 'TNC', 'NTC')]
refs = [decode_ref(target, self.model.alphabet) for target in targets]

n_pre = getattr(self.model, "n_pre_context_bases", 0)
n_post = getattr(self.model, "n_post_context_bases", 0)
if n_pre > 0 or n_post > 0:
refs = [ref[n_pre:len(ref)-n_post] for ref in refs]

accs = [
accuracy(ref, seq, min_coverage=0.5) if len(seq) else 0. for ref, seq in zip(refs, seqs)
]
Expand Down
7 changes: 3 additions & 4 deletions bonito/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def set_config_defaults(config, chunksize=None, batchsize=None, overlap=None, qu
# use `value or dict.get(key)` rather than `dict.get(key, value)` to make
# flags override values in config
basecall_params["chunksize"] = chunksize or basecall_params.get("chunksize", 4000)
basecall_params["overlap"] = overlap or basecall_params.get("overlap", 500)
basecall_params["overlap"] = overlap if overlap is not None else basecall_params.get("overlap", 500)
basecall_params["batchsize"] = batchsize or basecall_params.get("batchsize", 64)
basecall_params["quantize"] = basecall_params.get("quantize") if quantize is None else quantize
config["basecaller"] = basecall_params
Expand Down Expand Up @@ -293,9 +293,8 @@ def _load_model(model_file, config, device, half=None, use_koi=False):
# overlap must be even multiple of stride for correct stitching
config["basecaller"]["overlap"] -= config["basecaller"]["overlap"] % (model.stride * 2)

if config["model"]["package"] == "bonito.crf" and use_koi:
model.encoder = koi.lstm.update_graph(
model.encoder,
if use_koi:
model.use_koi(
batchsize=config["basecaller"]["batchsize"],
chunksize=config["basecaller"]["chunksize"] // model.stride,
quantize=config["basecaller"]["quantize"],
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pysam==0.19.1
parasail==1.2
pandas>1,<2
requests==2.25.1
ont-koi==0.1.1
ont-koi==0.1.4
onnxruntime==1.12.1
ont-remora==1.1.1
ont-fast5-api==3.3.0
Expand Down

0 comments on commit 804baf0

Please sign in to comment.