Skip to content

Commit 756aecb

Browse files
committed
improve benchmarks on NLP
1 parent 09a35a8 commit 756aecb

File tree

24 files changed

+382
-396
lines changed

24 files changed

+382
-396
lines changed

flgo/benchmark/agnews_classification/config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,26 @@
1010
test_data = torchtext.datasets.AG_NEWS(root=path, split='test')
1111
ngrams = 2
1212
tokenizer = get_tokenizer('basic_english')
13+
1314
def yield_tokens(data_iter, ngrams):
1415
for _, text in data_iter:
1516
yield ngrams_iterator(tokenizer(text), ngrams)
1617

1718
vocab = build_vocab_from_iterator(yield_tokens(train_data, ngrams), specials=["<unk>"])
1819
vocab.set_default_index(vocab["<unk>"])
1920

21+
def text_pipeline(x):
22+
return vocab(list(ngrams_iterator(tokenizer(x), ngrams)))
23+
24+
def label_pipeline(x):
25+
return int(x) - 1
26+
27+
def apply_transform(x):
28+
return text_pipeline(x[1]), label_pipeline(x[0])
29+
30+
train_data = train_data.map(apply_transform)
31+
test_data = test_data.map(apply_transform)
32+
2033
class TextClassificationModel(torch.nn.Module):
2134
def __init__(self, vocab_size, embed_dim, num_class):
2235
super(TextClassificationModel, self).__init__()
@@ -30,7 +43,13 @@ def init_weights(self):
3043
self.fc.weight.data.uniform_(-initrange, initrange)
3144
self.fc.bias.data.zero_()
3245

33-
def forward(self, text, offsets):
46+
def forward(self, text):
47+
offsets = [0]
48+
for t in text:
49+
offsets.append(t.size(0))
50+
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
51+
text = torch.cat(text)
52+
offsets = offsets.to(text.device)
3453
embedded = self.embedding(text, offsets)
3554
return self.fc(embedded)
3655

flgo/benchmark/agnews_classification/core.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,13 @@
33
import torch.utils.data
44
from flgo.benchmark.toolkits.nlp.classification import GeneralCalculator
55
from flgo.benchmark.base import FromDatasetPipe, FromDatasetGenerator
6-
from torchtext.vocab import build_vocab_from_iterator
7-
from torchtext.data.utils import get_tokenizer, ngrams_iterator
6+
from torchtext.data.utils import ngrams_iterator
87
from torchtext.data.functional import to_map_style_dataset
98
try:
109
import ujson as json
1110
except:
1211
import json
1312
from .config import train_data
14-
try:
15-
from .config import tokenizer
16-
except:
17-
tokenizer = None
18-
try:
19-
from .config import ngrams
20-
except:
21-
ngrams = 1
22-
23-
def yield_tokens(data_iter, ngrams):
24-
for _, text in data_iter:
25-
yield ngrams_iterator(tokenizer(text), ngrams)
26-
27-
try:
28-
from .config import vocab
29-
except:
30-
vocab = None
3113
try:
3214
from .config import test_data
3315
except:
@@ -37,30 +19,21 @@ def yield_tokens(data_iter, ngrams):
3719
except:
3820
val_data = None
3921

40-
if tokenizer is None: tokenizer = get_tokenizer('basic_english')
41-
if vocab is None:
42-
vocab = build_vocab_from_iterator(yield_tokens(train_data, ngrams), specials=["<unk>"])
43-
vocab.set_default_index(vocab["<unk>"])
44-
4522
def collate_batch(batch):
46-
label_list, text_list, offsets = [], [], [0]
47-
for (_label, _text) in batch:
48-
label_list.append(int(_label)-1)
49-
processed_text = torch.tensor(vocab(list(ngrams_iterator(tokenizer(_text), ngrams))), dtype=torch.int64)
23+
label_list, text_list = [], []
24+
for (_text, _label) in batch:
25+
label_list.append(_label)
26+
processed_text = torch.tensor(_text, dtype=torch.int64)
5027
text_list.append(processed_text)
51-
offsets.append(processed_text.size(0))
5228
label_list = torch.tensor(label_list, dtype=torch.int64)
53-
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
54-
text_list = torch.cat(text_list)
55-
return label_list, text_list, offsets
29+
return text_list, label_list
5630

5731
class TaskGenerator(FromDatasetGenerator):
5832
def __init__(self):
5933
super(TaskGenerator, self).__init__(benchmark=os.path.split(os.path.dirname(__file__))[-1],
6034
train_data=train_data, val_data=val_data, test_data=test_data)
6135

