Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatible with torch.utils.data.DataLoader #154

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import imagesize
import logging
Expand All @@ -15,10 +16,10 @@

from pix2tex.utils.utils import in_model_path
from pix2tex.dataset.transforms import train_transform, test_transform
import math



class Im2LatexDataset:
class Im2LatexDataset(IterableDataset):
keep_smaller_batches = False
shuffle = True
batchsize = 16
Expand All @@ -33,6 +34,7 @@ class Im2LatexDataset:
eos_token_id = 2
transform = train_transform
data = defaultdict(lambda: [])
permutation = None

def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024,
max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False):
Expand All @@ -42,7 +44,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
equations (str, optional): Path to equations. Defaults to None.
images (str, optional): Directory where images are saved. Defaults to None.
tokenizer (str, optional): Path to saved tokenizer. Defaults to None.
shuffle (bool, opitonal): Defaults to True.
shuffle (bool, opitonal): Defaults to True.
batchsize (int, optional): Defaults to 16.
max_seq_len (int, optional): Defaults to 1024.
max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle
Expand Down Expand Up @@ -75,32 +77,39 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
self.data[(width, height)].append((eqs[self.indices[i]], im))
except KeyboardInterrupt:
pass
# formula&image pairs grouped by image size
self.data = dict(self.data)
self._get_size()

self._shuffle()
iter(self)

def __len__(self):
return self.size
return self.size # total number of batches given the batchsize

def __iter__(self):
self.i = 0
self.transform = test_transform if self.test else train_transform
self.pairs = []
for k in self.data:
info = np.array(self.data[k], dtype=object)
p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info))
for i in range(0, len(info), self.batchsize):
batch = info[p[i:i+self.batchsize]]
batch = info[i:i+self.batchsize]
if len(batch.shape) == 1:
batch = batch[None, :]
if len(batch) < self.batchsize and not self.keep_smaller_batches:
continue
self.pairs.append(batch)
if self.shuffle:
self.pairs = np.random.permutation(np.array(self.pairs, dtype=object))
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
# configure the dataset to only process the split workload
per_worker = int(math.ceil(self.size/float(worker_info.num_workers)))
worker_id = worker_info.id
self.start = worker_id * per_worker
self.end = min(self.start + per_worker, self.size)
else:
self.pairs = np.array(self.pairs, dtype=object)
self.start, self.end = 0, self.size

self.pairs = np.array(self.pairs, dtype=object)[self.permutation[self.start:self.end]]
self.size = len(self.pairs)
return self

Expand All @@ -121,6 +130,8 @@ def prepare_data(self, batch):
"""

eqs, ims = batch.T
# for im in ims:
# print(im)
tok = self.tokenizer(list(eqs), return_token_type_ids=False)
# pad with bos and eos token
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
Expand Down Expand Up @@ -155,6 +166,15 @@ def _get_size(self):
for k in self.data:
div, mod = divmod(len(self.data[k]), self.batchsize)
self.size += div # + (1 if mod > 0 else 0)
if self.permutation is None or len(self.permutation) != self.size:
self._shuffle()

def _shuffle(self):
if self.shuffle:
self.permutation = np.random.permutation(self.size)
else:
self.permutation = np.arange(self.size)
return self

def load(self, filename, args=[]):
"""returns a pickled version of a dataset
Expand All @@ -169,6 +189,7 @@ def load(self, filename, args=[]):
filename = os.path.realpath(tempf)
with open(filename, 'rb') as file:
x = pickle.load(file)
x._get_size()
return x

def combine(self, x):
Expand Down Expand Up @@ -216,7 +237,19 @@ def update(self, **kwargs):
tokenizer_file = os.path.realpath(tokenizer_file)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
self._get_size()
iter(self)
return iter(self)


class Dataloader(DataLoader):
def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, drop_last=True, num_workers=0, pin_memory=False):
self.dataset = dataset
self.tokenizer = dataset.tokenizer
self.dataset.update(batchsize=batch_size, shuffle=shuffle, keep_smaller_batches=not drop_last)
super().__init__(self.dataset, num_workers=num_workers, shuffle=False, batch_size=None, pin_memory=pin_memory)

def __iter__(self):
self.dataset._shuffle()
return super().__iter__()


def generate_tokenizer(equations, output, vocab_size):
Expand Down
8 changes: 4 additions & 4 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import argparse
import logging
import yaml
Expand Down Expand Up @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer):


@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset

Args:
model (torch.nn.Module): the model
dataset (Im2LatexDataset): test dataset
dataset (Dataloader): test dataset
args (Munch): arguments
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
Expand All @@ -46,7 +46,7 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
log = {}
bleus, edit_dists, token_acc = [], [], []
bleu_score, edit_distance, token_accuracy = 0, 1, 0
pbar = tqdm(enumerate(iter(dataset)), total=len(dataset))
pbar = tqdm(enumerate(dataset), total=len(dataset))
for i, (seq, im) in pbar:
if seq is None or im is None:
continue
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
betas:
- 0.9
- 0.999
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
backbone_layers:
- 2
- 3
Expand Down
4 changes: 4 additions & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ pad: False
pad_token: 0
bos_token: 1
eos_token: 2

#devices(GPU&CPU)
num_workers: 0
gpu_devices: null #[0,1,2,3,4,5,6,7]
22 changes: 10 additions & 12 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import os
import argparse
import logging
Expand All @@ -16,12 +16,10 @@


def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
train_dataset = Im2LatexDataset().load(args.data).update(**args, test=False)
train_dataloader = Dataloader(train_dataset, batch_size=args.batchsize, num_workers=args.num_workers, pin_memory=args.pin_memory)
val_dataset = Im2LatexDataset().load(args.valdata).update(**args, test=True)
val_dataloader = Dataloader(val_dataset, batch_size=args.testbatchsize, num_workers=args.num_workers, drop_last=False, pin_memory=args.pin_memory)
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
Expand All @@ -47,7 +45,7 @@ def save_models(e, step=0):
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
dset = tqdm(train_dataloader)
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
Expand All @@ -63,20 +61,20 @@ def save_models(e, step=0):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (i+1+len(train_dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def parse_args(args, **kwargs) -> Munch:
args.update(kwargs)
args.wandb = not kwargs.debug and not args.debug
args.device = get_device(args, kwargs.no_cuda)
args.num_workers = args.get('num_workers', 0)
args.pin_memory = args.get('pin_memory', False)
args.max_dimensions = [args.max_width, args.max_height]
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
if 'decoder_args' not in args or args.decoder_args is None:
Expand Down