Skip to content

Commit 9fa2806

Browse files
neggertwilliamFalcon
authored andcommitted
Fix ModelCheckpoint default paths (#413)
* Make name and version properties required * Warn before deleting files in checkpoint directory * Get default checkpoint path from any logger * Fix typos * Uncomment logger tests * Whitespace * Update callback_config_mixin.py checkpoints and version file names would just have a number. it's easy to tell what you're looking at with version_ prepended * Address comments * Fix broken tests
1 parent 3e38005 commit 9fa2806

File tree

6 files changed

+157
-120
lines changed

6 files changed

+157
-120
lines changed

pytorch_lightning/callbacks/pt_callbacks.py

+10
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,16 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
179179
save_best_only=True, save_weights_only=False,
180180
mode='auto', period=1, prefix=''):
181181
super(ModelCheckpoint, self).__init__()
182+
if (
183+
save_best_only and
184+
os.path.isdir(filepath) and
185+
len(os.listdir(filepath)) > 0
186+
):
187+
warnings.warn(
188+
f"Checkpoint directory {filepath} exists and is not empty with save_best_only=True."
189+
"All files in this directory will be deleted when a checkpoint is saved!"
190+
)
191+
182192
self.monitor = monitor
183193
self.verbose = verbose
184194
self.filepath = filepath

pytorch_lightning/logging/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def rank(self, value):
6565
"""Set the process rank"""
6666
self._rank = value
6767

68+
@property
69+
def name(self):
70+
"""Return the experiment name"""
71+
raise NotImplementedError("Sub-classes must provide a name property")
72+
6873
@property
6974
def version(self):
7075
"""Return the experiment version"""
71-
return None
76+
raise NotImplementedError("Sub-classes must provide a version property")

pytorch_lightning/logging/mlflow_logger.py

+8
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,11 @@ def finalize(self, status="FINISHED"):
6060
if status == 'success':
6161
status = 'FINISHED'
6262
self.experiment.set_terminated(self.run_id, status)
63+
64+
@property
65+
def name(self):
66+
return self.experiment_name
67+
68+
@property
69+
def version(self):
70+
return self._run_id

pytorch_lightning/logging/test_tube_logger.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
):
1616
super().__init__()
1717
self.save_dir = save_dir
18-
self.name = name
18+
self._name = name
1919
self.description = description
2020
self.debug = debug
2121
self._version = version
@@ -29,7 +29,7 @@ def experiment(self):
2929

