Skip to content

Commit c50f2a5

Browse files
committed
Update requirements, use rich progress bar instead of tqdm
1 parent d64ec4f commit c50f2a5

12 files changed

+43
-45
lines changed

code2seq/code2class_wrapper.py

-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def train_code2class(config: DictConfig):
2727

2828
# Load data module
2929
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
30-
data_module.prepare_data()
31-
data_module.setup()
3230

3331
# Load model
3432
code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary)

code2seq/code2seq_wrapper.py

-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def train_code2seq(config: DictConfig):
2727

2828
# Load data module
2929
data_module = PathContextDataModule(config.data_folder, config.data)
30-
data_module.prepare_data()
31-
data_module.setup()
3230

3331
# Load model
3432
code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)

code2seq/data/path_context_data_module.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ class PathContextDataModule(LightningDataModule):
1818
_val = "val"
1919
_test = "test"
2020

21-
_vocabulary: Optional[Vocabulary] = None
22-
2321
def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False):
2422
super().__init__()
2523
self._config = config
2624
self._data_dir = data_dir
2725
self._name = basename(data_dir)
2826
self._is_class = is_class
2927

28+
self._vocabulary = self.setup_vocabulary()
29+
3030
@property
3131
def vocabulary(self) -> Vocabulary:
3232
if self._vocabulary is None:
@@ -41,14 +41,12 @@ def prepare_data(self):
4141
raise ValueError(f"Config doesn't contain url for, can't download it automatically")
4242
download_dataset(self._config.url, self._data_dir, self._name)
4343

44-
def setup(self, stage: Optional[str] = None):
45-
if not exists(join(self._data_dir, Vocabulary.vocab_filename)):
44+
def setup_vocabulary(self) -> Vocabulary:
45+
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
46+
if not exists(vocabulary_path):
4647
print("Can't find vocabulary, collect it from train holdout")
4748
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary)
48-
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
49-
self._vocabulary = Vocabulary(
50-
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class
51-
)
49+
return Vocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class)
5250

5351
@staticmethod
5452
def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext:
@@ -88,6 +86,9 @@ def val_dataloader(self, *args, **kwargs) -> DataLoader:
8886
def test_dataloader(self, *args, **kwargs) -> DataLoader:
8987
return self._shared_dataloader(self._test)
9088

89+
def predict_dataloader(self, *args, **kwargs) -> DataLoader:
90+
return self.test_dataloader(*args, **kwargs)
91+
9192
def transfer_batch_to_device(
9293
self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int
9394
) -> BatchedLabeledPathContext:

code2seq/data/typed_path_context_data_module.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class TypedPathContextDataModule(PathContextDataModule):
14-
_vocabulary: Optional[TypedVocabulary] = None
14+
_vocabulary: TypedVocabulary
1515

1616
def __init__(self, data_dir: str, config: DictConfig):
1717
super().__init__(data_dir, config)
@@ -27,12 +27,12 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathC
2727
raise RuntimeError(f"Setup vocabulary before creating data loaders")
2828
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
2929

30-
def setup(self, stage: Optional[str] = None):
30+
def setup_vocabulary(self) -> TypedVocabulary:
3131
if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)):
3232
print("Can't find vocabulary, collect it from train holdout")
3333
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary)
3434
vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename)
35-
self._vocabulary = TypedVocabulary(
35+
return TypedVocabulary(
3636
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._config.types_count
3737
)
3838

code2seq/model/code2seq.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from commode_utils.losses import SequenceCrossEntropyLoss
55
from commode_utils.metrics import SequentialF1Score, ClassificationMetrics
6+
from commode_utils.metrics.chrF import ChrF
67
from commode_utils.modules import LSTMDecoderStep, Decoder
78
from omegaconf import DictConfig
89
from pytorch_lightning import LightningModule
@@ -41,6 +42,10 @@ def __init__(
4142
f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
4243
for holdout in ["train", "val", "test"]
4344
}
45+
id2label = {v: k for k, v in vocabulary.label_to_id.items()}
46+
metrics.update(
47+
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]}
48+
)
4449
self.__metrics = MetricCollection(metrics)
4550

4651
self._encoder = self._get_encoder(model_config)
@@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
102107
target_sequence = batch.labels if step == "train" else None
103108
# [seq length; batch size; vocab size]
104109
logits, _ = self.logits_from_batch(batch, target_sequence)
105-
loss = self.__loss(logits[1:], batch.labels[1:])
110+
result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])}
106111

107112
with torch.no_grad():
108113
prediction = logits.argmax(-1)
109114
metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
115+
result.update(
116+
{f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall}
117+
)
118+
if step != "train":
119+
result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels)
110120

