-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathresume.py
341 lines (290 loc) · 15.1 KB
/
resume.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass
from pathlib import Path, PosixPath, WindowsPath
from typing import Optional, Union
import lightning.fabric as fl
import lightning.pytorch as pl
from nemo.lightning import io
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import uninject_model_parallel_rank
# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == "nt":
BasePath = WindowsPath
else:
BasePath = PosixPath
def _try_restore_tokenizer(model, ckpt_path):
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.lightning.io import load_context
try:
tokenizer = load_context(ckpt_path, "model.tokenizer")
except ValueError as e:
logging.warning(
f"Encountered error while trying to restore tokenizer. Tokenizer is not restored. " f"Original error: {e}"
)
return model
if isinstance(tokenizer, TokenizerSpec):
model.tokenizer = tokenizer
model.__io__.tokenizer = tokenizer.__io__
else:
# Ignore if the ckpt doesn't have a tokenizer. type(tokenizer)==TrainerContext in this case.
logging.warning("Checkpoint does not have model.tokenizer field. Tokenizer is not restored.")
return model
@dataclass(kw_only=True)
class AutoResume:
"""Class that handles the logic for setting checkpoint paths and restoring from
checkpoints in NeMo.
Attributes:
restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model
weights, optimizer states, etc.
If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be
automatically converted to a NeMo compatible format.
resume_from_folder or the run's log_dir takes precedence over restore_config.
resume_from_directory (str): Path to the checkpointing directory to restore from.
resume_from_path (str): Path to a specific checkpoint to restore from.
adapter_path (str): Path to any adapter checkpoints.
resume_if_exists (bool): Whether this experiment is resuming from a previous run. If
True, it sets trainer._checkpoint_connector._ckpt_path so that the trainer should
auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}.
Defaults to False.
resume_past_end (bool): By default, AutoResume throws an error if resume_if_exists is
True and a checkpoint matching ``*end.ckpt`` indicating a previous training run
fully completed. Setting resume_past_end=True disables this behavior and loads the
last checkpoint.
resume_ignore_no_checkpoint (bool): AutoResume throws an error if resume_if_exists is
True and no checkpoint could be found. Setting resume_ignore_no_checkpoint=True
disables this behavior, in which case exp_manager will print a message and
continue without restoring.
"""
restore_config: Optional[RestoreConfig] = None
resume_from_directory: Optional[str] = None
resume_from_path: Optional[str] = None
adapter_path: Optional[str] = None
resume_if_exists: bool = False
resume_past_end: bool = False
resume_ignore_no_checkpoint: bool = False
WEIGHTS_PATH = "weights"
def get_weights_path(self, path):
return Path(path) / self.WEIGHTS_PATH
def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
if isinstance(trainer, fl.Fabric):
raise NotImplementedError("Fabric is not supported yet.")
trainer_ckpt_path = self.get_trainer_ckpt_path(model)
# Need a way to actually restore the context.
model = _try_restore_tokenizer(model, os.path.join(trainer_ckpt_path.parent, "context"))
if trainer_ckpt_path:
trainer.ckpt_path = trainer_ckpt_path
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
# Load artifacts
if getattr(self.restore_config, 'load_artifacts', False):
if isinstance(trainer_ckpt_path, AdapterPath):
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
# is deleted before the current peft checkpoint is saved
context_path = trainer_ckpt_path.base_model_path / "context"
if not context_path.exists():
context_path = trainer_ckpt_path.base_model_path
else:
context_path = self.get_context_path(model)
model = _try_restore_tokenizer(model, context_path)
elif self.restore_config:
new_path = self._extract_path(
model=model,
path=self.restore_config.path,
adapter_path=self.restore_config.adapter_path,
)
if isinstance(new_path, AdapterPath):
self.restore_config.path = new_path.base_model_path
self.restore_config.adapter_path = str(new_path)
else:
self.restore_config.path = str(new_path)
trainer.strategy.restore_config = self.restore_config
# Load artifacts
if self.restore_config.load_artifacts:
if isinstance(new_path, AdapterPath):
context_path = Path(new_path.base_model_path) / "context"
else:
context_path = new_path / "context"
if not context_path.is_dir():
context_path = new_path
_try_restore_tokenizer(model, context_path)
def _extract_path(
self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None
) -> BasePath:
if "://" in path:
assert path.startswith("nemo://"), "Only NeMo based paths starting with nemo:// are currently supported."
_, _path = path.split("://")
new_path = os.path.join(NEMO_MODELS_CACHE, _path)
else:
new_path = path
if adapter_path:
maybe_weights_path = self.get_weights_path(adapter_path)
if maybe_weights_path.is_dir():
adapter_path = maybe_weights_path
new_path = AdapterPath(Path(adapter_path), base_model_path=new_path)
if isinstance(new_path, str):
new_path = Path(new_path)
return new_path
def _resume_peft(self, adapter_meta_path, model):
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
assert self.restore_config, "PEFT resume requires specifying restore_config"
base_model_path = self._extract_path(model, self.restore_config.path)
if base_model_path not in [Path(metadata['model_ckpt_path']), Path(metadata['model_ckpt_path']).parent]:
logging.warning(
f"⚠️ When trying to resume a PEFT training run, found mismatching values: "
f"your specified restore_path points to {base_model_path}, "
f"but the PEFT checkpoint was trained with "
f"model_ckpt_path={metadata['model_ckpt_path']}"
)
return base_model_path
def _find_trainer_ckpt_path(self) -> Optional[Path]:
from nemo.utils.exp_manager import NotFoundError, _filter_out_unfinished_checkpoints
app_state = AppState()
log_dir = app_state.log_dir
checkpoint = None
# Use <log_dir>/checkpoints/ unless `dirpath` is set
if self.resume_from_directory:
checkpoint_dir = Path(self.resume_from_directory)
elif log_dir is not None:
checkpoint_dir = Path(Path(log_dir) / "checkpoints")
else: # ie. if log_dir is None
return None
# when using distributed checkpointing, checkpoint_dir is a directory of directories
# we check for this here
dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()]
end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")]
last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")]
end_chkpt_cnt = len(end_dist_checkpoints)
end_checkpoints = _filter_out_unfinished_checkpoints(end_dist_checkpoints)
finished_end_chkpt_cnt = len(end_checkpoints)
if end_chkpt_cnt > 0 and finished_end_chkpt_cnt == 0:
raise ValueError(
"End checkpoint is unfinished and cannot be used to resume the training."
" Please remove the checkpoint manually to avoid unexpected cosequences, such as"
" restarting from scratch."
)
last_chkpt_cnt = len(last_dist_checkpoints)
last_checkpoints = _filter_out_unfinished_checkpoints(last_dist_checkpoints)
finished_last_chkpt_cnt = len(last_checkpoints)
if last_chkpt_cnt > 0 and finished_last_chkpt_cnt == 0:
raise ValueError(
"Last checkpoint is unfinished and cannot be used to resume the training."
" Please remove the checkpoint manually to avoid unexpected cosequences, such as"
" restarting from scratch. Hint: Iteration number can be added to the checkpoint name pattern"
" to maximize chance that there is at least one finished last checkpoint to resume from."
)
if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0):
if self.resume_ignore_no_checkpoint:
warn = (
f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir "
f":{checkpoint_dir}. "
)
if checkpoint is None:
warn += "Training from scratch."
logging.warning(warn)
else:
if self.restore_config:
# resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore
# later instead.
return None
else:
raise NotFoundError(
f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir "
f":{checkpoint_dir}. Cannot resume."
)
elif len(end_checkpoints) > 0:
if not self.resume_past_end:
raise ValueError(
f"Found {end_checkpoints[0]} indicating that the last training run has already completed."
)
if len(end_checkpoints) > 1:
if "mp_rank" in str(end_checkpoints[0]):
checkpoint = end_checkpoints[0]
else:
raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.")
elif len(last_checkpoints) > 1:
if any([s for s in ["mp_rank", "tp_rank", "fsdp_shard"] if s in str(last_checkpoints[0])]):
checkpoint = last_checkpoints[0]
checkpoint = uninject_model_parallel_rank(checkpoint)
else:
# Select the checkpoint with the latest modified time
checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0]
logging.warning(
f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest "
f"modified time."
)
else:
checkpoint = last_checkpoints[0]
return checkpoint
def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
checkpoint = None # ??? this is totally wrong.
app_state = AppState()
app_state.restore = self.resume_if_exists
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()
if checkpoint:
maybe_context_path = Path(checkpoint) / "context"
if maybe_context_path.is_dir():
checkpoint = maybe_context_path
return checkpoint
def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
if self.resume_from_path:
maybe_weights_path = self.get_weights_path(self.resume_from_path)
if maybe_weights_path.is_dir():
adapter_meta_path = maybe_weights_path / ADAPTER_META_FILENAME
if adapter_meta_path.exists():
# the resume_from_path is an adapter checkpoint
base_model_path = self._resume_peft(adapter_meta_path, model)
return AdapterPath(Path(self.resume_from_path), base_model_path=base_model_path)
else:
# the resume_from_path is not PEFT checkpoint
return maybe_weights_path
else:
return self.resume_from_path
checkpoint = None
app_state = AppState()
app_state.restore = self.resume_if_exists
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()
if checkpoint:
maybe_weights_path = self.get_weights_path(checkpoint)
if maybe_weights_path.is_dir():
checkpoint = maybe_weights_path
if checkpoint:
if self.adapter_path:
return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint)
else:
adapter_meta_path = checkpoint / ADAPTER_META_FILENAME
if adapter_meta_path.exists():
base_model_path = self._resume_peft(adapter_meta_path, model)
return AdapterPath(checkpoint, base_model_path=base_model_path)
else:
return Path(checkpoint)
return None
class AdapterPath(BasePath):
"""Path object for adapter paths which include a field for the base model the adapters are trained on
to facilitate model loading."""
base_model_path: Optional[Path]
def __new__(cls, *args, base_model_path: Optional[Path] = None, **kwargs):
output = super().__new__(cls, *args, **kwargs)
output.base_model_path = base_model_path
return output
def __repr__(self):
return "{}({!r}, base_model_path={})".format(self.__class__.__name__, self.as_posix(), self.base_model_path)