Skip to content

Commit cac7e33

Browse files
committed
add GEM
1 parent 9e793dc commit cac7e33

File tree

4 files changed

+154
-12
lines changed

4 files changed

+154
-12
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ In order to unify the format of all the dataset, we first ran the code in https:
5050
3. Make a copy of `env.example` and save it as `env`. In `env`, set the value of DATA_DIR as `data directory` and set the value of MODEL_ROOT_DIR as `model directory`.
5151
4. Before training or testing, load DATA_DIR and MODEL_ROOT_DIR variables into shell environment by the following command:
5252
```bash
53-
source env
53+
source ./env
5454
```
5555

5656
## Training and Testing
@@ -115,4 +115,5 @@ After running testing program, the metrics: `metrics.json` will be dumped in the
115115
## Acknowledgements:
116116
- We use the language model offered by [transformers](https://github.com/huggingface/transformers), a great state-of-the-art natural language processing models library by Thomas Wolf et al.
117117
- The implementation of MAS refer to [MAS-Memory-Aware-Synapses](https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses), a great Memory Aware Synapses method implementation code by Aljundi R. et al.
118+
- The implementation of GEM refer to [GradientEpisodicMemory](https://github.com/facebookresearch/GradientEpisodicMemory), a great Gradient Episodic Memory method implementation code by Lopez-Paz, David et al.
118119
- Data format conversion refer to [decaNLP](https://github.com/salesforce/decaNLP), a great The Natural Language Decathlon: Multitask Learning as Question Answering implementation code by Bryan McCann et al.

settings.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"lll": 0.35,
2121
"ewc": 0.30,
2222
"mas": 0.18,
23+
"gem": 0.50,
2324
}
2425
TURING_ARCHS = {'Tesla V100', '2080 Ti'}
2526
MODEL_CLASSES = {
@@ -53,12 +54,13 @@ def parse_args():
5354
parser.add_argument("--model_name", type=str, default="gpt2", choices=["gpt2", "openai-gpt"])
5455
parser.add_argument("--n_gpus", type=int, default=1)
5556
parser.add_argument("--n_train_epochs", type=int, default=3)
57+
parser.add_argument("--dynamic_epochs", action="store_true")
5658
parser.add_argument("--n_warmup_ratio", type=float, default=0.005)
5759
parser.add_argument("--n_workers", type=int, default=4)
5860
parser.add_argument("--use_sep", action="store_true")
5961
parser.add_argument("--reg_lambda", type=float, default=1.)
6062
parser.add_argument("--seed", type=int, default=42)
61-
parser.add_argument("--seq_train_type", type=str, default="lll", choices=["lll","finetune","multitask","mas","ewc"])
63+
parser.add_argument("--seq_train_type", type=str, default="lll", choices=["lll","finetune","multitask","mas","ewc","gem"])
6264
parser.add_argument("--tasks", nargs='+', default=["squad2"])
6365
parser.add_argument("--skip_tasks", nargs='+')
6466
parser.add_argument("--temperature_lm", type=float, default=1.0)
@@ -71,6 +73,7 @@ def parse_args():
7173
parser.add_argument("--top_p_qa", type=float, default=0.)
7274
parser.add_argument("--train_batch_size", type=int, default=0)
7375
parser.add_argument("--weight_decay", type=float, default=0.01)
76+
parser.add_argument("--qp_margin", type=float, default=0.5)
7477
args = parser.parse_args()
7578

7679
if args.debug:
@@ -141,9 +144,14 @@ def parse_args():
141144
elif args.unbound:
142145
pass
143146
else:
144-
data_sizes = {task: data_attrs[task]["train"]["data_size"] for task in args.tasks}
145-
max_total_data_size = max(data_sizes.values()) * args.n_train_epochs
146-
args.n_train_epochs = {d[0]: min(args.max_n_epochs, max_total_data_size//d[1]) for d in data_sizes.items()}
147+
if "gem" in args.seq_train_type:
148+
args.memory_data = []
149+
if args.dynamic_epochs:
150+
data_sizes = {task: data_attrs[task]["train"]["data_size"] for task in args.tasks}
151+
max_total_data_size = max(data_sizes.values()) * args.n_train_epochs
152+
args.n_train_epochs = {d[0]: min(args.max_n_epochs, max_total_data_size//d[1]) for d in data_sizes.items()}
153+
else:
154+
args.n_train_epochs = {task: args.n_train_epochs for task in args.tasks}
147155

148156
return args, model_config, model_class, tokenizer, config_class, special_token_ids, special_tokens, data_attrs, tokens_weight
149157

train.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def train(task_ids, model):
3131
prev_task = args.tasks[task_ids[0]-1]
3232
with torch.no_grad():
3333
create_extra_data(tasks[0], prev_task, model, train_extra_data)
34+
elif "gem" in args.seq_train_type and task_ids[0] > 0:
35+
get_real_data(tasks[0], train_extra_data, accum=False, encode=True)
36+
args.memory_data.append(train_extra_data)
37+
train_extra_data = []
3438
logger.info('extra training data size: {}'.format(len(train_extra_data)))
3539

3640
if not model:
@@ -90,7 +94,8 @@ def train(task_ids, model):
9094
max_train_batch_size = max(len(train_qadata) // args.min_n_steps, args.min_batch_size)
9195
train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size)
9296
if not args.unbound and args.seq_train_type != "multitask":
93-
n_train_epochs = TASK_DICT[tasks[0]]["n_train_epochs"]
97+
#n_train_epochs = TASK_DICT[tasks[0]]["n_train_epochs"]
98+
n_train_epochs = args.n_train_epochs[tasks[0]]
9499
else:
95100
n_train_epochs = args.n_train_epochs['_'.join(tasks)]
96101
n_train_optimization_steps = len(train_qadata) * n_train_epochs
@@ -104,6 +109,16 @@ def train(task_ids, model):
104109
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
105110
]
106111

112+
if "gem" in args.seq_train_type:
113+
model.task_id = task_ids[0]
114+
if not hasattr(model, "grad_dims"):
115+
model.grad_dims = []
116+
for param in model.parameters():
117+
model.grad_dims.append(param.data.numel())
118+
if not hasattr(model, "grads"):
119+
model.grads = torch.zeros(sum(model.grad_dims),len(args.tasks))
120+
model.grads = model.grads.cuda()
121+
107122
if args.seq_train_type in REG_TYPE_KEYS:
108123
optimizer = Weight_Regularized_AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
109124
else:
@@ -124,6 +139,8 @@ def train(task_ids, model):
124139

125140
tot_n_steps = 0
126141
train_once = TrainStep(model, optimizer, scheduler)
142+
if "gem" in args.seq_train_type and task_ids[0] != 0:
143+
gem_step = GEMStep(model, parallel_model, train_loss_fct, optimizer)
127144
model.train()
128145
for ep in range(n_train_epochs):
129146
cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0
@@ -139,6 +156,8 @@ def train(task_ids, model):
139156

140157
losses = get_losses(parallel_model, cqa, Y, gen_X, gen_Y, train_loss_fct)
141158
loss = sum(losses)
159+
if "gem" in args.seq_train_type and task_ids[0] != 0:
160+
gem_step(task_ids[0])
142161
train_once(loss, n_inputs)
143162

144163
qa_loss = losses[0].item() * n_inputs

utils.py

+120-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from settings import TOKENIZER, LEN_FACTOR, DATA_ATTRS, MEMORY_FACTOR, MODEL_CONFIG, MODEL_CLASS
1818
from multiprocessing import Pool
1919
import sys
20+
import time
21+
import quadprog
2022
import io
2123
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="UTF-8")
2224
logger = logging.getLogger(__name__)
@@ -164,6 +166,8 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
164166

165167
data = []
166168
for data_path in data_paths:
169+
if not data_path:
170+
continue
167171
with open(data_path, "r") as f:
168172
raw_ds = json.load(f)
169173
raw_ds = map(lambda x: x["paragraphs"], raw_ds["data"])
@@ -174,7 +178,7 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
174178

175179
self.data = []
176180
self.max_a_len = 0
177-
if len(data_paths)==1 and ('wiki' in data_paths[0] or 'woz' in data_paths[0]):
181+
if len(data_paths)==1 and data_paths[0] is not None and ('wiki' in data_paths[0] or 'woz' in data_paths[0]):
178182
#data = self._sort_by_index(data)
179183
#args.n_workers = 1
180184
if 'wiki' in data_paths[0]:
@@ -183,7 +187,8 @@ def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
183187
answers_file = "woz.en_answers.json"
184188
with open(os.path.join(args.data_dir,answers_file),"r") as f:
185189
self.answers = json.load(f)
186-
self.data_tokenization(data)
190+
if len(data) > 0:
191+
self.data_tokenization(data)
187192

188193
if len(extra_data) > 0:
189194
extra_data = map(lambda x: self.etl_single_extra_data(x), extra_data)
@@ -345,11 +350,26 @@ def __call__(self, loss, scheduler_steps):
345350
self.optimizer.backward(loss, update_master_grads=False)
346351
else:
347352
loss.backward()
353+
348354
if not args.fp32:
349355
self.optimizer.update_master_grads()
350356
self.optimizer.clip_master_grads(args.max_grad_norm)
351357
else:
352358
torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm)
359+
360+
if "gem" in args.seq_train_type and self.model.task_id >0:
361+
store_grad(self.model.parameters, self.model.grads, self.model.grad_dims,self.model.task_id)
362+
indx = torch.cuda.LongTensor([i for i in range(self.model.task_id)])
363+
dotp = torch.mm(self.model.grads[:, self.model.task_id].unsqueeze(0),
364+
self.model.grads.index_select(1, indx))
365+
if (dotp < 0).sum() != 0:
366+
project2cone2(self.model.grads[:, self.model.task_id].unsqueeze(1),
367+
self.model.grads.index_select(1, indx), args.qp_margin)
368+
# copy gradients back
369+
overwrite_grad(self.model.parameters,
370+
self.model.grads[:, self.model.task_id],
371+
self.model.grad_dims)
372+
353373
if args.seq_train_type in args.REG_TYPE_KEYS:
354374
self.optimizer.step(self.model.reg_params)
355375
else:
@@ -360,6 +380,58 @@ def __call__(self, loss, scheduler_steps):
360380
self.optimizer.zero_grad()
361381

362382

383+
class GEMStep:
384+
def __init__(self, model, parallel_model, train_loss_fct, optimizer):
385+
self.model = model
386+
self.parallel_model = parallel_model
387+
self.train_loss_fct = train_loss_fct
388+
self.optimizer = optimizer
389+
390+
def __call__(self,current_task_id):
391+
for past_task_id, md in enumerate(args.memory_data):
392+
# Not saving current task's grads.
393+
if past_task_id >= current_task_id: return
394+
qadata = QADataset(None, "test", "gen", md)[:90]
395+
dataloader = create_dataloader(qadata, "test")
396+
grads_tmp = torch.zeros(sum(self.model.grad_dims),).cuda()
397+
if not args.fp32:
398+
grads_tmp = grads_tmp.half()
399+
for _, _, cqa, _, Y, gen_X, gen_Y in dataloader:
400+
#CHECK
401+
n_inputs = sum(_cqa.shape[0] for _cqa in cqa)
402+
self.optimizer.zero_grad()
403+
for i in range(len(cqa)):
404+
cqa[i] = (cqa[i].to(args.device_ids[i]),)
405+
Y[i] = Y[i].to(args.device_ids[i])
406+
gen_X[i] = (gen_X[i].to(args.device_ids[i]),)
407+
gen_Y[i] = gen_Y[i].to(args.device_ids[i])
408+
409+
losses = get_losses(self.parallel_model, cqa, Y, gen_X, gen_Y, self.train_loss_fct)
410+
loss = sum(losses)
411+
if not args.fp32:
412+
self.optimizer.backward(loss, update_master_grads=False)
413+
else:
414+
loss.backward()
415+
416+
if not args.fp32:
417+
#copy fp16 grads to fp32 grads
418+
self.optimizer.update_master_grads()
419+
self.optimizer.clip_master_grads(args.max_grad_norm)
420+
else:
421+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm)
422+
i = 0
423+
for param in self.model.parameters():
424+
if param.grad is not None:
425+
beg = 0 if i == 0 else sum(self.model.grad_dims[:i])
426+
end = sum(self.model.grad_dims[:i+1])
427+
grads_tmp[beg: end] += param.grad.data.view(-1)*n_inputs
428+
i += 1
429+
430+
grads_tmp /= len(qadata)
431+
self.model.grads[:, past_task_id].copy_(grads_tmp)
432+
self.optimizer.zero_grad()
433+
434+
363435
class DynamicBatchSampler(Sampler):
364436
def __init__(self, dataset, data_type, max_batch_size):
365437
self.dataset = dataset
@@ -523,11 +595,15 @@ def parse_single_real_data(data,task):
523595
return data
524596

525597

526-
def get_real_data(task, train_extra_data):
598+
def get_real_data(task, train_extra_data, accum=True, encode=True):
527599
task_idx = args.tasks.index(task)
528-
prev_tasks = args.tasks[:task_idx]
529600
gen_size = DATA_ATTRS[task]["train"]["data_size"]
530-
gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage))//len(prev_tasks)
601+
if accum:
602+
prev_tasks = args.tasks[:task_idx]
603+
gen_size = int(np.ceil(gen_size * args.gen_lm_sample_percentage))//len(prev_tasks)
604+
else:
605+
prev_tasks = [args.tasks[task_idx-1]]
606+
gen_size = int(gen_size * args.gen_lm_sample_percentage)
531607

532608
datum = []
533609
for prev_task in prev_tasks:
@@ -537,11 +613,13 @@ def get_real_data(task, train_extra_data):
537613
for i in indices:
538614
d = parse_single_real_data(data[i],prev_task)
539615
datum.append(d)
540-
train_extra_data.append(TOKENIZER.encode(d))
616+
if encode:
617+
train_extra_data.append(TOKENIZER.encode(d))
541618

542619
model_dir = get_model_dir([prev_task])
543620
dump_path = os.path.join(model_dir,"real.csv")
544621
write_extra_data(dump_path, datum)
622+
return dump_path
545623

546624

547625
def read_extra_data(gen_path, train_extra_data):
@@ -728,3 +806,39 @@ def get_split_indices(data_sizes,chunk_sizes):
728806
chunk_sizes.pop(0)
729807
i+=1
730808
return records
809+
810+
811+
def store_grad(get_ps, grads, grad_dims, task_id):
812+
i = 0
813+
for param in get_ps():
814+
if param.grad is not None:
815+
beg = 0 if i == 0 else sum(grad_dims[:i])
816+
end = sum(grad_dims[:i+1])
817+
grads[beg: end, task_id].copy_(param.grad.data.view(-1))
818+
i += 1
819+
820+
821+
def overwrite_grad(pp, newgrad, grad_dims):
822+
cnt = 0
823+
for param in pp():
824+
if param.grad is not None:
825+
beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
826+
en = sum(grad_dims[:cnt + 1])
827+
this_grad = newgrad[beg: en].contiguous().view(
828+
param.grad.data.size())
829+
param.grad.data.copy_(this_grad)
830+
cnt += 1
831+
832+
833+
def project2cone2(gradient, memories, margin=0.5, eps=1e-3):
834+
memories_np = memories.cpu().t().double().numpy()
835+
gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
836+
t = memories_np.shape[0]
837+
P = np.dot(memories_np, memories_np.transpose())
838+
P = 0.5 * (P + P.transpose()) + np.eye(t) * eps
839+
q = np.dot(memories_np, gradient_np) * -1
840+
G = np.eye(t)
841+
h = np.zeros(t) + margin
842+
v = quadprog.solve_qp(P, q, G, h)[0]
843+
x = np.dot(v, memories_np) + gradient_np
844+
gradient.copy_(torch.Tensor(x).view(-1, 1))

0 commit comments

Comments
 (0)