Skip to content

Commit

Permalink
🐛 Fix some data processor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Nov 14, 2024
1 parent 2b147ff commit aa12363
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions hezar/data/dataset_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non

data["pixel_values"] = self.image_processor(path, return_tensors=return_tensors)["pixel_values"]
data["labels"] = tokenized_inputs["token_ids"]
data["attention_mask"] = tokenized_inputs["attention_mask"]
data["decoder_attention_mask"] = tokenized_inputs["attention_mask"]
data["decoder_input_ids"] = self._shift_tokens_right(
data["labels"],
[data["labels"]],
pad_token_id=self.tokenizer.pad_token_id,
decoder_start_token_id=self.tokenizer.bos_token_id,
)
)[0]

return data

Expand Down Expand Up @@ -270,9 +270,7 @@ def process_batch(self, data, return_tensors=None):
pixel_values = self.image_processor(paths, return_tensors=return_tensors)["pixel_values"]

# Process text labels in batch
labels = []
for text in texts:
labels.append(self._text_to_ids(text))
labels = [self._text_to_ids(text) for text in texts]

return {"pixel_values": pixel_values, "labels": labels}

Expand Down Expand Up @@ -363,7 +361,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non
padding=padding,
max_length=max_length,
)

tokenized_inputs = {k: v[0] for k, v in tokenized_inputs.items()}
data.update(tokenized_inputs)

return data
Expand Down Expand Up @@ -538,7 +536,7 @@ def process_single(self, data, return_tensors=None, padding=None, max_length=Non
return_tensors=return_tensors,
)
data.update(inputs)
data["labels"] = torch.tensor([label], dtype=torch.long)
data["labels"] = torch.tensor(label, dtype=torch.long)

return data

Expand Down

0 comments on commit aa12363

Please sign in to comment.