Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single head attention, decoupled LR, autoregressive auxiliary loss, and gradient accumulation #191

Open
wants to merge 63 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d73a38d
add single-head attention https://arxiv.org/abs/1911.11423
lucidrains Sep 3, 2021
52aa3d5
add all stability related logic from public SHA code
lucidrains Sep 10, 2021
e0ec562
allow for removal of ff, given conclusions of https://arxiv.org/abs/2…
lucidrains Sep 10, 2021
59096ba
update to faithful SHA with feedforward instead of Boom
lucidrains Sep 16, 2021
c31866f
bidirectional gives best results
lucidrains Sep 18, 2021
e8d4d2c
cleanup
lucidrains Sep 18, 2021
47dfbae
Merge branch 'nanoporetech:master' into master
lucidrains Oct 7, 2021
45f5205
add SHA with differential learning rates
lucidrains Oct 7, 2021
7bb19db
lint
lucidrains Oct 7, 2021
f75cc14
import SHABlock in trainer file
lucidrains Oct 7, 2021
b2ab746
lower SHA learning rate to 1e-4
lucidrains Oct 8, 2021
f762f7b
add head scale from normformer paper
lucidrains Oct 8, 2021
e891d92
modify single_head_layers hparams to indicate the layer number after …
lucidrains Oct 8, 2021
57d281a
add sandwich norm to SHA
lucidrains Oct 8, 2021
36c903a
add ability to adjust grad clip max norm, in light of attention exper…
lucidrains Oct 9, 2021
c03785a
make feedforward sandwich norm as well, for extra stability
lucidrains Oct 14, 2021
79a9336
first commit for decoder autoregressive auxiliary loss
lucidrains Oct 21, 2021
4f8d6c2
fix alibi
lucidrains Oct 21, 2021
9a60831
decouple gradient clipping of attention parameters from non-attention…
lucidrains Oct 22, 2021
4b59e68
add ability to do gradient accumulation, with effective batch size be…
lucidrains Oct 22, 2021
a3be656
clean up
lucidrains Oct 22, 2021
92d6f10
Decoder module parameters must be designated as attention parameters …
lucidrains Oct 22, 2021
73a44ec
make default AR loss weight a bit higher
lucidrains Oct 22, 2021
b6cff9d
handle losses being a dictionary already
lucidrains Oct 22, 2021
07f3e7e
make sure falling back to not using an AR decoder actually works
lucidrains Oct 23, 2021
decfd2b
bug fix for grad accumulation
lucidrains Oct 24, 2021
f18f5d2
fix all issues
lucidrains Oct 25, 2021
78cd4ad
move alibi to rotary positional embeddings
lucidrains Oct 28, 2021
ce1fa04
fix error with casting to float32
lucidrains Oct 28, 2021
e4a7df9
make sure decoder AR gets 0.1 dropout loss for attention and feedforward
lucidrains Oct 28, 2021
ab87687
make sure relu squared from primer paper is used (accidentally had th…
lucidrains Oct 28, 2021
2b888c3
fix rotary positional embedding
lucidrains Oct 28, 2021
39a4825
address problem with absolute positional embedding and rotary embeddi…
lucidrains Oct 28, 2021
e28fac3
cleanup
lucidrains Oct 29, 2021
8c8a64e
Merge branch 'master' into sha-attn
iiSeymour Nov 3, 2021
0f2fc65
add ability to specify more aggressive gradient clipping for attentio…
lucidrains Nov 5, 2021
2626685
make sure attention uses stable softmax
lucidrains Nov 5, 2021
dba58f2
use --attn-clip instead
lucidrains Nov 5, 2021
771f274
one more stability measure for final layernorm in decoder
lucidrains Nov 6, 2021
9c7111f
add yet another stability measure, from cogview paper
lucidrains Nov 6, 2021
d34dd2f
cross entropy for auxiliary decoder loss should be done in float32
lucidrains Nov 8, 2021
1b70e1b
use amp.autocast to disable mixed precision for cross entropy calc
lucidrains Nov 8, 2021
b6ba887
make sure gradients do not go through numerical stability measures
lucidrains Nov 11, 2021
5c6e200
remove head scaling
lucidrains Nov 11, 2021
9ab7fc4
add pb relax stable softmax technique from cogview paper
lucidrains Nov 11, 2021
b2c0447
remove layerscale
lucidrains Nov 11, 2021
2a2205b
add ability to turn off AR auxiliary loss at a certain epoch, or with…
lucidrains Nov 12, 2021
5ac0bac
use ff-geglu over relu squared for now
lucidrains Nov 12, 2021
498c79d
use stable layernorm from cogview paper for norming the encoder embed…
lucidrains Nov 12, 2021
79d803b
fix bug with stable softmax
lucidrains Nov 13, 2021
0ae286c
add ability to have decoder attend to all encoder layers by means of …
lucidrains Nov 19, 2021
ad2828b
add ability to turn on scaled cosine sim attention
lucidrains Nov 20, 2021
ee5c880
fix bug
lucidrains Nov 20, 2021
48135c6
better init for cosine sim attention learned temp
lucidrains Nov 20, 2021
4231882
prepare for fitting in induced set attention block
lucidrains Nov 25, 2021
0ff4aa1
make learned initial temperature for cosine sim attention customizable
lucidrains Nov 25, 2021
68a6f0d
add induced-set attention blocks, which can be turned on with use_isa…
lucidrains Nov 25, 2021
5cf5be0
ISAB block needs to be included as a module containing attention para…
lucidrains Nov 26, 2021
1520e73
add weight tying feature across transformer blocks, and also set ISAB…
lucidrains Nov 26, 2021
324c9d5
make sure attention head dimension is configurable through toml
lucidrains Nov 26, 2021
e1f7a7c
add comments and docs
lucidrains Nov 26, 2021
0a3cbcd
docstrings for SHA and MHA
lucidrains Nov 26, 2021
541f0c3
set some guardrails
lucidrains Nov 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bonito/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def main(args):
model.decode = model.module.decode
model.alphabet = model.module.alphabet

trainer = Trainer(model, device, train_loader, valid_loader, use_amp=half_supported() and not args.no_amp)
trainer.fit(workdir, args.epochs, args.lr, last_epoch=last_epoch)
trainer = Trainer(model, device, train_loader, valid_loader, grad_clip_max_norm=args.clip, use_amp=half_supported() and not args.no_amp)
trainer.fit(workdir, args.epochs, args.lr, last_epoch=last_epoch, sha_lr=args.sha_lr)

def argparser():
parser = ArgumentParser(
Expand All @@ -87,6 +87,8 @@ def argparser():
parser.add_argument("--directory", default=default_data)
parser.add_argument("--device", default="cuda")
parser.add_argument("--lr", default=2e-3, type=float)
parser.add_argument("--sha-lr", default=1e-4, type=float)
parser.add_argument("--clip", default=2., type=float)
parser.add_argument("--seed", default=25, type=int)
parser.add_argument("--epochs", default=5, type=int)
parser.add_argument("--batch", default=64, type=int)
Expand Down
60 changes: 43 additions & 17 deletions bonito/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import torch
import numpy as np
from bonito.nn import Module, Convolution, LinearCRFEncoder, Serial, Permute, layers, from_dict
from bonito.nn import Module, Convolution, SHABlock, LinearCRFEncoder, Serial, Permute, layers, Decoder, from_dict

import seqdist.sparse
from seqdist.ctc_simple import logZ_cupy, viterbi_alignments
from seqdist.core import SequenceDist, Max, Log, semiring

from collections import Counter

def get_stride(m):
if hasattr(m, 'stride'):
Expand Down Expand Up @@ -139,30 +139,55 @@ def conv(c_in, c_out, ks, stride=1, bias=False, activation=None):
return Convolution(c_in, c_out, ks, stride=stride, padding=ks//2, bias=bias, activation=activation)


def rnn_encoder(n_base, state_len, insize=1, stride=5, winlen=19, activation='swish', rnn_type='lstm', features=768, scale=5.0, blank_score=None):
def rnn_encoder(n_base, state_len, insize=1, stride=5, winlen=19, activation='swish', rnn_type='lstm', features=768, scale=5.0, blank_score=None, single_head_layers=[], num_attn_heads=1, attn_dropout=0., ff_dropout=0., sha_sandwich_norm=False):
rnn = layers[rnn_type]
return Serial([
conv(insize, 4, ks=5, bias=True, activation=activation),
conv(4, 16, ks=5, bias=True, activation=activation),
conv(16, features, ks=winlen, stride=stride, bias=True, activation=activation),
Permute([2, 0, 1]),
rnn(features, features, reverse=True), rnn(features, features),
rnn(features, features, reverse=True), rnn(features, features),
rnn(features, features, reverse=True),
LinearCRFEncoder(features, n_base, state_len, bias=True, activation='tanh', scale=scale, blank_score=blank_score)

rnns = [
rnn(features, features, reverse=True), rnn(features, features),
rnn(features, features, reverse=True), rnn(features, features),
rnn(features, features, reverse=True)
]

backbone = []
single_head_layers_count = Counter(single_head_layers) # allows for multiple SHA blocks per layer

for layer, rnn in enumerate(rnns):
layer_num = layer + 1
backbone.append(rnn)

if layer_num in single_head_layers_count:
backbone.extend([SHABlock(features, attn_dropout=attn_dropout, ff_dropout=ff_dropout, num_attn_heads=num_attn_heads, sha_sandwich_norm=sha_sandwich_norm) for _ in range(single_head_layers_count[layer_num])])

encoder = Serial([
conv(insize, 4, ks=5, bias=True, activation=activation),
conv(4, 16, ks=5, bias=True, activation=activation),
conv(16, features, ks=winlen, stride=stride, bias=True, activation=activation),
Permute([2, 0, 1]),
*backbone
])

linear_crf = LinearCRFEncoder(features, n_base, state_len, bias=True, activation='tanh', scale=scale, blank_score=blank_score)
return encoder, linear_crf

class SeqdistModel(Module):
def __init__(self, encoder, seqdist):
def __init__(self, encoder, linear_crf, decoder, seqdist):
super().__init__()
self.seqdist = seqdist
self.encoder = encoder
self.decoder = decoder
self.linear_crf = linear_crf
self.stride = get_stride(encoder)
self.alphabet = seqdist.alphabet

def forward(self, x):
return self.encoder(x).to(torch.float32)
def forward(self, x, targets = None):
encoded = self.encoder(x)
scores = self.linear_crf(encoded.to(torch.float32))

if targets is not None:
aux_loss = self.decoder(targets, encoded, return_loss=True) if self.decoder is not None else 0
return scores, aux_loss

return scores

def decode_batch(self, x):
scores = self.seqdist.posteriors(x.to(torch.float32)) + 1e-8
Expand All @@ -183,6 +208,7 @@ def __init__(self, config):
if 'type' in config['encoder']: #new-style config
encoder = from_dict(config['encoder'])
else: #old-style
encoder = rnn_encoder(seqdist.n_base, seqdist.state_len, insize=config['input']['features'], **config['encoder'])
super().__init__(encoder, seqdist)
encoder, linear_crf = rnn_encoder(seqdist.n_base, seqdist.state_len, insize=config['input']['features'], **config['encoder'])
decoder = Decoder(config['encoder']['features'], **config['aux_decoder']) if config['aux_decoder']['loss_weight'] > 0 else None
super().__init__(encoder, linear_crf, decoder, seqdist)
self.config = config
12 changes: 12 additions & 0 deletions bonito/models/configs/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,17 @@ rnn_type = "lstm"
activation = "swish"
blank_score = 2.0

single_head_layers = [ 3, 4 ]
attn_dropout = 0.1
ff_dropout = 0.1
num_attn_heads = 1
sha_sandwich_norm = true

[aux_decoder]
loss_weight = 0.25
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set this to 0 to turn off auxiliary AR loss

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protocol should be to start off with 0.25 and search for higher values up to 1. if you see continued improvement

depth = 2
heads = 4
max_seq_len = 1024

[global_norm]
state_len = 5
Loading