6236
def prepare_data_for_partition(self):
63-
self.train_data = self.train_data.map(lambda x: (x[1], x[0]))
6437
return to_map_style_dataset(self.train_data)
6538

6639
class TaskPipe(FromDatasetPipe):
@@ -70,7 +43,7 @@ def __init__(self, task_path):
7043

7144
def save_task(self, generator):
7245
client_names = self.gen_client_names(len(generator.local_datas))
73-
feddata = {'client_names': client_names,}
46+
feddata = {'client_names': client_names}
7447
for cid in range(len(client_names)): feddata[client_names[cid]] = {'data': generator.local_datas[cid],}
7548
with open(os.path.join(self.task_path, 'data.json'), 'w') as outf:
7649
json.dump(feddata, outf)

flgo/benchmark/agnews_classification/model/__init__.py

Whitespace-only changes.

flgo/benchmark/agnews_classification/model/bag_linear.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

flgo/benchmark/imdb_classification/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ def yield_tokens(data_iter, ngrams):
1717
vocab = build_vocab_from_iterator(yield_tokens(train_data, ngrams), specials=["<unk>"])
1818
vocab.set_default_index(vocab["<unk>"])
1919

20+
def text_pipeline(x):
21+
return vocab(list(ngrams_iterator(tokenizer(x), ngrams)))
22+
23+
def label_pipeline(x):
24+
return int(x) - 1
25+
26+
def apply_transform(x):
27+
return text_pipeline(x[1]), label_pipeline(x[0])
28+
29+
train_data = train_data.map(apply_transform)
30+
test_data = test_data.map(apply_transform)
31+
2032
class TextClassificationModel(torch.nn.Module):
2133
def __init__(self, vocab_size, embed_dim, num_class):
2234
super(TextClassificationModel, self).__init__()
@@ -30,7 +42,13 @@ def init_weights(self):
3042
self.fc.weight.data.uniform_(-initrange, initrange)
3143
self.fc.bias.data.zero_()
3244

33-
def forward(self, text, offsets):
45+
def forward(self, text):
46+
offsets = [0]
47+
for t in text:
48+
offsets.append(t.size(0))
49+
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
50+
text = torch.cat(text)
51+
offsets = offsets.to(text.device)
3452
embedded = self.embedding(text, offsets)
3553
return self.fc(embedded)
3654

flgo/benchmark/imdb_classification/core.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,13 @@
33
import torch.utils.data
44
from flgo.benchmark.toolkits.nlp.classification import GeneralCalculator
55
from flgo.benchmark.base import FromDatasetPipe, FromDatasetGenerator
6-
from torchtext.vocab import build_vocab_from_iterator
7-
from torchtext.data.utils import get_tokenizer, ngrams_iterator
86
from torchtext.data.functional import to_map_style_dataset
7+
98
try:
109
import ujson as json
1110
except:
1211
import json
1312
from .config import train_data
14-
try:
15-
from .config import tokenizer
16-
except:
17-
tokenizer = None
18-
try:
19-
from .config import ngrams
20-
except:
21-
ngrams = 1
22-
23-
def yield_tokens(data_iter, ngrams):
24-
for _, text in data_iter:
25-
yield ngrams_iterator(tokenizer(text), ngrams)
26-
27-
try:
28-
from .config import vocab
29-
except:
30-
vocab = None
3113
try:
3214
from .config import test_data
3315
except:
@@ -37,30 +19,21 @@ def yield_tokens(data_iter, ngrams):
3719
except:
3820
val_data = None
3921

40-
if tokenizer is None: tokenizer = get_tokenizer('basic_english')
41-
if vocab is None:
42-
vocab = build_vocab_from_iterator(yield_tokens(train_data, ngrams), specials=["<unk>"])
43-
vocab.set_default_index(vocab["<unk>"])
44-
4522
def collate_batch(batch):
46-
label_list, text_list, offsets = [], [], [0]
47-
for (_label, _text) in batch:
48-
label_list.append(int(_label)-1)
49-
processed_text = torch.tensor(vocab(list(ngrams_iterator(tokenizer(_text), ngrams))), dtype=torch.int64)
23+
label_list, text_list = [], []
24+
for (_text, _label) in batch:
25+
label_list.append(_label)
26+
processed_text = torch.tensor(_text, dtype=torch.int64)
5027
text_list.append(processed_text)
51-
offsets.append(processed_text.size(0))
5228
label_list = torch.tensor(label_list, dtype=torch.int64)
53-
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
54-
text_list = torch.cat(text_list)
55-
return label_list, text_list, offsets
29+
return text_list, label_list
5630

