17
17
from settings import TOKENIZER , LEN_FACTOR , DATA_ATTRS , MEMORY_FACTOR , MODEL_CONFIG , MODEL_CLASS
18
18
from multiprocessing import Pool
19
19
import sys
20
+ import time
21
+ import quadprog
20
22
import io
21
23
sys .stdout = io .TextIOWrapper (sys .stdout .buffer , encoding = "UTF-8" )
22
24
logger = logging .getLogger (__name__ )
@@ -164,6 +166,8 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
164
166
165
167
data = []
166
168
for data_path in data_paths :
169
+ if not data_path :
170
+ continue
167
171
with open (data_path , "r" ) as f :
168
172
raw_ds = json .load (f )
169
173
raw_ds = map (lambda x : x ["paragraphs" ], raw_ds ["data" ])
@@ -174,7 +178,7 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
174
178
175
179
self .data = []
176
180
self .max_a_len = 0
177
- if len (data_paths )== 1 and ('wiki' in data_paths [0 ] or 'woz' in data_paths [0 ]):
181
+ if len (data_paths )== 1 and data_paths [ 0 ] is not None and ('wiki' in data_paths [0 ] or 'woz' in data_paths [0 ]):
178
182
#data = self._sort_by_index(data)
179
183
#args.n_workers = 1
180
184
if 'wiki' in data_paths [0 ]:
@@ -183,7 +187,8 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
183
187
answers_file = "woz.en_answers.json"
184
188
with open (os .path .join (args .data_dir ,answers_file ),"r" ) as f :
185
189
self .answers = json .load (f )
186
- self .data_tokenization (data )
190
+ if len (data ) > 0 :
191
+ self .data_tokenization (data )
187
192
188
193
if len (extra_data ) > 0 :
189
194
extra_data = map (lambda x : self .etl_single_extra_data (x ), extra_data )
@@ -345,11 +350,26 @@ def __call__(self, loss, scheduler_steps):
345
350
self .optimizer .backward (loss , update_master_grads = False )
346
351
else :
347
352
loss .backward ()
353
+
348
354
if not args .fp32 :
349
355
self .optimizer .update_master_grads ()
350
356
self .optimizer .clip_master_grads (args .max_grad_norm )
351
357
else :
352
358
torch .nn .utils .clip_grad_norm_ (self .model .parameters (), args .max_grad_norm )
359
+
360
+ if "gem" in args .seq_train_type and self .model .task_id > 0 :
361
+ store_grad (self .model .parameters , self .model .grads , self .model .grad_dims ,self .model .task_id )
362
+ indx = torch .cuda .LongTensor ([i for i in range (self .model .task_id )])
363
+ dotp = torch .mm (self .model .grads [:, self .model .task_id ].unsqueeze (0 ),
364
+ self .model .grads .index_select (1 , indx ))
365
+ if (dotp < 0 ).sum () != 0 :
366
+ project2cone2 (self .model .grads [:, self .model .task_id ].unsqueeze (1 ),
367
+ self .model .grads .index_select (1 , indx ), args .qp_margin )
368
+ # copy gradients back
369
+ overwrite_grad (self .model .parameters ,
370
+ self .model .grads [:, self .model .task_id ],
371
+ self .model .grad_dims )
372
+
353
373
if args .seq_train_type in args .REG_TYPE_KEYS :
354
374
self .optimizer .step (self .model .reg_params )
355
375
else :
@@ -360,6 +380,58 @@ def __call__(self, loss, scheduler_steps):
360
380
self .optimizer .zero_grad ()
361
381
362
382
383
+ class GEMStep :
384
+ def __init__ (self , model , parallel_model , train_loss_fct , optimizer ):
385
+ self .model = model
386
+ self .parallel_model = parallel_model
387
+ self .train_loss_fct = train_loss_fct
388
+ self .optimizer = optimizer
389
+
390
+ def __call__ (self ,current_task_id ):
391
+ for past_task_id , md in enumerate (args .memory_data ):
392
+ # Not saving current task's grads.
393
+ if past_task_id >= current_task_id : return
394
+ qadata = QADataset (None , "test" , "gen" , md )[:90 ]
395
+ dataloader = create_dataloader (qadata , "test" )
396
+ grads_tmp = torch .zeros (sum (self .model .grad_dims ),).cuda ()
397
+ if not args .fp32 :
398
+ grads_tmp = grads_tmp .half ()
399
+ for _ , _ , cqa , _ , Y , gen_X , gen_Y in dataloader :
400
+ #CHECK
401
+ n_inputs = sum (_cqa .shape [0 ] for _cqa in cqa )
402
+ self .optimizer .zero_grad ()
403
+ for i in range (len (cqa )):
404
+ cqa [i ] = (cqa [i ].to (args .device_ids [i ]),)
405
+ Y [i ] = Y [i ].to (args .device_ids [i ])
406
+ gen_X [i ] = (gen_X [i ].to (args .device_ids [i ]),)
407
+ gen_Y [i ] = gen_Y [i ].to (args .device_ids [i ])
408
+
409
+ losses = get_losses (self .parallel_model , cqa , Y , gen_X , gen_Y , self .train_loss_fct )
410
+ loss = sum (losses )
411
+ if not args .fp32 :
412
+ self .optimizer .backward (loss , update_master_grads = False )
413
+ else :
414
+ loss .backward ()
415
+
416
+ if not args .fp32 :
417
+ #copy fp16 grads to fp32 grads
418
+ self .optimizer .update_master_grads ()
419
+ self .optimizer .clip_master_grads (args .max_grad_norm )
420
+ else :
421
+ torch .nn .utils .clip_grad_norm_ (self .model .parameters (), args .max_grad_norm )
422
+ i = 0
423
+ for param in self .model .parameters ():
424
+ if param .grad is not None :
425
+ beg = 0 if i == 0 else sum (self .model .grad_dims [:i ])
426
+ end = sum (self .model .grad_dims [:i + 1 ])
427
+ grads_tmp [beg : end ] += param .grad .data .view (- 1 )* n_inputs
428
+ i += 1
429
+
430
+ grads_tmp /= len (qadata )
431
+ self .model .grads [:, past_task_id ].copy_ (grads_tmp )
432
+ self .optimizer .zero_grad ()
433
+
434
+
363
435
class DynamicBatchSampler (Sampler ):
364
436
def __init__ (self , dataset , data_type , max_batch_size ):
365
437
self .dataset = dataset
@@ -523,11 +595,15 @@ def parse_single_real_data(data,task):
523
595
return data
524
596
525
597
526
- def get_real_data (task , train_extra_data ):
598
+ def get_real_data (task , train_extra_data , accum = True , encode = True ):
527
599
task_idx = args .tasks .index (task )
528
- prev_tasks = args .tasks [:task_idx ]
529
600
gen_size = DATA_ATTRS [task ]["train" ]["data_size" ]
530
- gen_size = int (np .ceil (gen_size * args .gen_lm_sample_percentage ))// len (prev_tasks )
601
+ if accum :
602
+ prev_tasks = args .tasks [:task_idx ]
603
+ gen_size = int (np .ceil (gen_size * args .gen_lm_sample_percentage ))// len (prev_tasks )
604
+ else :
605
+ prev_tasks = [args .tasks [task_idx - 1 ]]
606
+ gen_size = int (gen_size * args .gen_lm_sample_percentage )
531
607
532
608
datum = []
533
609
for prev_task in prev_tasks :
@@ -537,11 +613,13 @@ def get_real_data(task, train_extra_data):
537
613
for i in indices :
538
614
d = parse_single_real_data (data [i ],prev_task )
539
615
datum .append (d )
540
- train_extra_data .append (TOKENIZER .encode (d ))
616
+ if encode :
617
+ train_extra_data .append (TOKENIZER .encode (d ))
541
618
542
619
model_dir = get_model_dir ([prev_task ])
543
620
dump_path = os .path .join (model_dir ,"real.csv" )
544
621
write_extra_data (dump_path , datum )
622
+ return dump_path
545
623
546
624
547
625
def read_extra_data (gen_path , train_extra_data ):
@@ -728,3 +806,39 @@ def get_split_indices(data_sizes,chunk_sizes):
728
806
chunk_sizes .pop (0 )
729
807
i += 1
730
808
return records
809
+
810
+
811
+ def store_grad (get_ps , grads , grad_dims , task_id ):
812
+ i = 0
813
+ for param in get_ps ():
814
+ if param .grad is not None :
815
+ beg = 0 if i == 0 else sum (grad_dims [:i ])
816
+ end = sum (grad_dims [:i + 1 ])
817
+ grads [beg : end , task_id ].copy_ (param .grad .data .view (- 1 ))
818
+ i += 1
819
+
820
+
821
+ def overwrite_grad (pp , newgrad , grad_dims ):
822
+ cnt = 0
823
+ for param in pp ():
824
+ if param .grad is not None :
825
+ beg = 0 if cnt == 0 else sum (grad_dims [:cnt ])
826
+ en = sum (grad_dims [:cnt + 1 ])
827
+ this_grad = newgrad [beg : en ].contiguous ().view (
828
+ param .grad .data .size ())
829
+ param .grad .data .copy_ (this_grad )
830
+ cnt += 1
831
+
832
+
833
+ def project2cone2 (gradient , memories , margin = 0.5 , eps = 1e-3 ):
834
+ memories_np = memories .cpu ().t ().double ().numpy ()
835
+ gradient_np = gradient .cpu ().contiguous ().view (- 1 ).double ().numpy ()
836
+ t = memories_np .shape [0 ]
837
+ P = np .dot (memories_np , memories_np .transpose ())
838
+ P = 0.5 * (P + P .transpose ()) + np .eye (t ) * eps
839
+ q = np .dot (memories_np , gradient_np ) * - 1
840
+ G = np .eye (t )
841
+ h = np .zeros (t ) + margin
842
+ v = quadprog .solve_qp (P , q , G , h )[0 ]
843
+ x = np .dot (v , memories_np ) + gradient_np
844
+ gradient .copy_ (torch .Tensor (x ).view (- 1 , 1 ))
0 commit comments