Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds support for sequential, multi-stage experiment runs (checkpoint-chained training stages) and extends checkpoint naming/configuration to better support recurrent/latent-chaining and MeZO training variants.
Changes:
- Added
sequential_runssupport to the experiment runner to orchestrate multi-stage scripts with checkpoint handoff. - Introduced configurable checkpoint output filenames (
output_ckpt,mezo_output_ckpt,recurrent_output_ckpt) and wired them into training scripts. - Modularized latent-chaining recurrent loss into a new
recurrent_block_variantsmodule and refactoredtrain_recurrent.pyaccordingly.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
optimization_and_search/run_experiments.py |
Adds sequential multi-stage orchestration (sequential_runs) and script selection/checkpoint chaining. |
demos/sequential_run_experiments_demo.yaml |
Demonstrates a multi-stage sequential run pipeline (train → finetune → recurrent → MeZO). |
train_args.py |
Adds new CLI args for checkpoint output filenames and groups MeZO/recurrent options. |
train.py |
Uses --output_ckpt for checkpoint saving instead of hard-coded ckpt.pt. |
train_mezo.py |
Saves checkpoints using --mezo_output_ckpt. |
train_recurrent.py |
Refactors recurrent training into modular components and adds recurrent checkpoint filename + variant selection. |
recurrent_variations/recurrent_block_variants.py |
New module implementing latent-chaining recurrent block loss + config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if args.always_save_checkpoint: | ||
| save_checkpoint( | ||
| model=model, | ||
| ckpt_model_args=ckpt_model_args, | ||
| ckpt_path=best_ckpt_path, | ||
| best_val_loss=state.best_val_loss, | ||
| global_step=state.global_step, | ||
| tag="latest", | ||
| ) |
There was a problem hiding this comment.
When always_save_checkpoint is enabled, save_checkpoint() is called with ckpt_path=best_ckpt_path, which overwrites the best checkpoint file with a “latest” checkpoint. This makes it impossible to reliably keep the best checkpoint. Use a separate filename/path for the “latest” checkpoint (or only write “latest” when it’s also best).
| ptr = 0 | ||
| data_view = data | ||
| total_tokens = len(data_view) - 1 | ||
| save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint |
There was a problem hiding this comment.
Checkpoint saving currently ignores never_save_checkpoint when always_save_checkpoint is true (and save_enabled also allows saving when never_save_checkpoint is true). In train.py, never_save_checkpoint always prevents saving regardless of always_save_checkpoint; aligning this behavior avoids surprising writes in “no-save” mode.
| save_enabled = args.always_save_checkpoint or not args.never_save_checkpoint | |
| save_enabled = not args.never_save_checkpoint |
| parser.add_argument("--weight_end", type=float, default=1.0) | ||
| parser.add_argument( | ||
| "--reset_optim", | ||
| action="store_true", |
There was a problem hiding this comment.
--reset_optim is defined as store_true here, so the command builder in run_experiments.py can emit --no-reset_optim (for false) and cause an unknown-arg failure. Since the rest of the codebase uses BooleanOptionalAction for these flags, consider switching --reset_optim to argparse.BooleanOptionalAction with an explicit default to keep CLI behavior consistent and compatible with sequential runs.
| action="store_true", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, |
| def _normalize_sequential_runs(sequential_runs: object) -> list[dict]: | ||
| if sequential_runs is None: | ||
| return [] | ||
| if isinstance(sequential_runs, dict): | ||
| return [sequential_runs] |
There was a problem hiding this comment.
sequential_runs adds new multi-stage execution behavior, but there is no automated coverage for it. The existing tests/test_run_experiments.sh covers other config features; adding a small YAML under tests/run_optimization_tests/ that exercises sequential_runs (resume + input_ckpt chaining) would help prevent regressions.
| # Read metrics (use existing or nan on failure) | ||
| try: | ||
| metrics = read_metrics(str(combo['out_dir'])) | ||
| except Exception: | ||
| metrics = {k: float("nan") for k in METRIC_KEYS} |
There was a problem hiding this comment.
run_experiment() still derives metrics only from best_val_loss_and_iter.txt (written by train.py). With sequential runs, stages like train_mezo.py / train_recurrent.py don’t write this file, so successful runs can be logged with NaN metrics. Consider standardizing metric output across scripts or emitting a clear warning/selecting metrics from an earlier stage that produces them.
| best_val_loss = ckpt["best_val_loss"].item() | ||
| print("best_val_loss", best_val_loss) | ||
| best_val_loss = 5.00 # TODO: allow configurable start threshold |
There was a problem hiding this comment.
best_val_loss is loaded from the resume checkpoint but then unconditionally overwritten to 5.00. This breaks true “best so far” tracking (and can prevent saving if losses stay above/below that threshold) and makes resumed recurrent runs non-reproducible. Prefer initializing best_val_loss from the checkpoint (or inf), and if a custom threshold is desired, expose it as an explicit CLI flag instead of hard-coding.
| best_val_loss = ckpt["best_val_loss"].item() | |
| print("best_val_loss", best_val_loss) | |
| best_val_loss = 5.00 # TODO: allow configurable start threshold | |
| # Initialize best_val_loss from checkpoint if available; otherwise use +inf | |
| best_val_loss = float("inf") | |
| if "best_val_loss" in ckpt and ckpt["best_val_loss"] is not None: | |
| best_val_loss = ckpt["best_val_loss"].item() | |
| print("best_val_loss", best_val_loss) |
This pull request introduces support for sequential multi-stage experiment runs, improves flexibility for checkpoint naming and management, and adds infrastructure for recurrent/latent-chaining training variants. Key changes include enhancements to experiment configuration and orchestration, new argument options for checkpoint handling, and a new module for recurrent block variants.
Experiment orchestration and configuration:
Added support for
sequential_runsin experiment YAML files, allowing experiments to be defined as a sequence of training stages (e.g., base training, fine-tuning, recurrent, and MeZO stages) with explicit checkpoint handoff between stages. (demos/sequential_run_experiments_demo.yaml,optimization_and_search/run_experiments.py) [1] [2] [3] [4] [5] [6] [7]Improved the experiment runner to handle per-stage script selection, argument passing, checkpoint chaining, and output directory management for sequential runs. (
optimization_and_search/run_experiments.py) [1] [2] [3] [4]Checkpoint and argument handling:
output_ckpt,mezo_output_ckpt,recurrent_output_ckpt), and updated training scripts to use these names when saving checkpoints. (train_args.py,train.py,train_mezo.py) [1] [2] [3] [4] [5] [6] [7]Recurrent/latent-chaining infrastructure:
recurrent_block_variants.pyimplementing the latent chaining recurrent block loss and a configuration class, enabling flexible experimentation with recurrent training strategies. (recurrent_variations/recurrent_block_variants.py)CLI and argument group improvements:
train_args.py) [1] [2]These changes collectively enable more complex, reproducible experiment pipelines, more flexible checkpointing, and new research directions with recurrent block variants.