5731
class TaskGenerator(FromDatasetGenerator):
5832
def __init__(self):
5933
super(TaskGenerator, self).__init__(benchmark=os.path.split(os.path.dirname(__file__))[-1],
6034
train_data=train_data, val_data=val_data, test_data=test_data)
6135

6236
def prepare_data_for_partition(self):
63-
self.train_data = self.train_data.map(lambda x: (x[1], x[0]))
6437
return to_map_style_dataset(self.train_data)
6538

6639
class TaskPipe(FromDatasetPipe):
@@ -70,7 +43,7 @@ def __init__(self, task_path):
7043

7144
def save_task(self, generator):
7245
client_names = self.gen_client_names(len(generator.local_datas))
73-
feddata = {'client_names': client_names,}
46+
feddata = {'client_names': client_names}
7447
for cid in range(len(client_names)): feddata[client_names[cid]] = {'data': generator.local_datas[cid],}
7548
with open(os.path.join(self.task_path, 'data.json'), 'w') as outf:
7649
json.dump(feddata, outf)

flgo/benchmark/multi30k_translation/config.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@
88
import torch.nn.functional as F
99
import random
1010

11+
# 0. 加载数据集
1112
language_pair = ['de', 'en']
1213
path = os.path.join(flgo.benchmark.path, 'RAW_DATA', 'MULTI30K')
1314
train_data, val_data, test_data = Multi30k(split=('train', 'valid', 'test'), language_pair=language_pair)
1415

16+
# 1. 加载tokenizer和词表
1517
# init tokenizers
1618
tokenizers = {}
1719
tokenizers[language_pair[0]] = get_tokenizer('spacy', language='de_core_news_sm')
1820
tokenizers[language_pair[1]] = get_tokenizer('spacy', language='en_core_web_sm')
1921

2022
# init vocabs
2123
vocabs = {}
22-
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
23-
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
24+
PAD_IDX, UNK_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
25+
special_symbols = ['<pad>', '<unk>', '<bos>', '<eos>']
2426
for i,ln in enumerate(language_pair):
2527
# Create torchtext's Vocab object
2628
tokenizer = tokenizers[ln]
@@ -31,13 +33,33 @@
3133
special_first=True)
3234
for ln in language_pair: vocabs[ln].set_default_index(UNK_IDX)
3335

34-
def init_weights(m):
35-
for name, param in m.named_parameters():
36-
if 'weight' in name:
37-
nn.init.normal_(param.data, mean=0, std=0.01)
38-
else:
39-
nn.init.constant_(param.data, 0)
36+
# 2. 把数据集中的字符串(源语言和目标语言),根据tokenizer和vocab转化为数值向量
37+
def sequential_transforms(*transforms):
38+
def func(txt_input):
39+
for transform in transforms:
40+
txt_input = transform(txt_input)
41+
return txt_input
42+
return func
43+
44+
def tensor_transform(token_ids):
45+
return torch.cat((torch.tensor([BOS_IDX]),
46+
torch.tensor(token_ids),
47+
torch.tensor([EOS_IDX])))
48+
49+
text_transform = {}
50+
for ln in language_pair:
51+
text_transform[ln] = sequential_transforms(tokenizers[ln], #Tokenization
52+
vocabs[ln], #Numericalization
53+
tensor_transform) # Add BOS/EOS and create tensor
54+
55+
def apply_transform(x):
56+
return text_transform[language_pair[0]](x[0].rstrip("\n")), text_transform[language_pair[1]](x[1].rstrip("\n"))
4057

58+
train_data = train_data.map(apply_transform)
59+
val_data = val_data.map(apply_transform)
60+
test_data = test_data.map(apply_transform)
61+
62+
# 3. 定义模型
4163
def get_model():
4264
INPUT_DIM = len(vocabs[language_pair[0]])
4365
OUTPUT_DIM = len(vocabs[language_pair[1]])
@@ -54,6 +76,13 @@ def get_model():
5476
model.apply(init_weights)
5577
return model
5678

79+
def init_weights(m):
80+
for name, param in m.named_parameters():
81+
if 'weight' in name:
82+
nn.init.normal_(param.data, mean=0, std=0.01)
83+
else:
84+
nn.init.constant_(param.data, 0)
85+
5786
class Encoder(nn.Module):
5887
def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
5988
super().__init__()

0 commit comments

Comments
 (0)