-
Notifications
You must be signed in to change notification settings - Fork 65
Open
Description
I'm doing a NMT task.I use my own data loading function rather than using torch dataset.I got an "int object doesn't has attribute 'size' " error.
Here's my data loading code:
def get_batches(sz, pad=0):
for i in range(0, len(datatmp), sz):
n=0
srcdata = []
trgdata = []
for j in range(n, sz):
srcdata.append(datatmp[i+j][0])
trgdata.append(datatmp[i+j][1])
a = randint(1, 2)
src_max_seq_length=max([len(srcdata[i]) for i in range(len(srcdata))])
trg_max_seq_length=max([len(trgdata[i]) for i in range(len(trgdata))])
# pad src to src_max_seq_length
for i in range(len(srcdata)):
srcdata[i] = srcdata[i] + [pad for j in range(src_max_seq_length-len(srcdata[i]))]
#pad trg to trg_max_seq_length
for i in range(len(trgdata)):
trgdata[i] = trgdata[i] + [pad for j in range(trg_max_seq_length-len(trgdata[i]))]
sr = np.ndarray(shape=(sz, src_max_seq_length))
tg = np.ndarray(shape=(sz, trg_max_seq_length))
for i in range(len(srcdata)):
for j in range(len(srcdata[i])):
sr[i][j] = srcdata[i][j]
for i in range(len(trgdata)):
for j in range(len(trgdata[i])):
tg[i][j] = trgdata[i][j]
#srcdata = np.array(srcdata)
#trgdata = np.array(trgdata)
srcdata = torch.from_numpy(sr)
trgdata = torch.from_numpy(tg)
src = Variable(srcdata, requires_grad=False).long()
trg = Variable(trgdata, requires_grad=False).long()
yield Batch(src, trg, pad)#Batch is only a simple class
class Batch:
"Object for holding a batch of data with mask during training."
def __init__(self, src, trg=None, pad=0):
self.src = src
self.src_mask = (src != pad).unsqueeze(-2)
if trg is not None:
self.trg = trg[:, :-1]
self.trg_y = trg[:, 1:]
self.trg_mask = \
self.make_std_mask(self.trg, pad)
self.ntokens = (self.trg_y != pad).data.sum()
@staticmethod
def make_std_mask(tgt, pad):
"Create a mask to hide padding and future words."
tgt_mask = (tgt != pad).unsqueeze(-2)
tgt_mask = tgt_mask & Variable(
subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
return tgt_mask
ps:The code is adapted from 'Annotated Transformer'
Metadata
Metadata
Assignees
Labels
No labels