diff --git a/nni/compression/pytorch/utils/evaluator.py b/nni/compression/pytorch/utils/evaluator.py index e55307bf1b..375619392c 100644 --- a/nni/compression/pytorch/utils/evaluator.py +++ b/nni/compression/pytorch/utils/evaluator.py @@ -14,10 +14,16 @@ from torch.utils.hooks import RemovableHandle try: - import pytorch_lightning as pl - from pytorch_lightning.callbacks import Callback + import lightning as pl + from lightning.callbacks import Callback except ImportError: - LIGHTNING_INSTALLED = False + try: + import pytorch_lightning as pl + from pytorch_lightning.callbacks import Callback + except ImportError: + LIGHTNING_INSTALLED = False + else: + LIGHTNING_INSTALLED = True else: LIGHTNING_INSTALLED = True @@ -957,4 +963,4 @@ def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, return self.trainer.evaluate() def get_dummy_input(self) -> Any: - return self.dummy_input + return self.dummy_input \ No newline at end of file diff --git a/nni/contrib/compression/utils/evaluator.py b/nni/contrib/compression/utils/evaluator.py index 0ba7bd645b..38c5e5f7b0 100644 --- a/nni/contrib/compression/utils/evaluator.py +++ b/nni/contrib/compression/utils/evaluator.py @@ -16,9 +16,16 @@ from torch.utils.hooks import RemovableHandle try: - import pytorch_lightning as pl + import lightning as pl + from lightning.callbacks import Callback except ImportError: - LIGHTNING_INSTALLED = False + try: + import pytorch_lightning as pl + from pytorch_lightning.callbacks import Callback + except ImportError: + LIGHTNING_INSTALLED = False + else: + LIGHTNING_INSTALLED = True else: LIGHTNING_INSTALLED = True @@ -1099,4 +1106,4 @@ def evaluate(self) -> Tuple[float | None, Dict[str, Any]]: return nni_used_metric, metric def get_dummy_input(self) -> Any: - return self.dummy_input + return self.dummy_input \ No newline at end of file