111-
return {
112-
f"{step}/loss": loss,
113-
f"{step}/f1": metric.f1_score,
114-
f"{step}/precision": metric.precision,
115-
f"{step}/recall": metric.recall,
116-
}
121+
return result
117122

118123
def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore
119124
result = self._shared_step(batch, "train")
@@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
143148
f"{step}/recall": metric.recall,
144149
}
145150
self.__metrics[f"{step}_f1"].reset()
151+
if step != "train":
152+
log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute()
153+
self.__metrics[f"{step}_chrf"].reset()
146154
self.log_dict(log, on_step=False, on_epoch=True)
147155

148156
def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):

code2seq/typed_code2seq_wrapper.py

-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ def train_typed_code2seq(config: DictConfig):
2727

2828
# Load data module
2929
data_module = TypedPathContextDataModule(config.data_folder, config.data)
30-
data_module.prepare_data()
31-
data_module.setup()
3230

3331
# Load model
3432
typed_code2seq = TypedCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)

code2seq/utils/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
def filter_warnings():
55
# "The dataloader does not have many workers which may be a bottleneck."
66
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.utilities.distributed", lineno=50)
7-
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=105)
7+
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=110)
88
# "Please also save or load the state of the optimizer when saving or loading the scheduler."
99
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=216) # save
1010
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=234) # load

code2seq/utils/train.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from os.path import join
2+
13
import torch
24
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
35
from omegaconf import DictConfig, OmegaConf
46
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
5-
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
7+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar
68
from pytorch_lightning.loggers import WandbLogger
79

810

@@ -21,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
2123

2224
# define model checkpoint callback
2325
checkpoint_callback = ModelCheckpointWithUpload(
24-
dirpath=wandb_logger.experiment.dir,
26+
dirpath=join(wandb_logger.experiment.dir, "checkpoints"),
2527
filename="{epoch:02d}-val_loss={val/loss:.4f}",
2628
monitor="val/loss",
2729
every_n_epochs=params.save_every_epoch,
@@ -36,6 +38,8 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
3638
gpu = 1 if torch.cuda.is_available() else None
3739
# define learning rate logger
3840
lr_logger = LearningRateMonitor("step")
41+
# define progress bar callback
42+
progress_bar = RichProgressBar(refresh_rate_per_second=config.progress_bar_refresh_rate)
3943
trainer = Trainer(
4044
max_epochs=params.n_epochs,
4145
gradient_clip_val=params.clip_norm,
@@ -44,15 +48,9 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
4448
log_every_n_steps=params.log_every_n_steps,
4549
logger=wandb_logger,
4650
gpus=gpu,
47-
progress_bar_refresh_rate=config.progress_bar_refresh_rate,
48-
callbacks=[
49-
lr_logger,
50-
early_stopping_callback,
51-
checkpoint_callback,
52-
print_epoch_result_callback,
53-
],
51+
callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar],
5452
resume_from_checkpoint=config.get("checkpoint", None),
5553
)
5654

5755
trainer.fit(model=model, datamodule=data_module)
58-
trainer.test()
56+
trainer.test(datamodule=data_module, ckpt_path="best")

config/code2seq-java-med.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ data:
2828
random_context: true
2929

3030
batch_size: 512
31-
test_batch_size: 768
31+
test_batch_size: 512
3232

3333
model:
3434
# Encoder

config/code2seq-java-test.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ data_folder: ../data/code2seq/java-test
33
checkpoint: null
44

55
seed: 7
6-
# Training in notebooks (e.g. Google Colab) may crash with too small value
76
progress_bar_refresh_rate: 1
87
print_config: true
98

requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
torch==1.10.0
2-
pytorch-lightning==1.4.9
3-
torchmetrics==0.5.1
2+
pytorch-lightning==1.5.1
3+
torchmetrics==0.6.0
44
tqdm==4.62.3
55
wandb==0.12.6
66
omegaconf==2.1.1
7-
commode-utils==0.3.12
7+
commode-utils==0.4.0

setup.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
readme = readme_file.read()
77

88
install_requires = [
9-
"torch>=1.9.0",
10-
"pytorch-lightning~=1.4.2",
11-
"torchmetrics~=0.5.0",
12-
"tqdm~=4.62.1",
9+
"torch>=1.10.0",
10+
"pytorch-lightning~=1.5.0",
1311
"wandb~=0.12.0",
1412
"omegaconf~=2.1.1",
15-
"commode-utils>=0.3.8",
13+
"commode-utils>=0.4.0",
1614
]
1715

1816
setup_args = dict(

0 commit comments

Comments
 (0)