Skip to content

Commit e678f8b

Browse files
authored
Merge pull request #108 from JetBrains-Research/integrate_commode_utils
Integrate commode utils
2 parents 88b1a32 + ef8229d commit e678f8b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1407
-5399
lines changed

Diff for: .gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ __pycache__/
77
*.py[cod]
88
*$py.class
99

10-
data/
1110
wandb/
1211
notebooks/
1312
outputs/

Diff for: README.md

+26-14
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,43 @@ pip install code2seq
1717
## Usage
1818

1919
Minimal code example to run the model:
20+
2021
```python
21-
from os.path import join
22+
from argparse import ArgumentParser
2223

23-
import hydra
24-
from code2seq.dataset import PathContextDataModule
25-
from code2seq.model import Code2Seq
26-
from code2seq.utils.vocabulary import Vocabulary
27-
from omegaconf import DictConfig
24+
from omegaconf import DictConfig, OmegaConf
2825
from pytorch_lightning import Trainer
2926

27+
from code2seq.data.path_context_data_module import PathContextDataModule
28+
from code2seq.model import Code2Seq
29+
3030

31-
@hydra.main(config_path="configs")
3231
def train(config: DictConfig):
33-
vocabulary_path = join(config.data_folder, config.dataset.name, config.vocabulary_name)
34-
vocabulary = Vocabulary.load_vocabulary(vocabulary_path)
35-
model = Code2Seq(config, vocabulary)
36-
data_module = PathContextDataModule(config, vocabulary)
32+
# Load data module
33+
data_module = PathContextDataModule(config.data_folder, config.data)
34+
data_module.prepare_data()
35+
data_module.setup()
36+
37+
# Load model
38+
model = Code2Seq(
39+
config.model,
40+
config.optimizer,
41+
data_module.vocabulary,
42+
config.train.teacher_forcing
43+
)
3744

3845
trainer = Trainer(max_epochs=config.hyper_parameters.n_epochs)
3946
trainer.fit(model, datamodule=data_module)
4047

4148

4249
if __name__ == "__main__":
43-
train()
50+
__arg_parser = ArgumentParser()
51+
__arg_parser.add_argument("config", help="Path to YAML configuration file", type=str)
52+
__args = __arg_parser.parse_args()
53+
54+
__config = OmegaConf.load(__args.config)
55+
train(__config)
4456
```
4557

46-
Navigate to [code2seq/configs](code2seq/configs) to see examples of configs.
47-
If you had any questions then feel free to open the issue.
58+
Navigate to [config](config) directory to see examples of configs.
59+
If you have any questions, then feel free to open the issue.

Diff for: code2seq/code2class_wrapper.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from argparse import ArgumentParser
2+
from typing import cast
3+
4+
import torch
5+
from commode_utils.common import print_config
6+
from omegaconf import DictConfig, OmegaConf
7+
8+
from code2seq.data.path_context_data_module import PathContextDataModule
9+
from code2seq.model import Code2Class
10+
from code2seq.utils.common import filter_warnings
11+
from code2seq.utils.test import test
12+
from code2seq.utils.train import train
13+
14+
15+
def configure_arg_parser() -> ArgumentParser:
16+
arg_parser = ArgumentParser()
17+
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"])
18+
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
19+
return arg_parser
20+
21+
22+
def train_code2class(config: DictConfig):
23+
filter_warnings()
24+
25+
if config.print_config:
26+
print_config(config, fields=["model", "data", "train", "optimizer"])
27+
28+
# Load data module
29+
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
30+
data_module.prepare_data()
31+
data_module.setup()
32+
33+
# Load model
34+
code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary)
35+
36+
train(code2class, data_module, config)
37+
38+
39+
def test_code2class(config: DictConfig):
40+
filter_warnings()
41+
42+
# Load data module
43+
data_module = PathContextDataModule(config.data_folder, config.data)
44+
data_module.prepare_data()
45+
data_module.setup()
46+
47+
# Load model
48+
code2class = Code2Class.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu"))
49+
50+
test(code2class, data_module, config.seed)
51+
52+
53+
if __name__ == "__main__":
54+
__arg_parser = configure_arg_parser()
55+
__args = __arg_parser.parse_args()
56+
57+
__config = cast(DictConfig, OmegaConf.load(__args.config))
58+
if __args.mode == "train":
59+
train_code2class(__config)
60+
else:
61+
test_code2class(__config)

Diff for: code2seq/code2seq_wrapper.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from argparse import ArgumentParser
2+
from typing import cast
3+
4+
import torch
5+
from commode_utils.common import print_config
6+
from omegaconf import DictConfig, OmegaConf
7+
8+
from code2seq.data.path_context_data_module import PathContextDataModule
9+
from code2seq.model import Code2Seq
10+
from code2seq.utils.common import filter_warnings
11+
from code2seq.utils.test import test
12+
from code2seq.utils.train import train
13+
14+
15+
def configure_arg_parser() -> ArgumentParser:
16+
arg_parser = ArgumentParser()
17+
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"])
18+
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
19+
return arg_parser
20+
21+
22+
def train_code2seq(config: DictConfig):
23+
filter_warnings()
24+
25+
if config.print_config:
26+
print_config(config, fields=["model", "data", "train", "optimizer"])
27+
28+
# Load data module
29+
data_module = PathContextDataModule(config.data_folder, config.data)
30+
data_module.prepare_data()
31+
data_module.setup()
32+
33+
# Load model
34+
code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)
35+
36+
train(code2seq, data_module, config)
37+
38+
39+
def test_code2seq(config: DictConfig):
40+
filter_warnings()
41+
42+
# Load data module
43+
data_module = PathContextDataModule(config.data_folder, config.data)
44+
data_module.prepare_data()
45+
data_module.setup()
46+
47+
# Load model
48+
code2seq = Code2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu"))
49+
50+
test(code2seq, data_module, config.seed)
51+
52+
53+
if __name__ == "__main__":
54+
__arg_parser = configure_arg_parser()
55+
__args = __arg_parser.parse_args()
56+
57+
__config = cast(DictConfig, OmegaConf.load(__args.config))
58+
if __args.mode == "train":
59+
train_code2seq(__config)
60+
else:
61+
test_code2seq(__config)

Diff for: code2seq/configs/code2class-poj104.yaml

-72
This file was deleted.

Diff for: code2seq/configs/code2seq-java-small.yaml

-76
This file was deleted.

0 commit comments

Comments
 (0)