From 2654bad302d0e92aa602990bfebe2dab5884e1b5 Mon Sep 17 00:00:00 2001 From: Mattie Tesfaldet Date: Fri, 1 Sep 2023 16:20:17 -0400 Subject: [PATCH] Fix `torch.compile` on `nn.module` instead of on `LightningModule` (#587) --- configs/experiment/example.yaml | 1 + configs/model/mnist.yaml | 3 +++ configs/train.yaml | 3 --- environment.yaml | 1 + src/models/mnist_module.py | 20 ++++++++++++++++---- src/train.py | 4 ---- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/configs/experiment/example.yaml b/configs/experiment/example.yaml index 690a59fef..9a93b540c 100644 --- a/configs/experiment/example.yaml +++ b/configs/experiment/example.yaml @@ -28,6 +28,7 @@ model: lin1_size: 128 lin2_size: 256 lin3_size: 64 + compile: false data: batch_size: 64 diff --git a/configs/model/mnist.yaml b/configs/model/mnist.yaml index 3965ec23b..6f9c2fa1e 100644 --- a/configs/model/mnist.yaml +++ b/configs/model/mnist.yaml @@ -20,3 +20,6 @@ net: lin2_size: 128 lin3_size: 64 output_size: 10 + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/train.yaml b/configs/train.yaml index c24b20681..ef7bdab6e 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -42,9 +42,6 @@ train: True # lightning chooses best weights based on the metric specified in checkpoint callback test: True -# compile model for faster training with pytorch 2.0 -compile: False - # simply provide checkpoint path to resume training ckpt_path: null diff --git a/environment.yaml b/environment.yaml index 34a24631a..f74ee8c72 100644 --- a/environment.yaml +++ b/environment.yaml @@ -21,6 +21,7 @@ channels: # compatibility is usually guaranteed dependencies: + - python=3.10 - pytorch=2.* - torchvision=0.* - lightning=2.* diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index 5c1049725..b10450f19 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -44,6 +44,7 @@ def __init__( net: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, + compile: bool, ) -> None: """Initialize a `MNISTLitModule`. @@ -176,10 +177,21 @@ def on_test_epoch_end(self) -> None: """Lightning hook that is called when a test epoch ends.""" pass - def configure_optimizers(self) -> Dict[str, Any]: - """Configures optimizers and learning-rate schedulers to be used for training. + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. - Normally you'd need one, but in the case of GANs or similar you might need multiple. + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers @@ -202,4 +214,4 @@ def configure_optimizers(self) -> Dict[str, Any]: if __name__ == "__main__": - _ = MNISTLitModule(None, None, None) + _ = MNISTLitModule(None, None, None, None) diff --git a/src/train.py b/src/train.py index 2942d6acb..955eefaf4 100644 --- a/src/train.py +++ b/src/train.py @@ -74,10 +74,6 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Logging hyperparameters!") utils.log_hyperparameters(object_dict) - if cfg.get("compile"): - log.info("Compiling model!") - model = torch.compile(model) - if cfg.get("train"): log.info("Starting training!") trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))