|
27 | 27 | from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
|
28 | 28 | from lightning.pytorch.plugins.precision import PrecisionPlugin
|
29 | 29 | from lightning.pytorch.strategies.ddp import DDPStrategy
|
| 30 | + from lightning.pytorch.utilities.types import STEP_OUTPUT |
30 | 31 | elif module_available("pytorch_lightning"):
|
31 | 32 | from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
|
32 | 33 | from lightning_fabric.utilities.distributed import group as _group
|
|
36 | 37 | from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
37 | 38 | from pytorch_lightning.plugins.precision import PrecisionPlugin
|
38 | 39 | from pytorch_lightning.strategies.ddp import DDPStrategy
|
| 40 | + from pytorch_lightning.utilities.types import STEP_OUTPUT |
39 | 41 | else:
|
40 | 42 | raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")
|
41 | 43 | from torch import Tensor
|
@@ -138,20 +140,20 @@ def optimizer_step(
|
138 | 140 | htcore.mark_step()
|
139 | 141 | return optimizer_output
|
140 | 142 |
|
141 |
| - def validation_step(self, batch: Any, batch_idx: int) -> Any: |
| 143 | + def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: |
142 | 144 | # Break lazy accumulation of graph after every step
|
143 | 145 | htcore.mark_step()
|
144 |
| - return super().validation_step(batch, batch_idx) |
| 146 | + return super().validation_step(*args, **kwargs) |
145 | 147 |
|
146 |
| - def test_step(self, batch: Any, batch_idx: int) -> Any: |
| 148 | + def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: |
147 | 149 | # Break lazy accumulation of graph after every step
|
148 | 150 | htcore.mark_step()
|
149 |
| - return super().test_step(batch, batch_idx) |
| 151 | + return super().test_step(*args, **kwargs) |
150 | 152 |
|
151 |
| - def predict_step(self, batch: Any, batch_idx: int) -> Any: |
| 153 | + def predict_step(self, *args: Any, **kwargs: Any) -> Any: |
152 | 154 | # Break lazy accumulation of graph after every step
|
153 | 155 | htcore.mark_step()
|
154 |
| - return super().predict_step(batch, batch_idx) |
| 156 | + return super().predict_step(*args, **kwargs) |
155 | 157 |
|
156 | 158 | def reduce(
|
157 | 159 | self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
|
|
0 commit comments