|
| 1 | +from os.path import exists |
| 2 | +from typing import Dict, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from commode_utils.filesystem import get_lines_offsets, get_line_by_offset |
| 6 | +from omegaconf import DictConfig |
| 7 | +from torch.utils.data import Dataset |
| 8 | + |
| 9 | +from code2seq.data.path_context import LabeledPathContext, Path |
| 10 | +from code2seq.data.vocabulary import Vocabulary |
| 11 | + |
| 12 | + |
| 13 | +class PathContextDataset(Dataset): |
| 14 | + _log_file = "bad_samples.log" |
| 15 | + _separator = "|" |
| 16 | + |
| 17 | + def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool): |
| 18 | + if not exists(data_file): |
| 19 | + raise ValueError(f"Can't find file with data: {data_file}") |
| 20 | + self._data_file = data_file |
| 21 | + self._config = config |
| 22 | + self._vocab = vocabulary |
| 23 | + self._random_context = random_context |
| 24 | + |
| 25 | + self._label_unk = vocabulary.label_to_id[vocabulary.UNK] |
| 26 | + |
| 27 | + self._line_offsets = get_lines_offsets(data_file) |
| 28 | + self._n_samples = len(self._line_offsets) |
| 29 | + |
| 30 | + open(self._log_file, "w").close() |
| 31 | + |
| 32 | + def __len__(self): |
| 33 | + return self._n_samples |
| 34 | + |
| 35 | + def __getitem__(self, index) -> Optional[LabeledPathContext]: |
| 36 | + raw_sample = get_line_by_offset(self._data_file, self._line_offsets[index]) |
| 37 | + try: |
| 38 | + raw_label, *raw_path_contexts = raw_sample.split() |
| 39 | + except ValueError as e: |
| 40 | + with open(self._log_file, "a") as f_out: |
| 41 | + f_out.write(f"Error reading sample from line #{index}: {e}") |
| 42 | + return None |
| 43 | + |
| 44 | + # Choose paths for current data sample |
| 45 | + n_contexts = min(len(raw_path_contexts), self._config.max_context) |
| 46 | + if self._random_context: |
| 47 | + raw_path_contexts.shuffle() |
| 48 | + raw_path_contexts = raw_path_contexts[:n_contexts] |
| 49 | + |
| 50 | + # Tokenize label |
| 51 | + label = self._get_label(raw_label) |
| 52 | + |
| 53 | + # Tokenize paths |
| 54 | + try: |
| 55 | + paths = [self._get_path(raw_path.split(",")) for raw_path in raw_path_contexts] |
| 56 | + except ValueError as e: |
| 57 | + with open(self._log_file, "a") as f_out: |
| 58 | + f_out.write(f"Error parsing sample from line #{index}: {e}") |
| 59 | + return None |
| 60 | + |
| 61 | + return LabeledPathContext(label, paths) |
| 62 | + |
| 63 | + def _get_label(self, raw_label: str) -> torch.Tensor: |
| 64 | + label = torch.full((self._config.max_label_parts + 1, 1), self._vocab.label_to_id[self._vocab.PAD]) |
| 65 | + label[0, 0] = self._vocab.label_to_id[self._vocab.SOS] |
| 66 | + sublabels = raw_label.split(self._separator)[: self._config.max_label_parts] |
| 67 | + label[1 : len(sublabels) + 1, 0] = torch.tensor( |
| 68 | + [self._vocab.label_to_id.get(sl, self._label_unk) for sl in sublabels] |
| 69 | + ) |
| 70 | + if len(sublabels) < self._config.max_label_parts: |
| 71 | + label[len(sublabels) + 1, 0] = self._vocab.label_to_id[self._vocab.EOS] |
| 72 | + return label |
| 73 | + |
| 74 | + def _tokenize_token(self, token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor: |
| 75 | + sub_tokens = token.split(self._separator) |
| 76 | + max_parts = max_parts or len(sub_tokens) |
| 77 | + token_unk = vocab[self._vocab.UNK] |
| 78 | + |
| 79 | + result = torch.full((max_parts,), vocab[self._vocab.PAD], dtype=torch.long) |
| 80 | + sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]] |
| 81 | + result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids) |
| 82 | + return result |
| 83 | + |
| 84 | + def _get_path(self, raw_path: List[str]) -> Path: |
| 85 | + return Path( |
| 86 | + from_token=self._tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts), |
| 87 | + path_node=self._tokenize_token(raw_path[1], self._vocab.node_to_id, None), |
| 88 | + to_token=self._tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts), |
| 89 | + ) |
0 commit comments