-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.py
133 lines (114 loc) · 6.58 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
'''
Generic sentence evaluation scripts wrapper
'''
from __future__ import absolute_import, division, unicode_literals
from senteval import utils
from senteval.binary import CREval, MREval, MPQAEval, SUBJEval
from senteval.snli import SNLIEval
from senteval.trec import TRECEval
from senteval.sick import SICKEntailmentEval, SICKEval
from senteval.mrpc import MRPCEval
from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune
from senteval.sst import SSTEval
from senteval.rank import ImageCaptionRetrievalEval
from senteval.probing import *
from senteval.hans import HANSEval ### newly added
class SE(object):
def __init__(self, params, batcher, prepare=None):
# parameters
params = utils.dotdict(params)
params.usepytorch = True if 'usepytorch' not in params else params.usepytorch
params.seed = 1111 if 'seed' not in params else params.seed
params.batch_size = 128 if 'batch_size' not in params else params.batch_size
params.nhid = 0 if 'nhid' not in params else params.nhid
params.kfold = 5 if 'kfold' not in params else params.kfold
if 'classifier' not in params or not params['classifier']:
params.classifier = {'nhid': 0}
assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!'
self.params = params
# batcher and prepare
self.batcher = batcher
self.prepare = prepare if prepare else lambda x, y: None
self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'HANS', ### newly added HANS
'SICKRelatedness', 'SICKEntailment', 'STSBenchmark',
'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13',
'STS14', 'STS15', 'STS16',
'Length', 'WordContent', 'Depth', 'TopConstituents',
'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix']
def eval(self, name):
# evaluate on evaluation [name], either takes string or list of strings
if (isinstance(name, list)):
self.results = {x: self.eval(x) for x in name}
return self.results
tpath = self.params.task_path
assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)
# Original SentEval tasks
if name == 'CR':
self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed)
elif name == 'MR':
self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed)
elif name == 'MPQA':
self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed)
elif name == 'SUBJ':
self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed)
elif name == 'SST2':
self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed)
elif name == 'SST5':
self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed)
elif name == 'TREC':
self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed)
elif name == 'MRPC':
self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed)
elif name == 'SICKRelatedness':
self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed)
elif name == 'STSBenchmark':
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
elif name == 'STSBenchmark-fix':
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed)
elif name == 'STSBenchmark-finetune':
self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
elif name == 'SICKRelatedness-finetune':
self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed)
elif name == 'SICKEntailment':
self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed)
elif name == 'SNLI':
self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed)
elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
fpath = name + '-en-test'
self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed)
elif name == 'ImageCaptionRetrieval':
self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed)
elif name == 'HANS': ### newly added
self.evaluation = HANSEval(tpath + '/downstream/HANS', seed=self.params.seed)
# Probing Tasks
elif name == 'Length':
self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed)
elif name == 'WordContent':
self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed)
elif name == 'Depth':
self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed)
elif name == 'TopConstituents':
self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed)
elif name == 'BigramShift':
self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed)
elif name == 'Tense':
self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed)
elif name == 'SubjNumber':
self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed)
elif name == 'ObjNumber':
self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed)
elif name == 'OddManOut':
self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed)
elif name == 'CoordinationInversion':
self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed)
self.params.current_task = name
self.evaluation.do_prepare(self.params, self.prepare)
self.results = self.evaluation.run(self.params, self.batcher)
return self.results