Skip to content

Commit 5426d18

Browse files
jeffraeltonzhengsamyamRezaYazdaniAminabadiarashashari
authored
DeepSpeed 0.2 support (deepspeedai#21)
* ZeRO-2 support * DeepSpeed transformer kernel support Co-authored-by: Elton Zheng <[email protected]> Co-authored-by: Samyam Rajbhandari <[email protected]> Co-authored-by: RezaYazdaniAminabadi <[email protected]> Co-authored-by: Arash Ashari <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Tunji Ruwase <[email protected]>
1 parent 9e2c735 commit 5426d18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+5798
-3237
lines changed

.pre-commit-config.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
repos:
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v1.2.3
5+
hooks:
6+
- id: trailing-whitespace
7+
exclude: "Megatron-LM/"
8+
- id: check-yaml
9+
exclude: "Megatron-LM/"
10+
- id: end-of-file-fixer
11+
exclude: "Megatron-LM/"
12+
13+
14+
- repo: https://github.com/pre-commit/mirrors-yapf
15+
rev: v0.29.0
16+
hooks:
17+
- id: yapf
18+
exclude: "Megatron-LM/"

BingBertSquad/deepspeed_bsz24_config.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"train_batch_size": 24,
3-
"train_micro_batch_size_per_gpu": 6,
3+
"train_micro_batch_size_per_gpu": 3,
44
"steps_per_print": 10,
55
"optimizer": {
66
"type": "Adam",
@@ -13,6 +13,6 @@
1313
"gradient_clipping": 1.0,
1414
"fp16": {
1515
"enabled": true
16-
}
16+
}
1717

1818
}

BingBertSquad/evaluate-v1.1.py

+7-86
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,16 @@
1-
""" Official evaluation script for v1.1 of the SQuAD dataset. """
2-
from __future__ import print_function
3-
from collections import Counter
4-
import string
5-
import re
61
import argparse
72
import json
8-
import sys
9-
10-
11-
def normalize_answer(s):
12-
"""Lower text and remove punctuation, articles and extra whitespace."""
13-
def remove_articles(text):
14-
return re.sub(r'\b(a|an|the)\b', ' ', text)
15-
16-
def white_space_fix(text):
17-
return ' '.join(text.split())
18-
19-
def remove_punc(text):
20-
exclude = set(string.punctuation)
21-
return ''.join(ch for ch in text if ch not in exclude)
22-
23-
def lower(text):
24-
return text.lower()
25-
26-
return white_space_fix(remove_articles(remove_punc(lower(s))))
27-
28-
29-
def f1_score(prediction, ground_truth):
30-
prediction_tokens = normalize_answer(prediction).split()
31-
ground_truth_tokens = normalize_answer(ground_truth).split()
32-
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33-
num_same = sum(common.values())
34-
if num_same == 0:
35-
return 0
36-
precision = 1.0 * num_same / len(prediction_tokens)
37-
recall = 1.0 * num_same / len(ground_truth_tokens)
38-
f1 = (2 * precision * recall) / (precision + recall)
39-
return f1
40-
41-
42-
def exact_match_score(prediction, ground_truth):
43-
return (normalize_answer(prediction) == normalize_answer(ground_truth))
44-
45-
46-
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47-
scores_for_ground_truths = []
48-
for ground_truth in ground_truths:
49-
score = metric_fn(prediction, ground_truth)
50-
scores_for_ground_truths.append(score)
51-
return max(scores_for_ground_truths)
52-
53-
54-
def evaluate(dataset, predictions):
55-
f1 = exact_match = total = 0
56-
for article in dataset:
57-
for paragraph in article['paragraphs']:
58-
for qa in paragraph['qas']:
59-
total += 1
60-
if qa['id'] not in predictions:
61-
message = 'Unanswered question ' + qa['id'] + \
62-
' will receive score 0.'
63-
print(message, file=sys.stderr)
64-
continue
65-
ground_truths = list(map(lambda x: x['text'], qa['answers']))
66-
prediction = predictions[qa['id']]
67-
exact_match += metric_max_over_ground_truths(
68-
exact_match_score, prediction, ground_truths)
69-
f1 += metric_max_over_ground_truths(
70-
f1_score, prediction, ground_truths)
71-
72-
exact_match = 100.0 * exact_match / total
73-
f1 = 100.0 * f1 / total
74-
75-
return {'exact_match': exact_match, 'f1': f1}
76-
3+
import evaluate as eval
774

