Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added contextualized representation to ner model #33

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions model_partial_ner/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ class NERDataset(object):
"""
def __init__(self,
dataset: list,
flm_pad: int,
blm_pad: int,
w_pad: int,
c_pad: int,
token_per_batch: int):
super(NERDataset, self).__init__()
self.dataset = dataset
self.flm_pad = flm_pad
self.blm_pad = blm_pad
self.w_pad = w_pad
self.c_pad = c_pad
self.token_per_batch = token_per_batch
Expand Down Expand Up @@ -176,15 +180,39 @@ def reader(self, device):
batch_idx = self.shuffle_list[cur_idx]
batch = self.dataset[self.index_list[batch_idx]: self.index_list[batch_idx + 1]]
cur_seq_length = len(batch[0][0])
batch_size = len(batch)
batch_idx = range(batch_size)

flm_t, blm_t, blm_ind, lm_index = None, None, None, None
if batch[0][6] is not None:
flm_t = list()
blm_t = list()
blm_ind = list()
word_padded_len = max([len(tup[6]) for tup in batch])
for instance_ind in range(batch_size):
instance = batch[instance_ind]
word_padded_len_ins = word_padded_len - len(instance[6])
flm_t.append(instance[6] + [self.flm_pad] + [self.flm_pad] * word_padded_len_ins)
blm_t.append([self.blm_pad] + instance[7][::-1] + [self.blm_pad] * word_padded_len_ins)
tmp_p = list(range(len(instance[7]), -1, -1)) + list(range(len(instance[7])+1, word_padded_len+1))
blm_ind.append([x * batch_size + instance_ind for x in tmp_p])
flm_t = torch.LongTensor(flm_t).transpose(0, 1).contiguous().to(device)
blm_t = torch.LongTensor(blm_t).transpose(0, 1).contiguous().to(device)
blm_ind = torch.LongTensor(blm_ind).transpose(0, 1).contiguous().view(-1).to(device)
lm_index = torch.LongTensor([tup[8] + [word_padded_len] * (cur_seq_length - len(tup[0])) for tup in batch])
lm_index[lm_index==-1] = word_padded_len + 1
lm_index = lm_index.to(device)

word_t = torch.LongTensor([tup[0] + [self.w_pad] * (cur_seq_length - len(tup[0])) for tup in batch]).to(device)
char_t = torch.LongTensor([tup[1] + [self.c_pad] * (cur_seq_length - len(tup[0])) for tup in batch]).to(device)
chunk_mask = torch.ByteTensor([tup[2] + [0] * (cur_seq_length - len(tup[2])) for tup in batch]).to(device)
chunk_label = torch.FloatTensor([label for tup in batch for label in tup[3]]).to(device)
type_mask = torch.ByteTensor([mask for tup in batch for mask in tup[4]]).to(device)
label_list = [label for tup in batch for label in tup[5]]
type_label = torch.FloatTensor(label_list[0:-1]).to(device)

cur_idx += 1
yield word_t, char_t, chunk_mask, chunk_label, type_mask, type_label
yield word_t, char_t, chunk_mask, chunk_label, type_mask, type_label, flm_t, blm_t, blm_ind, lm_index
self.shuffle()

class TrainDataset(object):
Expand All @@ -206,6 +234,8 @@ class TrainDataset(object):
"""
def __init__(self,
dataset_name: str,
flm_pad: int,
blm_pad: int,
w_pad: int,
c_pad: int,
token_per_batch: int,
Expand All @@ -215,7 +245,8 @@ def __init__(self,
self.sample_ratio = sample_ratio

self.dataset_name = dataset_name

self.flm_pad = flm_pad
self.blm_pad = blm_pad
self.w_pad = w_pad
self.c_pad = c_pad
self.token_per_batch = token_per_batch
Expand Down Expand Up @@ -258,6 +289,28 @@ def reader(self, device):
batch = self.dataset[self.index_list[batch_idx]: self.index_list[batch_idx + 1]]

cur_seq_length = len(batch[0][0])
batch_size = len(batch)
batch_idx = range(batch_size)

flm_t, blm_t, blm_ind, lm_index = None, None, None, None
if batch[0][6] is not None:
flm_t = list()
blm_t = list()
blm_ind = list()
word_padded_len = max([len(tup[6]) for tup in batch])
for instance_ind in range(batch_size):
instance = batch[instance_ind]
word_padded_len_ins = word_padded_len - len(instance[6])
flm_t.append(instance[6] + [self.flm_pad] + [self.flm_pad] * word_padded_len_ins)
blm_t.append([self.blm_pad] + instance[7][::-1] + [self.blm_pad] * word_padded_len_ins)
tmp_p = list(range(len(instance[7]), -1, -1)) + list(range(len(instance[7])+1, word_padded_len+1))
blm_ind.append([x * batch_size + instance_ind for x in tmp_p])
flm_t = torch.LongTensor(flm_t).transpose(0, 1).contiguous().to(device)
blm_t = torch.LongTensor(blm_t).transpose(0, 1).contiguous().to(device)
blm_ind = torch.LongTensor(blm_ind).transpose(0, 1).contiguous().view(-1).to(device)
lm_index = torch.LongTensor([tup[8] + [word_padded_len] * (cur_seq_length - len(tup[0])) for tup in batch])
lm_index[lm_index==-1] = word_padded_len + 1
lm_index = lm_index.to(device)
word_t = torch.LongTensor([tup[0] + [self.w_pad] * (cur_seq_length - len(tup[0])) for tup in batch]).to(device)
char_t = torch.LongTensor([tup[1] + [self.c_pad] * (cur_seq_length - len(tup[0])) for tup in batch]).to(device)
chunk_mask = torch.ByteTensor([tup[2] + [0] * (cur_seq_length - len(tup[2])) for tup in batch]).to(device)
Expand All @@ -266,9 +319,10 @@ def reader(self, device):
label_list = [label for tup in batch for label in tup[5]]
type_label = torch.FloatTensor(label_list[0:-1]).to(device)


cur_idx += 1

yield word_t, char_t, chunk_mask, chunk_label, type_mask, type_label
yield word_t, char_t, chunk_mask, chunk_label, type_mask, type_label, flm_t, blm_t, blm_ind, lm_index

random.shuffle(self.shuffle_list)

Expand Down
43 changes: 39 additions & 4 deletions model_partial_ner/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class NER(nn.Module):

Parameters
----------
f_lm : ``torch.nn.Module``, optional.
The forward language modle for contextualized representations.
b_lm : ``torch.nn.Module``, optional.
The backward language modle for contextualized representations.
rnn : ``torch.nn.Module``, required.
The RNN unit..
w_num : ``int`` , required.
Expand All @@ -33,7 +37,7 @@ class NER(nn.Module):
droprate : ``float`` , required
The dropout ratio.
"""
def __init__(self, rnn,
def __init__(self, flm, blm, rnn,
w_num: int,
w_dim: int,
c_num: int,
Expand All @@ -44,6 +48,8 @@ def __init__(self, rnn,

super(NER, self).__init__()

self.flm = flm
self.blm = blm
self.rnn = rnn
self.rnn_outdim = self.rnn.output_dim
self.one_direction_dim = self.rnn_outdim // 2
Expand Down Expand Up @@ -109,12 +115,20 @@ def rand_ini(self):
utils.init_linear(self.to_chunk_proj)
utils.init_linear(self.to_type_proj)

def forward(self, w_in, c_in, mask):
def forward(self, flm_in, blm_in, blm_ind, lm_idx, w_in, c_in, mask):
"""
Sequence labeling model.

Parameters
----------
flm_in : ``torch.LongTensor``, optional.
Forward contextualized language model input.
blm_in : ``torch.LongTensor``, optional.
Backward contextualized language model input.
blm_ind : ``torch.LongTensor``, optional.
Backward contextualized language model index.
lm_idx : ``torch.LongTensor``, optional.
Contextualized language model index.
w_in : ``torch.LongTensor``, required.
The RNN unit.
c_in : ``torch.LongTensor`` , required.
Expand All @@ -126,7 +140,28 @@ def forward(self, w_in, c_in, mask):

c_emb = self.char_embed(c_in)

emb = self.drop( torch.cat([w_emb, c_emb], 2) )
context_o = []
if self.flm is not None:
self.flm.init_hidden()
self.blm.init_hidden()
flm_o = self.flm(flm_in)
blm_o = self.blm(blm_in, blm_ind)

tmp_lm_pad = torch.zeros(1, flm_o.size()[1], flm_o.size()[2]).to(flm_o.device)
flm_o = torch.cat([flm_o, tmp_lm_pad], 0)
blm_o = torch.cat([blm_o, tmp_lm_pad], 0)

flm_o = flm_o.permute(1,0,2)
blm_o = blm_o.permute(1,0,2)

flm_o = flm_o[torch.arange(flm_o.shape[0]).unsqueeze(-1), lm_idx]
blm_o = blm_o[torch.arange(blm_o.shape[0]).unsqueeze(-1), lm_idx]

context_o.append(flm_o)
context_o.append(blm_o)


emb = self.drop( torch.cat([w_emb, c_emb] + context_o, 2) )

out = self.rnn(emb)

Expand Down Expand Up @@ -236,4 +271,4 @@ def to_typed_span(self, chunk_label, type_label, none_idx, id2label):

assert type_idx == len(type_label)

return set(span_list)
return set(span_list)
14 changes: 7 additions & 7 deletions model_partial_ner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def evaluate_chunking(iterator, ner_model, none_idx):

ner_model.eval()

for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label in iterator:
output = ner_model(word_t, char_t, chunk_mask)
for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label, flm_t, blm_t, blm_ind, lm_idx in iterator:
output = ner_model(flm_t, blm_t, blm_ind, lm_idx, word_t, char_t, chunk_mask)
chunk_score = ner_model.chunking(output)
pred_chunk = (chunk_score < 0.0)

Expand Down Expand Up @@ -91,8 +91,8 @@ def evaluate_typing(iterator, ner_model, none_idx):

ner_model.eval()

for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label in iterator:
output = ner_model(word_t, char_t, chunk_mask)
for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label, flm_t, blm_t, blm_ind, lm_idx in iterator:
output = ner_model(flm_t, blm_t, blm_ind, lm_idx, word_t, char_t, chunk_mask)
pred_chunk = (chunk_label <= 0.0)

if pred_chunk.data.float().sum() <= 1:
Expand Down Expand Up @@ -137,8 +137,8 @@ def evaluate_ner(iterator, ner_model, none_idx, id2label):

type2gold, type2guess, type2overlap = {}, {}, {}

for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label in iterator:
output = ner_model(word_t, char_t, chunk_mask)
for word_t, char_t, chunk_mask, chunk_label, type_mask, type_label, flm_t, blm_t, blm_ind, lm_idx in iterator:
output = ner_model(flm_t, blm_t, blm_ind, lm_idx, word_t, char_t, chunk_mask)
chunk_score = ner_model.chunking(output)
pred_chunk = (chunk_score < 0.0)

Expand Down Expand Up @@ -214,4 +214,4 @@ def init_lstm(input_lstm):
weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
weight = eval('input_lstm.bias_hh_l'+str(ind))
weight.data.zero_()
weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
48 changes: 35 additions & 13 deletions preprocess_partial_ner/encode_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def read_corpus(lines):
return features, labels_chunk, labels_point, labels_typing


def encode_folder(input_folder, output_folder, w_map, c_map, cl_map, tl_map, c_threshold = -1):
def encode_folder(input_folder, output_folder, flm_map, blm_map, w_map, c_map, cl_map, tl_map, c_threshold = -1):

w_st, w_unk, w_con, w_pad = w_map['<s>'], w_map['<unk>'], w_map['< >'], w_map['<\n>']
c_st, c_unk, c_con, c_pad = c_map['<s>'], c_map['<unk>'], c_map['< >'], c_map['<\n>']
flm_unk = flm_map['<unk>']
blm_unk = blm_map['<unk>']

# list_dirs = os.walk(input_folder)

Expand Down Expand Up @@ -179,16 +181,23 @@ def encode_folder(input_folder, output_folder, w_map, c_map, cl_map, tl_map, c_t
tmp_w = [w_st, w_con]
tmp_c = [c_st, c_con]
tmp_mc = [0, 1]
tmp_flm = [flm_map.get(token, flm_map.get(token.lower(), flm_unk)) for token in f_l[1:-1]]
tmp_blm = [blm_map.get(token, blm_map.get(token.lower(), blm_unk)) for token in f_l[1:-1]]
tmp_lm_idx = [-1, -1]
idx = 0

for i_f, i_m in zip(f_l[1:-1], l_c_m[1:-1]):
tmp_w = tmp_w + [w_map.get(i_f, w_map.get(i_f.lower(), w_unk))] * len(i_f) + [w_con]
tmp_c = tmp_c + [c_map.get(t, c_unk) for t in i_f] + [c_con]
tmp_mc = tmp_mc + [0] * len(i_f) + [i_m]
tmp_lm_idx = tmp_lm_idx + [idx] * len(i_f) + [-1]
idx += 1

tmp_w.append(w_pad)
tmp_c.append(c_pad)
tmp_lm_idx.append(-1)
tmp_mc.append(0)

assert len(tmp_w) == len(tmp_lm_idx)

tmp_lc = [cl_map[tup] for tup in l_c[1:]]
tmp_mt = l_m[1:]
Expand All @@ -199,7 +208,7 @@ def encode_folder(input_folder, output_folder, w_map, c_map, cl_map, tl_map, c_t
tmp_mask[tl_map[tup]] = 1
tmp_lt.append(tmp_mask)

dataset.append([tmp_w, tmp_c, tmp_mc, tmp_lc, tmp_mt, tmp_lt])
dataset.append([tmp_w, tmp_c, tmp_mc, tmp_lc, tmp_mt, tmp_lt, tmp_flm, tmp_blm, tmp_lm_idx])

dataset.sort(key=lambda t: len(t[0]), reverse=True)

Expand All @@ -211,7 +220,7 @@ def encode_folder(input_folder, output_folder, w_map, c_map, cl_map, tl_map, c_t
return range_ind


def encode_dataset(input_file, w_map, c_map, cl_map, tl_map):
def encode_dataset(input_file, flm_map, blm_map, w_map, c_map, cl_map, tl_map):

print('loading from ' + input_file)

Expand All @@ -222,42 +231,50 @@ def encode_dataset(input_file, w_map, c_map, cl_map, tl_map):

w_st, w_unk, w_con, w_pad = w_map['<s>'], w_map['<unk>'], w_map['< >'], w_map['<\n>']
c_st, c_unk, c_con, c_pad = c_map['<s>'], c_map['<unk>'], c_map['< >'], c_map['<\n>']

flm_unk = flm_map['<unk>']
blm_unk = blm_map['<unk>']
dataset = list()

for f_l, l_c, l_m, l_t in zip(features, labels_chunk, labels_point, labels_typing):
tmp_w = [w_st, w_con]
tmp_c = [c_st, c_con]
tmp_mc = [0, 1]
tmp_lc = [cl_map[l_c[1]]]
tmp_flm = [flm_map.get(token, flm_map.get(token.lower(), flm_unk)) for token in f_l[1:-1]]
tmp_blm = [blm_map.get(token, blm_map.get(token.lower(), blm_unk)) for token in f_l[1:-1]]
tmp_lm_idx = [-1, -1]
idx = 0

for i_f, i_c in zip(f_l[1:-1], l_c[2:]):
tmp_w = tmp_w + [w_map.get(i_f, w_map.get(i_f.lower(), w_unk))] * len(i_f) + [w_con]
tmp_c = tmp_c + [c_map.get(t, c_unk) for t in i_f] + [c_con]
tmp_mc = tmp_mc + [0] * len(i_f) + [1]
tmp_lc = tmp_lc + [cl_map[i_c]]
tmp_lm_idx = tmp_lm_idx + [idx] * len(i_f) + [-1]
idx += 1

tmp_w.append(w_pad)
tmp_c.append(c_pad)
tmp_lm_idx.append(-1)
tmp_mc.append(0)

tmp_mt = l_m[1:]
tmp_lt = [tl_map[tup] for tup in l_t]

dataset.append([tmp_w, tmp_c, tmp_mc, tmp_lc, tmp_mt, tmp_lt])

dataset.append([tmp_w, tmp_c, tmp_mc, tmp_lc, tmp_mt, tmp_lt, tmp_flm, tmp_blm, tmp_lm_idx])
dataset.sort(key=lambda t: len(t[0]), reverse=True)

return dataset


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_train', default="./annotations/debug.ck")
parser.add_argument('--input_train', default="./data/ner/eng.train.ck")
parser.add_argument('--input_testa', default="./data/ner/eng.testa.ck")
parser.add_argument('--input_testb', default="./data/ner/eng.testb.ck")
parser.add_argument('--pre_word_emb', default="./data/glove.100.pk")
parser.add_argument('--output_folder', default="./data/hqner/")
parser.add_argument('--lm_dataset', default="./data/ner/ner_dataset.pk")
parser.add_argument('--output_folder', default="./data/contextner/")
args = parser.parse_args()

with open(args.pre_word_emb, 'rb') as f:
Expand All @@ -267,6 +284,11 @@ def encode_dataset(input_file, w_map, c_map, cl_map, tl_map):

w_map, emb_array = filter_words(w_map, emb_array, [args.input_train, args.input_testa, args.input_testb])
assert len(w_map) == len(emb_array)

with open(args.lm_dataset, 'rb') as f:
dataset = pickle.load(f)
flm_map = dataset['flm_map']
blm_map = dataset['blm_map']

#four special char/word, <s>, <unk>, < > and <\n>
c_map = {'<s>': 0, '<unk>': 1, '< >': 2, '<\n>': 3}
Expand All @@ -277,12 +299,12 @@ def encode_dataset(input_file, w_map, c_map, cl_map, tl_map):
tl_map = build_label_mapping(args.input_train, args.input_testa, args.input_testb)
cl_map = {'I': 0, 'O': 1}

range_ind = encode_folder(args.input_train, args.output_folder, w_map, c_map, cl_map, tl_map, 5)
testa_dataset = encode_dataset(args.input_testa, w_map, c_map, cl_map, tl_map)
testb_dataset = encode_dataset(args.input_testb, w_map, c_map, cl_map, tl_map)
range_ind = encode_folder(args.input_train, args.output_folder, flm_map, blm_map, w_map, c_map, cl_map, tl_map, 5)
testa_dataset = encode_dataset(args.input_testa, flm_map, blm_map, w_map, c_map, cl_map, tl_map)
testb_dataset = encode_dataset(args.input_testb, flm_map, blm_map, w_map, c_map, cl_map, tl_map)

with open(args.output_folder+'test.pk', 'wb') as f:
pickle.dump({'emb_array': emb_array, 'w_map': w_map, 'c_map': c_map, 'tl_map': tl_map, 'cl_map': cl_map, 'range': range_ind, 'test_data':testb_dataset, 'dev_data': testa_dataset}, f)
pickle.dump({'emb_array': emb_array, 'flm_map': flm_map, 'blm_map': blm_map, 'w_map': w_map, 'c_map': c_map, 'tl_map': tl_map, 'cl_map': cl_map, 'range': range_ind, 'test_data':testb_dataset, 'dev_data': testa_dataset}, f)

print('dumped to the folder: ' + args.output_folder)
print('done!')
Loading