3030
self._experiment = Experiment(
3131
save_dir=self.save_dir,
32-
name=self.name,
32+
name=self._name,
3333
debug=self.debug,
3434
version=self.version,
3535
description=self.description,
@@ -80,6 +80,13 @@ def rank(self, value):
8080
if self._experiment is not None:
8181
self.experiment.rank = value
8282

83+
@property
84+
def name(self):
85+
if self._experiment is None:
86+
return self._name
87+
else:
88+
return self.experiment.name
89+
8390
@property
8491
def version(self):
8592
if self._experiment is None:

pytorch_lightning/trainer/callback_config_mixin.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
24
from pytorch_lightning.logging import TestTubeLogger
35

@@ -12,14 +14,15 @@ def configure_checkpoint_callback(self):
1214
"""
1315
if self.checkpoint_callback is True:
1416
# init a default one
15-
if isinstance(self.logger, TestTubeLogger):
16-
ckpt_path = '{}/{}/version_{}/{}'.format(
17+
if self.logger is not None:
18+
ckpt_path = os.path.join(
1719
self.default_save_path,
18-
self.logger.experiment.name,
19-
self.logger.experiment.version,
20-
'checkpoints')
20+
self.logger.name,
21+
f'version_{self.logger.version}',
22+
"checkpoints"
23+
)
2124
else:
22-
ckpt_path = self.default_save_path
25+
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
2326

2427
self.checkpoint_callback = ModelCheckpoint(
2528
filepath=ckpt_path

tests/test_y_logging.py

+115-111
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
import pickle
23

34
import numpy as np
45
import torch
56

67
from pytorch_lightning import Trainer
78
from pytorch_lightning.testing import LightningTestModel
9+
from pytorch_lightning.logging import LightningLoggerBase, rank_zero_only
810
from . import testing_utils
911

1012
RANDOM_FILE_PATHS = list(np.random.randint(12000, 19000, 1000))
@@ -69,117 +71,119 @@ def test_testtube_pickle():
6971
testing_utils.clear_save_dir()
7072

7173

72-
# def test_mlflow_logger():
73-
# """
74-
# verify that basic functionality of mlflow logger works
75-
# """
76-
# reset_seed()
77-
#
78-
# try:
79-
# from pytorch_lightning.logging import MLFlowLogger
80-
# except ModuleNotFoundError:
81-
# return
82-
#
83-
# hparams = get_hparams()
84-
# model = LightningTestModel(hparams)
85-
#
86-
# root_dir = os.path.dirname(os.path.realpath(__file__))
87-
# mlflow_dir = os.path.join(root_dir, "mlruns")
88-
# import pdb
89-
# pdb.set_trace()
90-
#
91-
# logger = MLFlowLogger("test", f"file://{mlflow_dir}")
92-
# logger.log_hyperparams(hparams)
93-
# logger.save()
94-
#
95-
# trainer_options = dict(
96-
# max_nb_epochs=1,
97-
# train_percent_check=0.01,
98-
# logger=logger
99-
# )
100-
#
101-
# trainer = Trainer(**trainer_options)
102-
# result = trainer.fit(model)
103-
#
104-
# print('result finished')
105-
# assert result == 1, "Training failed"
106-
#
107-
# shutil.move(mlflow_dir, mlflow_dir + f'_{n}')
108-
109-
110-
# def test_mlflow_pickle():
111-
# """
112-
# verify that pickling trainer with mlflow logger works
113-
# """
114-
# reset_seed()
115-
#
116-
# try:
117-
# from pytorch_lightning.logging import MLFlowLogger
118-
# except ModuleNotFoundError:
119-
# return
120-
#
121-
# hparams = get_hparams()
122-
# model = LightningTestModel(hparams)
123-
#
124-
# root_dir = os.path.dirname(os.path.realpath(__file__))
125-
# mlflow_dir = os.path.join(root_dir, "mlruns")
126-
#
127-
# logger = MLFlowLogger("test", f"file://{mlflow_dir}")
128-
# logger.log_hyperparams(hparams)
129-
# logger.save()
130-
#
131-
# trainer_options = dict(
132-
# max_nb_epochs=1,
133-
# logger=logger
134-
# )
135-
#
136-
# trainer = Trainer(**trainer_options)
137-
# pkl_bytes = pickle.dumps(trainer)
138-
# trainer2 = pickle.loads(pkl_bytes)
139-
# trainer2.logger.log_metrics({"acc": 1.0})
140-
#
141-
# n = RANDOM_FILE_PATHS.pop()
142-
# shutil.move(mlflow_dir, mlflow_dir + f'_{n}')
143-
144-
145-
# def test_custom_logger():
146-
#
147-
# class CustomLogger(LightningLoggerBase):
148-
# def __init__(self):
149-
# super().__init__()
150-
# self.hparams_logged = None
151-
# self.metrics_logged = None
152-
# self.finalized = False
153-
#
154-
# @rank_zero_only
155-
# def log_hyperparams(self, params):
156-
# self.hparams_logged = params
157-
#
158-
# @rank_zero_only
159-
# def log_metrics(self, metrics, step_num):
160-
# self.metrics_logged = metrics
161-
#
162-
# @rank_zero_only
163-
# def finalize(self, status):
164-
# self.finalized_status = status
165-
#
166-
# hparams = get_hparams()
167-
# model = LightningTestModel(hparams)
168-
#
169-
# logger = CustomLogger()
170-
#
171-
# trainer_options = dict(
172-
# max_nb_epochs=1,
173-
# train_percent_check=0.01,
174-
# logger=logger
175-
# )
176-
#
177-
# trainer = Trainer(**trainer_options)
178-
# result = trainer.fit(model)
179-
# assert result == 1, "Training failed"
180-
# assert logger.hparams_logged == hparams
181-
# assert logger.metrics_logged != {}
182-
# assert logger.finalized_status == "success"
74+
def test_mlflow_logger():
75+
"""
76+
verify that basic functionality of mlflow logger works
77+
"""
78+
reset_seed()
79+
80+
try:
81+
from pytorch_lightning.logging import MLFlowLogger
82+
except ModuleNotFoundError:
83+
return
84+
85+
hparams = testing_utils.get_hparams()
86+
model = LightningTestModel(hparams)
87+
88+
root_dir = os.path.dirname(os.path.realpath(__file__))
89+
mlflow_dir = os.path.join(root_dir, "mlruns")
90+
91+
logger = MLFlowLogger("test", f"file://{mlflow_dir}")
92+
93+
trainer_options = dict(
94+
max_nb_epochs=1,
95+
train_percent_check=0.01,
96+
logger=logger
97+
)
98+
99+
trainer = Trainer(**trainer_options)
100+
result = trainer.fit(model)
101+
102+
print('result finished')
103+
assert result == 1, "Training failed"
104+
105+
testing_utils.clear_save_dir()
106+
107+
108+
def test_mlflow_pickle():
109+
"""
110+
verify that pickling trainer with mlflow logger works
111+
"""
112+
reset_seed()
113+
114+
try:
115+
from pytorch_lightning.logging import MLFlowLogger
116+
except ModuleNotFoundError:
117+
return
118+
119+
hparams = testing_utils.get_hparams()
120+
model = LightningTestModel(hparams)
121+
122+
root_dir = os.path.dirname(os.path.realpath(__file__))
123+
mlflow_dir = os.path.join(root_dir, "mlruns")
124+
125+
logger = MLFlowLogger("test", f"file://{mlflow_dir}")
126+
127+
trainer_options = dict(
128+
max_nb_epochs=1,
129+
logger=logger
130+
)
131+
132+
trainer = Trainer(**trainer_options)
133+
pkl_bytes = pickle.dumps(trainer)
134+
trainer2 = pickle.loads(pkl_bytes)
135+
trainer2.logger.log_metrics({"acc": 1.0})
136+
137+
testing_utils.clear_save_dir()
138+
139+
140+
def test_custom_logger(tmpdir):
141+
142+
class CustomLogger(LightningLoggerBase):
143+
def __init__(self):
144+
super().__init__()
145+
self.hparams_logged = None
146+
self.metrics_logged = None
147+
self.finalized = False
148+
149+
@rank_zero_only
150+
def log_hyperparams(self, params):
151+
self.hparams_logged = params
152+
153+
@rank_zero_only
154+
def log_metrics(self, metrics, step_num):
155+
self.metrics_logged = metrics
156+
157+
@rank_zero_only
158+
def finalize(self, status):
159+
self.finalized_status = status
160+
161+
@property
162+
def name(self):
163+
return "name"
164+
165+
@property
166+
def version(self):
167+
return "1"
168+
169+
hparams = testing_utils.get_hparams()
170+
model = LightningTestModel(hparams)
171+
172+
logger = CustomLogger()
173+
174+
trainer_options = dict(
175+
max_nb_epochs=1,
176+
train_percent_check=0.01,
177+
logger=logger,
178+
default_save_path=tmpdir
179+
)
180+
181+
trainer = Trainer(**trainer_options)
182+
result = trainer.fit(model)
183+
assert result == 1, "Training failed"
184+
assert logger.hparams_logged == hparams
185+
assert logger.metrics_logged != {}
186+
assert logger.finalized_status == "success"
183187

184188

185189
def reset_seed():

0 commit comments

Comments
 (0)