-
Notifications
You must be signed in to change notification settings - Fork 37
/
model.py
84 lines (68 loc) · 2.94 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class RNN(nn.Module):
def __init__(self, vocab_size, embed_size, num_output, rnn_model='LSTM', use_last=True, embedding_tensor=None,
padding_index=0, hidden_size=64, num_layers=1, batch_first=True):
"""
Args:
vocab_size: vocab size
embed_size: embedding size
num_output: number of output (classes)
rnn_model: LSTM or GRU
use_last: bool
embedding_tensor:
padding_index:
hidden_size: hidden size of rnn module
num_layers: number of layers in rnn module
batch_first: batch first option
"""
super(RNN, self).__init__()
self.use_last = use_last
# embedding
self.encoder = None
if torch.is_tensor(embedding_tensor):
self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index, _weight=embedding_tensor)
self.encoder.weight.requires_grad = False
else:
self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index)
self.drop_en = nn.Dropout(p=0.6)
# rnn module
if rnn_model == 'LSTM':
self.rnn = nn.LSTM( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5,
batch_first=True, bidirectional=True)
elif rnn_model == 'GRU':
self.rnn = nn.GRU( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5,
batch_first=True, bidirectional=True)
else:
raise LookupError(' only support LSTM and GRU')
self.bn2 = nn.BatchNorm1d(hidden_size*2)
self.fc = nn.Linear(hidden_size*2, num_output)
def forward(self, x, seq_lengths):
'''
Args:
x: (batch, time_step, input_size)
Returns:
num_output size
'''
x_embed = self.encoder(x)
x_embed = self.drop_en(x_embed)
packed_input = pack_padded_sequence(x_embed, seq_lengths.cpu().numpy(),batch_first=True)
# r_out shape (batch, time_step, output_size)
# None is for initial hidden state
packed_output, ht = self.rnn(packed_input, None)
out_rnn, _ = pad_packed_sequence(packed_output, batch_first=True)
row_indices = torch.arange(0, x.size(0)).long()
col_indices = seq_lengths - 1
if next(self.parameters()).is_cuda:
row_indices = row_indices.cuda()
col_indices = col_indices.cuda()
if self.use_last:
last_tensor=out_rnn[row_indices, col_indices, :]
else:
# use mean
last_tensor = out_rnn[row_indices, :, :]
last_tensor = torch.mean(last_tensor, dim=1)
fc_input = self.bn2(last_tensor)
out = self.fc(fc_input)
return out