Skip to content

Commit 9e2c735

Browse files
authored
New apex compatible squad (deepspeedai#19)
* change squad baseline to use new apex
1 parent 6a698b2 commit 9e2c735

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

BingBertSquad/deepspeed_bsz24_config.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
"type": "Adam",
77
"params": {
88
"lr": 3e-5,
9-
"max_grad_norm": 1.0,
109
"weight_decay": 0.0,
1110
"bias_correction": false
1211
}
1312
},
13+
"gradient_clipping": 1.0,
1414
"fp16": {
1515
"enabled": true
1616
}

BingBertSquad/nvidia_run_squad_baseline.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pytorch_pretrained_bert.optimization import BertAdam
4343
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
4444

45+
from apex import amp
4546
from turing.nvidia_modeling import BertForQuestionAnswering, BertConfig
4647

4748
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
@@ -712,6 +713,31 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n
712713
return is_nan
713714

714715

716+
from apex.multi_tensor_apply import multi_tensor_applier
717+
class GradientClipper:
718+
"""
719+
Clips gradient norm of an iterable of parameters.
720+
"""
721+
def __init__(self, max_grad_norm):
722+
self.max_norm = max_grad_norm
723+
if multi_tensor_applier.available:
724+
import amp_C
725+
self._overflow_buf = torch.cuda.IntTensor([0])
726+
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
727+
self.multi_tensor_scale = amp_C.multi_tensor_scale
728+
else:
729+
raise RuntimeError('Gradient clipping requires cuda extensions')
730+
731+
def step(self, parameters):
732+
l = [p.grad for p in parameters if p.grad is not None]
733+
total_norm, _ = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [l], False)
734+
total_norm = total_norm.item()
735+
if (total_norm == float('inf')): return
736+
clip_coef = self.max_norm / (total_norm + 1e-6)
737+
if clip_coef < 1:
738+
multi_tensor_applier(self.multi_tensor_scale, self._overflow_buf, [l, l], clip_coef)
739+
740+
715741
def main():
716742
parser = get_argument_parser()
717743
args = parser.parse_args()
@@ -813,18 +839,7 @@ def main():
813839
#model.bert.load_state_dict(bert_state_dict, strict=False)
814840
logger.info(f"Pretrained Bert Encoder Loaded from: {args.model_file}")
815841

816-
if args.fp16:
817-
model.half()
818842
model.to(device)
819-
if args.local_rank != -1:
820-
try:
821-
from apex.parallel import DistributedDataParallel as DDP
822-
except ImportError:
823-
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
824-
825-
model = DDP(model)
826-
elif n_gpu > 1:
827-
model = torch.nn.DataParallel(model)
828843

829844
# Prepare optimizer
830845
param_optimizer = list(model.named_parameters())
@@ -844,25 +859,33 @@ def main():
844859
t_total = t_total // torch.distributed.get_world_size()
845860
if args.fp16:
846861
try:
847-
from apex.optimizers import FP16_Optimizer
848862
from apex.optimizers import FusedAdam
849863
except ImportError:
850864
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
851865

852866
optimizer = FusedAdam(optimizer_grouped_parameters,
853867
lr=args.learning_rate,
854-
bias_correction=False,
855-
max_grad_norm=1.0)
868+
bias_correction=False)
856869
if args.loss_scale == 0:
857-
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
870+
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale="dynamic")
858871
else:
859-
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
872+
raise NotImplementedError("dynamic loss scale is only supported in baseline, please set loss_scale=0")
860873
else:
861874
optimizer = BertAdam(optimizer_grouped_parameters,
862875
lr=args.learning_rate,
863876
warmup=args.warmup_proportion,
864877
t_total=t_total)
865878

879+
if args.local_rank != -1:
880+
try:
881+
from apex.parallel import DistributedDataParallel as DDP
882+
except ImportError:
883+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
884+
885+
model = DDP(model)
886+
elif n_gpu > 1:
887+
model = torch.nn.DataParallel(model)
888+
866889
global_step = 0
867890
if args.do_train:
868891
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
@@ -901,6 +924,8 @@ def main():
901924
train_sampler = DistributedSampler(train_data)
902925
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
903926

927+
gradClipper = GradientClipper(max_grad_norm=1.0)
928+
904929
model.train()
905930
ema_loss = 0.
906931
sample_count = 0
@@ -928,10 +953,14 @@ def main():
928953
model.enable_allreduce()
929954

930955
if args.fp16:
931-
optimizer.backward(loss)
956+
with amp.scale_loss(loss, optimizer) as scaled_loss:
957+
scaled_loss.backward()
932958
else:
933959
loss.backward()
934960

961+
# gradient clipping
962+
gradClipper.step(amp.master_params(optimizer))
963+
935964
sample_count += (args.train_batch_size * torch.distributed.get_world_size())
936965

937966
if (step + 1) % args.gradient_accumulation_steps == 0:

0 commit comments

Comments
 (0)