785
if __name__ == '__main__':
796
expected_version = '1.1'
80-
parser = argparse.ArgumentParser(
81-
description='Evaluation for SQuAD ' + expected_version)
7+
parser = argparse.ArgumentParser(description='Evaluation for SQuAD ' +
8+
expected_version)
829
parser.add_argument('dataset_file', help='Dataset file')
8310
parser.add_argument('prediction_file', help='Prediction File')
8411
args = parser.parse_args()
85-
with open(args.dataset_file) as dataset_file:
86-
dataset_json = json.load(dataset_file)
87-
if (dataset_json['version'] != expected_version):
88-
print('Evaluation expects v-' + expected_version +
89-
', but got dataset with v-' + dataset_json['version'],
90-
file=sys.stderr)
91-
dataset = dataset_json['data']
92-
with open(args.prediction_file) as prediction_file:
93-
predictions = json.load(prediction_file)
94-
print(json.dumps(evaluate(dataset, predictions)))
9512

13+
print(
14+
json.dumps(
15+
eval.evaluate(expected_version, args.dataset_file,
16+
args.prediction_file)))

BingBertSquad/evaluate.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
""" Official evaluation script for v1.1 of the SQuAD dataset. """
2+
from __future__ import print_function
3+
from collections import Counter
4+
import string
5+
import re
6+
import argparse
7+
import json
8+
import sys
9+
10+
11+
def normalize_answer(s):
12+
"""Lower text and remove punctuation, articles and extra whitespace."""
13+
def remove_articles(text):
14+
return re.sub(r'\b(a|an|the)\b', ' ', text)
15+
16+
def white_space_fix(text):
17+
return ' '.join(text.split())
18+
19+
def remove_punc(text):
20+
exclude = set(string.punctuation)
21+
return ''.join(ch for ch in text if ch not in exclude)
22+
23+
def lower(text):
24+
return text.lower()
25+
26+
return white_space_fix(remove_articles(remove_punc(lower(s))))
27+
28+
29+
def f1_score(prediction, ground_truth):
30+
prediction_tokens = normalize_answer(prediction).split()
31+
ground_truth_tokens = normalize_answer(ground_truth).split()
32+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33+
num_same = sum(common.values())
34+
if num_same == 0:
35+
return 0
36+
precision = 1.0 * num_same / len(prediction_tokens)
37+
recall = 1.0 * num_same / len(ground_truth_tokens)
38+
f1 = (2 * precision * recall) / (precision + recall)
39+
return f1
40+
41+
42+
def exact_match_score(prediction, ground_truth):
43+
return (normalize_answer(prediction) == normalize_answer(ground_truth))
44+
45+
46+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47+
scores_for_ground_truths = []
48+
for ground_truth in ground_truths:
49+
score = metric_fn(prediction, ground_truth)
50+
scores_for_ground_truths.append(score)
51+
return max(scores_for_ground_truths)
52+
53+
54+
def evaluate(expected_version, ds_file, pred_file):
55+
with open(ds_file) as dataset_file:
56+
dataset_json = json.load(dataset_file)
57+
if (dataset_json['version'] != expected_version):
58+
print('Evaluation expects v-' + expected_version +
59+
', but got dataset with v-' + dataset_json['version'],
60+
file=sys.stderr)
61+
dataset = dataset_json['data']
62+
with open(pred_file) as prediction_file:
63+
predictions = json.load(prediction_file)
64+
65+
f1 = exact_match = total = 0
66+
for article in dataset:
67+
for paragraph in article['paragraphs']:
68+
for qa in paragraph['qas']:
69+
total += 1
70+
if qa['id'] not in predictions:
71+
message = 'Unanswered question ' + qa['id'] + \
72+
' will receive score 0.'
73+
print(message, file=sys.stderr)
74+
continue
75+
ground_truths = list(map(lambda x: x['text'], qa['answers']))
76+
prediction = predictions[qa['id']]
77+
exact_match += metric_max_over_ground_truths(
78+
exact_match_score, prediction, ground_truths)
79+
f1 += metric_max_over_ground_truths(f1_score, prediction,
80+
ground_truths)
81+
82+
exact_match = 100.0 * exact_match / total
83+
f1 = 100.0 * f1 / total
84+
85+
return {'exact_match': exact_match, 'f1': f1}

0 commit comments

Comments
 (0)