-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrnn_classifier.py
81 lines (54 loc) · 1.57 KB
/
rnn_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import sys
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
from locked_dropout import *
from bert_loader import *
from bert_encoder import *
from util import *
class RnnClassifier(torch.nn.Module):
def __init__(self, opt, shared):
super(RnnClassifier, self).__init__()
self.opt = opt
self.shared = shared
bidir = True
hidden_state = opt.hidden_size if not bidir else opt.hidden_size//2
self.rnn = build_rnn(
opt.rnn_type,
input_size=opt.bert_size,
hidden_size=hidden_state,
num_layers=1,
bias=True,
batch_first=True,
dropout=opt.dropout,
bidirectional=bidir)
self.drop = LockedDropout(opt.dropout)
self.linear = nn.Sequential(
#nn.Dropout(opt.dropout), # no dropout here according to huggingface
nn.Linear(opt.hidden_size+opt.bert_size, 2)) # 1 for start, 1 for end
def rnn_over(self, rnn, enc):
enc, _ = rnn(self.drop(enc))
return enc
def fp32(self, x):
if x.dtype != torch.float32:
return x.float()
return x
def forward(self, concated):
batch_l, concated_l, bert_size = concated.shape
concated = self.fp32(concated)
rnn_enc = self.rnn_over(self.rnn, concated).contiguous()
phi = torch.cat([rnn_enc, concated], 2)
scores = self.linear(phi.view(-1, bert_size+self.opt.hidden_size)).view(batch_l, concated_l, 2)
log_p = nn.LogSoftmax(1)(scores)
log_p1 = log_p[:, :, 0]
log_p2 = log_p[:, :, 1]
self.shared.y_scores = scores
return [log_p1, log_p2]
def begin_pass(self):
pass
def end_pass(self):
pass
if __name__ == '__main__':
pass