|
1 |
| -""" Official evaluation script for v1.1 of the SQuAD dataset. """ |
2 |
| -from __future__ import print_function |
3 |
| -from collections import Counter |
4 |
| -import string |
5 |
| -import re |
6 | 1 | import argparse
|
7 | 2 | import json
|
8 |
| -import sys |
9 |
| - |
10 |
| - |
11 |
| -def normalize_answer(s): |
12 |
| - """Lower text and remove punctuation, articles and extra whitespace.""" |
13 |
| - def remove_articles(text): |
14 |
| - return re.sub(r'\b(a|an|the)\b', ' ', text) |
15 |
| - |
16 |
| - def white_space_fix(text): |
17 |
| - return ' '.join(text.split()) |
18 |
| - |
19 |
| - def remove_punc(text): |
20 |
| - exclude = set(string.punctuation) |
21 |
| - return ''.join(ch for ch in text if ch not in exclude) |
22 |
| - |
23 |
| - def lower(text): |
24 |
| - return text.lower() |
25 |
| - |
26 |
| - return white_space_fix(remove_articles(remove_punc(lower(s)))) |
27 |
| - |
28 |
| - |
29 |
| -def f1_score(prediction, ground_truth): |
30 |
| - prediction_tokens = normalize_answer(prediction).split() |
31 |
| - ground_truth_tokens = normalize_answer(ground_truth).split() |
32 |
| - common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
33 |
| - num_same = sum(common.values()) |
34 |
| - if num_same == 0: |
35 |
| - return 0 |
36 |
| - precision = 1.0 * num_same / len(prediction_tokens) |
37 |
| - recall = 1.0 * num_same / len(ground_truth_tokens) |
38 |
| - f1 = (2 * precision * recall) / (precision + recall) |
39 |
| - return f1 |
40 |
| - |
41 |
| - |
42 |
| -def exact_match_score(prediction, ground_truth): |
43 |
| - return (normalize_answer(prediction) == normalize_answer(ground_truth)) |
44 |
| - |
45 |
| - |
46 |
| -def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): |
47 |
| - scores_for_ground_truths = [] |
48 |
| - for ground_truth in ground_truths: |
49 |
| - score = metric_fn(prediction, ground_truth) |
50 |
| - scores_for_ground_truths.append(score) |
51 |
| - return max(scores_for_ground_truths) |
52 |
| - |
53 |
| - |
54 |
| -def evaluate(dataset, predictions): |
55 |
| - f1 = exact_match = total = 0 |
56 |
| - for article in dataset: |
57 |
| - for paragraph in article['paragraphs']: |
58 |
| - for qa in paragraph['qas']: |
59 |
| - total += 1 |
60 |
| - if qa['id'] not in predictions: |
61 |
| - message = 'Unanswered question ' + qa['id'] + \ |
62 |
| - ' will receive score 0.' |
63 |
| - print(message, file=sys.stderr) |
64 |
| - continue |
65 |
| - ground_truths = list(map(lambda x: x['text'], qa['answers'])) |
66 |
| - prediction = predictions[qa['id']] |
67 |
| - exact_match += metric_max_over_ground_truths( |
68 |
| - exact_match_score, prediction, ground_truths) |
69 |
| - f1 += metric_max_over_ground_truths( |
70 |
| - f1_score, prediction, ground_truths) |
71 |
| - |
72 |
| - exact_match = 100.0 * exact_match / total |
73 |
| - f1 = 100.0 * f1 / total |
74 |
| - |
75 |
| - return {'exact_match': exact_match, 'f1': f1} |
76 |
| - |
| 3 | +import evaluate as eval |
77 | 4 |
|
78 | 5 | if __name__ == '__main__':
|
79 | 6 | expected_version = '1.1'
|
80 |
| - parser = argparse.ArgumentParser( |
81 |
| - description='Evaluation for SQuAD ' + expected_version) |
| 7 | + parser = argparse.ArgumentParser(description='Evaluation for SQuAD ' + |
| 8 | + expected_version) |
82 | 9 | parser.add_argument('dataset_file', help='Dataset file')
|
83 | 10 | parser.add_argument('prediction_file', help='Prediction File')
|
84 | 11 | args = parser.parse_args()
|
85 |
| - with open(args.dataset_file) as dataset_file: |
86 |
| - dataset_json = json.load(dataset_file) |
87 |
| - if (dataset_json['version'] != expected_version): |
88 |
| - print('Evaluation expects v-' + expected_version + |
89 |
| - ', but got dataset with v-' + dataset_json['version'], |
90 |
| - file=sys.stderr) |
91 |
| - dataset = dataset_json['data'] |
92 |
| - with open(args.prediction_file) as prediction_file: |
93 |
| - predictions = json.load(prediction_file) |
94 |
| - print(json.dumps(evaluate(dataset, predictions))) |
95 | 12 |
|
| 13 | + print( |
| 14 | + json.dumps( |
| 15 | + eval.evaluate(expected_version, args.dataset_file, |
| 16 | + args.prediction_file))) |
0 commit comments