Skip to content

Commit c3449dc

Browse files
committed
Refactor data modules, update requirements.txt, move configs to root
1 parent 71d4023 commit c3449dc

36 files changed

+420
-517
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ __pycache__/
77
*.py[cod]
88
*$py.class
99

10-
data/
1110
wandb/
1211
notebooks/
1312
outputs/

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ Minimal code example to run the model:
2121
from os.path import join
2222

2323
import hydra
24-
from code2seq.dataset import PathContextDataModule
24+
from code2seq.data import PathContextDataModule
2525
from code2seq.model import Code2Seq
26-
from code2seq.utils.vocabulary import Vocabulary
26+
from code2seq.data.vocabulary import Vocabulary
2727
from omegaconf import DictConfig
2828
from pytorch_lightning import Trainer
2929

@@ -43,5 +43,5 @@ if __name__ == "__main__":
4343
train()
4444
```
4545

46-
Navigate to [code2seq/configs](code2seq/configs) to see examples of configs.
46+
Navigate to [code2seq/configs](configs) to see examples of configs.
4747
If you had any questions then feel free to open the issue.
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1+
from .path_context import (
2+
Path,
3+
LabeledPathContext,
4+
BatchedLabeledPathContext,
5+
TypedPath,
6+
LabeledTypedPathContext,
7+
BatchedLabeledTypedPathContext,
8+
)
19
from .path_context_dataset import PathContextDataset
2-
from .data_classes import PathContextSample, PathContextBatch
310
from .path_context_data_module import PathContextDataModule
411
from .typed_path_context_dataset import TypedPathContextDataset
512
from .typed_path_context_data_module import TypedPathContextDataModule
613

714
__all__ = [
15+
"Path",
16+
"LabeledPathContext",
17+
"BatchedLabeledPathContext",
818
"PathContextDataset",
9-
"PathContextSample",
10-
"PathContextBatch",
1119
"PathContextDataModule",
20+
"TypedPath",
21+
"LabeledTypedPathContext",
22+
"BatchedLabeledTypedPathContext",
1223
"TypedPathContextDataset",
1324
"TypedPathContextDataModule",
1425
]

code2seq/data/path_context.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from dataclasses import dataclass
2+
from typing import List, Iterable, Tuple, Optional
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class Path:
9+
from_token: torch.Tensor # [max token parts]
10+
path_node: torch.Tensor # [path length]
11+
to_token: torch.Tensor # [max token parts]
12+
13+
14+
@dataclass
15+
class LabeledPathContext:
16+
label: torch.Tensor # [max label parts]
17+
path_contexts: List[Path]
18+
19+
20+
class BatchedLabeledPathContext:
21+
def __init__(self, samples: List[Optional[LabeledPathContext]]):
22+
samples = [s for s in samples if s is not None]
23+
24+
# [batch size; max label parts]
25+
self.labels = torch.cat([s.label for s in samples], dim=1)
26+
# [batch size]
27+
self.contexts_per_label = [len(s.path_contexts) for s in samples]
28+
29+
# [paths in batch; max token parts]
30+
self.from_token = torch.cat([path.from_token for s in samples for path in s.path_contexts], dim=1)
31+
# [paths in batch; path length]
32+
self.path_node = torch.cat([path.path_node for s in samples for path in s.path_contexts], dim=1)
33+
# [paths in batch; max token parts]
34+
self.to_token = torch.cat([path.to_token for s in samples for path in s.path_contexts], dim=1)
35+
36+
def __len__(self) -> int:
37+
return len(self.contexts_per_label)
38+
39+
def __get_all_tensors(self) -> Iterable[Tuple[str, torch.Tensor]]:
40+
for name, value in vars(self).items():
41+
if isinstance(value, torch.Tensor):
42+
yield name, value
43+
44+
def pin_memory(self) -> "BatchedLabeledPathContext":
45+
for name, value in self.__get_all_tensors():
46+
setattr(self, name, value.pin_memory())
47+
return self
48+
49+
def move_to_device(self, device: torch.device):
50+
for name, value in self.__get_all_tensors():
51+
setattr(self, name, value.to(device))
52+
53+
54+
@dataclass
55+
class TypedPath(Path):
56+
from_type: torch.Tensor # [max type parts]
57+
to_type: torch.Tensor # [max type parts]
58+
59+
60+
@dataclass
61+
class LabeledTypedPathContext(LabeledPathContext):
62+
path_contexts: List[TypedPath]
63+
64+
65+
class BatchedLabeledTypedPathContext(BatchedLabeledPathContext):
66+
def __init__(self, samples: List[Optional[LabeledTypedPathContext]]):
67+
super().__init__(samples)
68+
# [paths in batch; max type parts]
69+
self.from_type = torch.cat([path.from_type for s in samples for path in s.path_contexts], dim=1)
70+
# [paths in batch; max type parts]
71+
self.to_type = torch.cat([path.to_type for s in samples for path in s.path_contexts], dim=1)
+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from os.path import exists, join, basename
2+
from typing import List, Optional
3+
4+
import torch
5+
from commode_utils.common import download_dataset
6+
from commode_utils.vocabulary import build_from_scratch
7+
from omegaconf import DictConfig
8+
from pytorch_lightning import LightningDataModule
9+
from torch.utils.data import DataLoader
10+
11+
from code2seq.data import PathContextDataset, LabeledPathContext, BatchedLabeledPathContext
12+
from code2seq.data.vocabulary import Vocabulary
13+
14+
15+
class PathContextDataModule(LightningDataModule):
16+
_train = "train"
17+
_val = "val"
18+
_test = "test"
19+
20+
_vocabulary: Optional[Vocabulary] = None
21+
22+
def __init__(self, data_dir: str, config: DictConfig):
23+
super().__init__()
24+
self._config = config
25+
self._data_dir = data_dir
26+
self._name = basename(data_dir)
27+
28+
@property
29+
def vocabulary(self) -> Vocabulary:
30+
if self._vocabulary is None:
31+
raise RuntimeError(f"Setup data module for initializing vocabulary")
32+
return self._vocabulary
33+
34+
def prepare_data(self):
35+
if exists(self._data_dir):
36+
print(f"Dataset is already downloaded")
37+
return
38+
if "url" not in self._config:
39+
raise ValueError(f"Config doesn't contain url for, can't download it automatically")
40+
download_dataset(self._config.url, self._data_dir, self._name)
41+
42+
def setup(self, stage: Optional[str] = None):
43+
if not exists(join(self._data_dir, Vocabulary.vocab_filename)):
44+
print("Can't find vocabulary, collect it from train holdout")
45+
build_from_scratch(join(self._data_dir, f"{self._train}.jsonl"), Vocabulary)
46+
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
47+
self._vocabulary = Vocabulary(vocabulary_path, self._config.max_labels, self._config.max_tokens)
48+
49+
@staticmethod
50+
def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext:
51+
return BatchedLabeledPathContext(batch)
52+
53+
def _create_dataset(self, holdout_file: str, random_context: bool) -> PathContextDataset:
54+
return PathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
55+
56+
def _shared_dataloader(self, holdout: str) -> DataLoader:
57+
if self._vocabulary is None:
58+
raise RuntimeError(f"Setup vocabulary before creating data loaders")
59+
60+
holdout_file = join(self._data_dir, f"{holdout}.jsonl")
61+
random_context = self._config.random_context if holdout == self._train else False
62+
dataset = self._create_dataset(holdout_file, random_context)
63+
64+
batch_size = self._config.batch_size if holdout == self._train else self._config.test_batch_size
65+
shuffle = holdout == self._train
66+
67+
return DataLoader(
68+
dataset,
69+
batch_size,
70+
shuffle=shuffle,
71+
num_workers=self._config.num_workers,
72+
collate_fn=self.collate_wrapper,
73+
pin_memory=True,
74+
)
75+
76+
def train_dataloader(self, *args, **kwargs) -> DataLoader:
77+
return self._shared_dataloader(self._train)
78+
79+
def val_dataloader(self, *args, **kwargs) -> DataLoader:
80+
return self._shared_dataloader(self._val)
81+
82+
def test_dataloader(self, *args, **kwargs) -> DataLoader:
83+
return self._shared_dataloader(self._test)
84+
85+
def transfer_batch_to_device(
86+
self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int
87+
) -> BatchedLabeledPathContext:
88+
batch.move_to_device(device)
89+
return batch

code2seq/data/path_context_dataset.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from os.path import exists
2+
from typing import Dict, List, Optional
3+
4+
import torch
5+
from commode_utils.filesystem import get_lines_offsets, get_line_by_offset
6+
from omegaconf import DictConfig
7+
from torch.utils.data import Dataset
8+
9+
from code2seq.data.path_context import LabeledPathContext, Path
10+
from code2seq.data.vocabulary import Vocabulary
11+
12+
13+
class PathContextDataset(Dataset):
14+
_log_file = "bad_samples.log"
15+
_separator = "|"
16+
17+
def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool):
18+
if not exists(data_file):
19+
raise ValueError(f"Can't find file with data: {data_file}")
20+
self._data_file = data_file
21+
self._config = config
22+
self._vocab = vocabulary
23+
self._random_context = random_context
24+
25+
self._label_unk = vocabulary.label_to_id[vocabulary.UNK]
26+
27+
self._line_offsets = get_lines_offsets(data_file)
28+
self._n_samples = len(self._line_offsets)
29+
30+
open(self._log_file, "w").close()
31+
32+
def __len__(self):
33+
return self._n_samples
34+
35+
def __getitem__(self, index) -> Optional[LabeledPathContext]:
36+
raw_sample = get_line_by_offset(self._data_file, self._line_offsets[index])
37+
try:
38+
raw_label, *raw_path_contexts = raw_sample.split()
39+
except ValueError as e:
40+
with open(self._log_file, "a") as f_out:
41+
f_out.write(f"Error reading sample from line #{index}: {e}")
42+
return None
43+
44+
# Choose paths for current data sample
45+
n_contexts = min(len(raw_path_contexts), self._config.max_context)
46+
if self._random_context:
47+
raw_path_contexts.shuffle()
48+
raw_path_contexts = raw_path_contexts[:n_contexts]
49+
50+
# Tokenize label
51+
label = self._get_label(raw_label)
52+
53+
# Tokenize paths
54+
try:
55+
paths = [self._get_path(raw_path.split(",")) for raw_path in raw_path_contexts]
56+
except ValueError as e:
57+
with open(self._log_file, "a") as f_out:
58+
f_out.write(f"Error parsing sample from line #{index}: {e}")
59+
return None
60+
61+
return LabeledPathContext(label, paths)
62+
63+
def _get_label(self, raw_label: str) -> torch.Tensor:
64+
label = torch.full((self._config.max_label_parts + 1, 1), self._vocab.label_to_id[self._vocab.PAD])
65+
label[0, 0] = self._vocab.label_to_id[self._vocab.SOS]
66+
sublabels = raw_label.split(self._separator)[: self._config.max_label_parts]
67+
label[1 : len(sublabels) + 1, 0] = torch.tensor(
68+
[self._vocab.label_to_id.get(sl, self._label_unk) for sl in sublabels]
69+
)
70+
if len(sublabels) < self._config.max_label_parts:
71+
label[len(sublabels) + 1, 0] = self._vocab.label_to_id[self._vocab.EOS]
72+
return label
73+
74+
def _tokenize_token(self, token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor:
75+
sub_tokens = token.split(self._separator)
76+
max_parts = max_parts or len(sub_tokens)
77+
token_unk = vocab[self._vocab.UNK]
78+
79+
result = torch.full((max_parts,), vocab[self._vocab.PAD], dtype=torch.long)
80+
sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
81+
result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids)
82+
return result
83+
84+
def _get_path(self, raw_path: List[str]) -> Path:
85+
return Path(
86+
from_token=self._tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts),
87+
path_node=self._tokenize_token(raw_path[1], self._vocab.node_to_id, None),
88+
to_token=self._tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts),
89+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import List, Optional
2+
3+
from omegaconf import DictConfig
4+
5+
from code2seq.data import (
6+
PathContextDataModule,
7+
TypedPathContextDataset,
8+
BatchedLabeledTypedPathContext,
9+
LabeledTypedPathContext,
10+
)
11+
from code2seq.data.vocabulary import TypedVocabulary
12+
13+
14+
class TypedPathContextDataModule(PathContextDataModule):
15+
_vocabulary: Optional[TypedVocabulary] = None
16+
17+
def __init__(self, data_dir: str, config: DictConfig):
18+
super().__init__(data_dir, config)
19+
20+
@staticmethod
21+
def collate_wrapper(batch: List[Optional[LabeledTypedPathContext]]) -> BatchedLabeledTypedPathContext:
22+
return BatchedLabeledTypedPathContext(batch)
23+
24+
def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathContextDataset:
25+
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List
2+
3+
from omegaconf import DictConfig
4+
5+
from code2seq.data import PathContextDataset, TypedPath
6+
from code2seq.data.vocabulary import TypedVocabulary
7+
8+
9+
class TypedPathContextDataset(PathContextDataset):
10+
def __init__(self, data_file: str, config: DictConfig, vocabulary: TypedVocabulary, random_context: bool):
11+
super().__init__(data_file, config, vocabulary, random_context)
12+
self._vocab = vocabulary
13+
14+
def _get_path(self, raw_path: List[str]) -> TypedPath:
15+
return TypedPath(
16+
from_type=self._tokenize_token(raw_path[0], self._vocab.type_to_id, self._config.max_type_parts),
17+
from_token=self._tokenize_token(raw_path[1], self._vocab.token_to_id, self._config.max_token_parts),
18+
path_node=self._tokenize_token(raw_path[2], self._vocab.node_to_id, None),
19+
to_token=self._tokenize_token(raw_path[3], self._vocab.token_to_id, self._config.max_token_parts),
20+
to_type=self._tokenize_token(raw_path[4], self._vocab.type_to_id, self._config.max_type_parts),
21+
)

0 commit comments

Comments
 (0)