42
42
from pytorch_pretrained_bert .optimization import BertAdam
43
43
from pytorch_pretrained_bert .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
44
44
45
+ from apex import amp
45
46
from turing .nvidia_modeling import BertForQuestionAnswering , BertConfig
46
47
47
48
logging .basicConfig (format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' ,
@@ -712,6 +713,31 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n
712
713
return is_nan
713
714
714
715
716
+ from apex .multi_tensor_apply import multi_tensor_applier
717
+ class GradientClipper :
718
+ """
719
+ Clips gradient norm of an iterable of parameters.
720
+ """
721
+ def __init__ (self , max_grad_norm ):
722
+ self .max_norm = max_grad_norm
723
+ if multi_tensor_applier .available :
724
+ import amp_C
725
+ self ._overflow_buf = torch .cuda .IntTensor ([0 ])
726
+ self .multi_tensor_l2norm = amp_C .multi_tensor_l2norm
727
+ self .multi_tensor_scale = amp_C .multi_tensor_scale
728
+ else :
729
+ raise RuntimeError ('Gradient clipping requires cuda extensions' )
730
+
731
+ def step (self , parameters ):
732
+ l = [p .grad for p in parameters if p .grad is not None ]
733
+ total_norm , _ = multi_tensor_applier (self .multi_tensor_l2norm , self ._overflow_buf , [l ], False )
734
+ total_norm = total_norm .item ()
735
+ if (total_norm == float ('inf' )): return
736
+ clip_coef = self .max_norm / (total_norm + 1e-6 )
737
+ if clip_coef < 1 :
738
+ multi_tensor_applier (self .multi_tensor_scale , self ._overflow_buf , [l , l ], clip_coef )
739
+
740
+
715
741
def main ():
716
742
parser = get_argument_parser ()
717
743
args = parser .parse_args ()
@@ -813,18 +839,7 @@ def main():
813
839
#model.bert.load_state_dict(bert_state_dict, strict=False)
814
840
logger .info (f"Pretrained Bert Encoder Loaded from: { args .model_file } " )
815
841
816
- if args .fp16 :
817
- model .half ()
818
842
model .to (device )
819
- if args .local_rank != - 1 :
820
- try :
821
- from apex .parallel import DistributedDataParallel as DDP
822
- except ImportError :
823
- raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
824
-
825
- model = DDP (model )
826
- elif n_gpu > 1 :
827
- model = torch .nn .DataParallel (model )
828
843
829
844
# Prepare optimizer
830
845
param_optimizer = list (model .named_parameters ())
@@ -844,25 +859,33 @@ def main():
844
859
t_total = t_total // torch .distributed .get_world_size ()
845
860
if args .fp16 :
846
861
try :
847
- from apex .optimizers import FP16_Optimizer
848
862
from apex .optimizers import FusedAdam
849
863
except ImportError :
850
864
raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
851
865
852
866
optimizer = FusedAdam (optimizer_grouped_parameters ,
853
867
lr = args .learning_rate ,
854
- bias_correction = False ,
855
- max_grad_norm = 1.0 )
868
+ bias_correction = False )
856
869
if args .loss_scale == 0 :
857
- optimizer = FP16_Optimizer ( optimizer , dynamic_loss_scale = True )
870
+ model , optimizer = amp . initialize ( model , optimizer , opt_level = "O2" , keep_batchnorm_fp32 = False , loss_scale = "dynamic" )
858
871
else :
859
- optimizer = FP16_Optimizer ( optimizer , static_loss_scale = args . loss_scale )
872
+ raise NotImplementedError ( "dynamic loss scale is only supported in baseline, please set loss_scale=0" )
860
873
else :
861
874
optimizer = BertAdam (optimizer_grouped_parameters ,
862
875
lr = args .learning_rate ,
863
876
warmup = args .warmup_proportion ,
864
877
t_total = t_total )
865
878
879
+ if args .local_rank != - 1 :
880
+ try :
881
+ from apex .parallel import DistributedDataParallel as DDP
882
+ except ImportError :
883
+ raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
884
+
885
+ model = DDP (model )
886
+ elif n_gpu > 1 :
887
+ model = torch .nn .DataParallel (model )
888
+
866
889
global_step = 0
867
890
if args .do_train :
868
891
cached_train_features_file = args .train_file + '_{0}_{1}_{2}_{3}' .format (
@@ -901,6 +924,8 @@ def main():
901
924
train_sampler = DistributedSampler (train_data )
902
925
train_dataloader = DataLoader (train_data , sampler = train_sampler , batch_size = args .train_batch_size )
903
926
927
+ gradClipper = GradientClipper (max_grad_norm = 1.0 )
928
+
904
929
model .train ()
905
930
ema_loss = 0.
906
931
sample_count = 0
@@ -928,10 +953,14 @@ def main():
928
953
model .enable_allreduce ()
929
954
930
955
if args .fp16 :
931
- optimizer .backward (loss )
956
+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
957
+ scaled_loss .backward ()
932
958
else :
933
959
loss .backward ()
934
960
961
+ # gradient clipping
962
+ gradClipper .step (amp .master_params (optimizer ))
963
+
935
964
sample_count += (args .train_batch_size * torch .distributed .get_world_size ())
936
965
937
966
if (step + 1 ) % args .gradient_accumulation_steps == 0 :
0 commit comments