-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_prepare.py
113 lines (97 loc) · 3.21 KB
/
data_prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
"""
@author: bai
"""
import pickle
import os
import random
tag2label = {"O": 0,"B-PER": 1, "I-PER": 2,"B-LOC": 3, "I-LOC": 4,"B-ORG": 5, "I-ORG": 6}
def load_vocab(vocab_path):
vocab_path = os.path.join(vocab_path)
with open(vocab_path, 'rb') as fr:
word2id = pickle.load(fr)
print('vocab_size:', len(word2id))
return word2id
def load_data(data_path):
sentences = []
with open(data_path, encoding='utf-8') as fr:
lines = fr.readlines()
sent, tag = [], []
for line in lines:
if line != '\n':
[char, label] = line.strip().split()
sent.append(char)
tag.append(label)
else:
sentences.append((sent, tag))
sent, tag = [], []
return sentences
def get_vocab(vocab_path, data_path, min_count):
sentences = load_data(data_path)
word2id = {}
for sent_, tag_ in sentences:
for word in sent_:
if word.isdigit():
word = '<NUM>'
elif ('\u0041' <= word <='\u005a') or ('\u0061' <= word <='\u007a'):
word = '<ENG>'
if word not in word2id:
word2id[word] = [len(word2id)+1, 1]
else:
word2id[word][1] += 1
low_freq_words = []
for word, [word_id, word_freq] in word2id.items():
if word_freq < min_count and word != '<NUM>' and word != '<ENG>':
low_freq_words.append(word)
for word in low_freq_words:
del word2id[word]
new_id = 1
for word in word2id.keys():
word2id[word] = new_id
new_id += 1
word2id['<UNK>'] = new_id
word2id['<PAD>'] = 0
print("vocab size: ",len(word2id))
with open(vocab_path, 'wb') as fw:
pickle.dump(word2id, fw)
def sentence2id(sent, word2id):
sentence_id = []
for word in sent:
if word.isdigit():
word = '<NUM>'
elif ('\u0041' <= word <= '\u005a') or ('\u0061' <= word <= '\u007a'):
word = '<ENG>'
if word not in word2id:
word = '<UNK>'
sentence_id.append(word2id[word])
return sentence_id
def pad_sequences(sequences, pad_mark=0):
max_len=max(map(lambda x : len(x), sequences))
seq_list, seq_len_list = [], []
for seq in sequences:
seq = list(seq)
seq_ = seq[:max_len] + [pad_mark] * max(max_len - len(seq), 0)
seq_list.append(seq_)
seq_len_list.append(min(len(seq), max_len))
return seq_list, seq_len_list
def next_batch(data, batch_size, vocab, tag2label, shuffle=False):
if shuffle:
random.shuffle(data)
seqs, labels = [], []
for (sent_, tag_) in data:
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
if len(seqs)!=0:
if len(data)>1:
for i in range(batch_size-len(seqs)):
(sent_, tag_)=data[i]
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
seqs.append(sent_)
labels.append(label_)
yield seqs, labels