From 9f974c9c855535e1d1435236f7367cbefe6bbdda Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Sat, 21 May 2022 11:49:09 +0200 Subject: [PATCH] small adjustments --- pix2tex/model/settings/config-vit.yaml | 2 +- pix2tex/model/settings/config.yaml | 2 +- pix2tex/train.py | 30 +++++++------------------- pix2tex/utils/utils.py | 28 +++++++++++++++++++----- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pix2tex/model/settings/config-vit.yaml b/pix2tex/model/settings/config-vit.yaml index f434be2..3d94e84 100644 --- a/pix2tex/model/settings/config-vit.yaml +++ b/pix2tex/model/settings/config-vit.yaml @@ -29,7 +29,7 @@ max_seq_len: 512 max_width: 672 min_height: 32 min_width: 32 -micro_batchsize: 64 +micro_batchsize: -1 model_path: checkpoints_add name: pix2tex-vit num_layers: 4 diff --git a/pix2tex/model/settings/config.yaml b/pix2tex/model/settings/config.yaml index fa1b3b7..a579f9e 100644 --- a/pix2tex/model/settings/config.yaml +++ b/pix2tex/model/settings/config.yaml @@ -6,7 +6,7 @@ backbone_layers: betas: - 0.9 - 0.999 -batchsize: 10 +batchsize: 64 bos_token: 1 channels: 1 data: dataset/data/train.pkl diff --git a/pix2tex/train.py b/pix2tex/train.py index a226abd..bd2f599 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -12,23 +12,7 @@ from pix2tex.eval import evaluate from pix2tex.models import get_model # from pix2tex.utils import * -from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler - - -def gpu_memory_check(model, args): - # check if largest batch can be handled by system - try: - batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize - for _ in range(5): - im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() - seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() - loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq) - loss.sum().backward() - except RuntimeError: - raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width)) - model.zero_grad() - with torch.cuda.device(args.device):torch.cuda.empty_cache() - del im, seq +from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check def train(args): @@ -40,13 +24,15 @@ def train(args): valdataloader.update(**valargs) device = args.device model = get_model(args) - gpu_memory_check(model, args) - if args.load_chkpt is not None: - model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) + if torch.cuda.is_available() and not args.no_cuda: + gpu_memory_check(model, args) max_bleu, max_token_acc = 0, 0 out_path = os.path.join(args.model_path, args.name) os.makedirs(out_path, exist_ok=True) + if args.load_chkpt is not None: + model.load_state_dict(torch.load(args.load_chkpt, map_location=device)) + def save_models(e, step=0): torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step))) yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+')) @@ -88,9 +74,9 @@ def save_models(e, step=0): wandb.log({'train/epoch': e+1}) except KeyboardInterrupt: if e >= 2: - save_models(e) + save_models(e, step=i) raise KeyboardInterrupt - save_models(e) + save_models(e, step=len(dataloader)) if __name__ == '__main__': diff --git a/pix2tex/utils/utils.py b/pix2tex/utils/utils.py index 07d4708..2b5f920 100644 --- a/pix2tex/utils/utils.py +++ b/pix2tex/utils/utils.py @@ -52,8 +52,9 @@ def seed_everything(seed: int): def parse_args(args, **kwargs) -> Munch: args = Munch({'epoch': 0}, **args) kwargs = Munch({'no_cuda': False, 'debug': False}, **kwargs) + args.update(kwargs) args.wandb = not kwargs.debug and not args.debug - args.device = get_device(args, kwargs) + args.device = get_device(args, kwargs.no_cuda) args.max_dimensions = [args.max_width, args.max_height] args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] if 'decoder_args' not in args or args.decoder_args is None: @@ -61,17 +62,34 @@ def parse_args(args, **kwargs) -> Munch: return args -def get_device(args, kwargs): +def get_device(args, no_cuda=False): device = 'cpu' available_gpus = torch.cuda.device_count() - args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else range(available_gpus) - if available_gpus > 0 and not kwargs.no_cuda: + args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else list(range(available_gpus)) + if available_gpus > 0 and not no_cuda: device = 'cuda:%d' % args.gpu_devices[0] if args.gpu_devices else 0 assert available_gpus >= len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices))) - assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices))) + assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))), ','.join(map(str, args.gpu_devices))) return device +def gpu_memory_check(model, args): + # check if largest batch can be handled by system + try: + batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize + for _ in range(5): + im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float() + seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long() + loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq) + loss.sum().backward() + except RuntimeError: + raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width)) + model.zero_grad() + with torch.cuda.device(args.device): + torch.cuda.empty_cache() + del im, seq + + def token2str(tokens, tokenizer) -> list: if len(tokens.shape) == 1: tokens = tokens[None, :]