diff --git a/fortuna/calib_model/base.py b/fortuna/calib_model/base.py index 37055efa..f3176442 100644 --- a/fortuna/calib_model/base.py +++ b/fortuna/calib_model/base.py @@ -137,11 +137,11 @@ def _calibrate( rng=self.rng.get(), state=state, loss_fun=loss, - training_dataloader=calib_data_loader, + training_data_loader=calib_data_loader, training_dataset_size=n_calib_data, n_epochs=config.optimizer.n_epochs, metrics=config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=n_val_data, verbose=config.monitor.verbose, callbacks=config.callbacks, diff --git a/fortuna/calib_model/calib_mixin.py b/fortuna/calib_model/calib_mixin.py index a63f45ac..1d79b876 100644 --- a/fortuna/calib_model/calib_mixin.py +++ b/fortuna/calib_model/calib_mixin.py @@ -1,7 +1,7 @@ import os from typing import Optional -from flax.training import checkpoints +# from flax.training import checkpoints from fortuna.calib_model.state import CalibState from fortuna.training.mixins.checkpointing import WithCheckpointingMixin @@ -12,29 +12,30 @@ class WithCalibCheckpointingMixin(WithCheckpointingMixin): - def restore_checkpoint( - self, - restore_checkpoint_dir: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "", - **kwargs, - ) -> CalibState: - if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile( - restore_checkpoint_dir - ): - raise ValueError( - f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_dir), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`." - ) - - return CalibState.init_from_dict(d, optimizer, **kwargs) + pass + # def restore_checkpoint( + # self, + # restore_checkpoint_dir: Path, + # optimizer: Optional[OptaxOptimizer] = None, + # prefix: str = "", + # **kwargs, + # ) -> CalibState: + # if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile( + # restore_checkpoint_dir + # ): + # raise ValueError( + # f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found." + # ) + # d = checkpoints.restore_checkpoint( + # ckpt_dir=str(restore_checkpoint_dir), + # target=None, + # step=None, + # prefix=prefix, + # parallel=True, + # ) + # if d is None: + # raise ValueError( + # f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`." + # ) + # + # return CalibState.init_from_dict(d, optimizer, **kwargs) diff --git a/fortuna/data/dataset/huggingface_datasets.py b/fortuna/data/dataset/huggingface_datasets.py index 7d556459..60521779 100644 --- a/fortuna/data/dataset/huggingface_datasets.py +++ b/fortuna/data/dataset/huggingface_datasets.py @@ -112,12 +112,12 @@ def get_data_loader( drop_last: bool if True, the last batch (which potentially is smaller then the default batch size) is dropped. verbose: bool - Whether to show a progress bar while iterating over the dataloader or not. + Whether to show a progress bar while iterating over the data_loader or not. Returns ------- HuggingFaceDataLoader - The dataloader + The data_loader """ iterable = IterableData.from_callable( lambda *args, **kwargs: self._get_data_loader( diff --git a/fortuna/data/loader/base.py b/fortuna/data/loader/base.py index 2f7aed43..4376f149 100644 --- a/fortuna/data/loader/base.py +++ b/fortuna/data/loader/base.py @@ -9,6 +9,7 @@ Tuple, Type, TypeVar, + Union ) from flax import jax_utils @@ -24,6 +25,10 @@ Status, Targets, ) +from fortuna.utils.prefetch import prefetch_to_mesh +from fortuna.partitioner.partition_manager.base import PartitionManager +from jax import device_put +from jax.sharding import NamedSharding, PartitionSpec T = TypeVar("T") @@ -185,7 +190,7 @@ def from_tensorflow_data_loader(cls: Type[T], tf_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_tf_dataloader(tf_data_loader)) + return cls(iterable=IterableData.from_tf_data_loader(tf_data_loader)) @classmethod def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: @@ -203,7 +208,7 @@ def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T: T A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`. """ - return cls(iterable=IterableData.from_torch_dataloader(torch_data_loader)) + return cls(iterable=IterableData.from_torch_data_loader(torch_data_loader)) @classmethod def from_inputs_loaders( @@ -545,3 +550,30 @@ def __iter__(self, *args, **kwargs): loader = map(lambda batch: tree_map(self._reshape_inputs, batch), self._loader) loader = jax_utils.prefetch_to_device(loader, 2) yield from loader + + +class ShardedPrefetchedLoader: + def __init__( + self, + loader, + partition_manager: Optional[PartitionManager] = None, + shard: bool = True, + partition_spec: Optional[PartitionSpec] = None + ): + self._loader = loader + self.partition_manager = partition_manager + self.shard = shard + self.partition_spec = partition_spec + if partition_manager is None and shard: + raise ValueError("`partition_manager` cannot be None when `shard` is set to True.") + + def _shard(self, data: Union[Batch, InputData, Targets]): + return device_put(data, NamedSharding(self.partition_manager.partitioner.mesh, self.partition_spec)) + + def __iter__(self, *args, **kwargs): + if self.shard: + loader = map(lambda data: tree_map(self._shard, data), self._loader) + loader = prefetch_to_mesh(loader, 2, self.partition_manager.partitioner.mesh, self.partition_spec) + else: + loader = jax_utils.prefetch_to_device(self._loader, 2) + yield from loader diff --git a/fortuna/data/loader/huggingface_loaders.py b/fortuna/data/loader/huggingface_loaders.py index 3a89801c..39e44461 100644 --- a/fortuna/data/loader/huggingface_loaders.py +++ b/fortuna/data/loader/huggingface_loaders.py @@ -35,7 +35,7 @@ def __init__( Parameters ---------- iterable : Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array],Array]]] - A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_dataloader`. + A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_data_loader`. num_unique_labels: int Number of unique target labels in the task (classification only) num_inputs: Optional[int] diff --git a/fortuna/data/loader/utils.py b/fortuna/data/loader/utils.py index 6bba7211..6bca8c1c 100644 --- a/fortuna/data/loader/utils.py +++ b/fortuna/data/loader/utils.py @@ -44,9 +44,9 @@ def _inner(): return cls(_inner) @classmethod - def from_tf_dataloader(cls, tf_dataloader) -> IterableData: + def from_tf_data_loader(cls, tf_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in tf_dataloader: + for batch_inputs, batch_targets in tf_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: @@ -57,9 +57,9 @@ def _inner(): return cls(_inner) @classmethod - def from_torch_dataloader(cls, torch_dataloader) -> IterableData: + def from_torch_data_loader(cls, torch_data_loader) -> IterableData: def _inner(): - for batch_inputs, batch_targets in torch_dataloader: + for batch_inputs, batch_targets in torch_data_loader: if not isinstance(batch_inputs, dict): batch_inputs = batch_inputs.numpy() else: diff --git a/fortuna/likelihood/base.py b/fortuna/likelihood/base.py index fd32b49a..aecf7d4d 100644 --- a/fortuna/likelihood/base.py +++ b/fortuna/likelihood/base.py @@ -215,9 +215,10 @@ def _batched_log_joint_prob( mutable=mutable, rng=rng, ) - if "mutable" in return_aux: + if mutable is not None: outputs, aux = outs - mutable = aux["mutable"] + if mutable in return_aux: + mutable = aux["mutable"] else: outputs = outs diff --git a/fortuna/model/model_manager/base.py b/fortuna/model/model_manager/base.py index f8a8cf63..88c21156 100755 --- a/fortuna/model/model_manager/base.py +++ b/fortuna/model/model_manager/base.py @@ -9,7 +9,7 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree +from optax._src.base import PyTree from jax._src.prng import PRNGKeyArray import jax.numpy as jnp diff --git a/fortuna/model/model_manager/classification.py b/fortuna/model/model_manager/classification.py index df4c85fe..dd7b2a5e 100644 --- a/fortuna/model/model_manager/classification.py +++ b/fortuna/model/model_manager/classification.py @@ -10,7 +10,7 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree +from optax._src.base import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray diff --git a/fortuna/model/model_manager/regression.py b/fortuna/model/model_manager/regression.py index 2458c6d6..2756a5f3 100644 --- a/fortuna/model/model_manager/regression.py +++ b/fortuna/model/model_manager/regression.py @@ -7,7 +7,7 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree +from optax._src.base import PyTree import jax from jax import random from jax._src.prng import PRNGKeyArray @@ -65,6 +65,7 @@ def apply( lik_log_var_rngs = None if mutable is not None: + mutable = mutable.unfreeze() mutable["model"] = mutable.get("model") mutable["lik_log_var"] = mutable.get("lik_log_var") diff --git a/fortuna/model/model_manager/transformers/classification.py b/fortuna/model/model_manager/transformers/classification.py index 6b793433..d61f93de 100644 --- a/fortuna/model/model_manager/transformers/classification.py +++ b/fortuna/model/model_manager/transformers/classification.py @@ -8,7 +8,7 @@ from flax import linen as nn from flax.core import FrozenDict -from flax.training.checkpoints import PyTree +from optax._src.base import PyTree import jax from jax import ( numpy as jnp, diff --git a/fortuna/output_calib_model/base.py b/fortuna/output_calib_model/base.py index d5dd4ec5..e5d88eee 100644 --- a/fortuna/output_calib_model/base.py +++ b/fortuna/output_calib_model/base.py @@ -10,9 +10,7 @@ from fortuna.output_calib_model.config.base import Config from fortuna.output_calib_model.loss import Loss -from fortuna.output_calib_model.output_calib_mixin import ( - WithOutputCalibCheckpointingMixin, -) +from fortuna.training.mixins.checkpointing import WithCheckpointingMixin from fortuna.output_calib_model.output_calib_model_calibrator import ( JittedOutputCalibModelCalibrator, MultiDeviceOutputCalibModelCalibrator, @@ -34,7 +32,7 @@ from fortuna.utils.random import RandomNumberGenerator -class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC): +class OutputCalibModel(WithCheckpointingMixin, abc.ABC): """ Abstract calibration model class. """ diff --git a/fortuna/output_calib_model/output_calib_mixin.py b/fortuna/output_calib_model/output_calib_mixin.py deleted file mode 100644 index 11938c06..00000000 --- a/fortuna/output_calib_model/output_calib_mixin.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from typing import Optional - -from flax.training import checkpoints - -from fortuna.output_calib_model.state import OutputCalibState -from fortuna.training.mixins.checkpointing import WithCheckpointingMixin -from fortuna.typing import ( - OptaxOptimizer, - Path, -) - - -class WithOutputCalibCheckpointingMixin(WithCheckpointingMixin): - def restore_checkpoint( - self, - restore_checkpoint_dir: Path, - optimizer: Optional[OptaxOptimizer] = None, - prefix: str = "", - **kwargs, - ) -> OutputCalibState: - if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile( - restore_checkpoint_dir - ): - raise ValueError( - f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found." - ) - d = checkpoints.restore_checkpoint( - ckpt_dir=str(restore_checkpoint_dir), - target=None, - step=None, - prefix=prefix, - parallel=True, - ) - if d is None: - raise ValueError( - f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`." - ) - - return OutputCalibState.init_from_dict(d, optimizer, **kwargs) diff --git a/fortuna/output_calibrator/output_calib_manager/base.py b/fortuna/output_calibrator/output_calib_manager/base.py index 146473e4..c7f63228 100644 --- a/fortuna/output_calibrator/output_calib_manager/base.py +++ b/fortuna/output_calibrator/output_calib_manager/base.py @@ -6,7 +6,7 @@ from flax.core import FrozenDict import flax.linen as nn -from flax.training.checkpoints import PyTree +from optax._src.base import PyTree from jax import random from jax._src.prng import PRNGKeyArray import jax.numpy as jnp diff --git a/fortuna/partitioner/base.py b/fortuna/partitioner/base.py index 55fe1338..1bf90b46 100644 --- a/fortuna/partitioner/base.py +++ b/fortuna/partitioner/base.py @@ -19,7 +19,7 @@ def __init__( n_devices: Optional[int] = None, ): if axis_dims is None: - axis_dims = {"dp": 1, "fsdp": 1, "mp": 1} + axis_dims = {"dp": 1, "fsdp": 1, "mp": -1} if rules is None: rules = {} self.specs = { diff --git a/fortuna/prob_model/base.py b/fortuna/prob_model/base.py index 1dd5d256..61bebf41 100644 --- a/fortuna/prob_model/base.py +++ b/fortuna/prob_model/base.py @@ -14,8 +14,7 @@ from fortuna.prob_model.calib_config.base import CalibConfig from fortuna.prob_model.fit_config.base import FitConfig from fortuna.prob_model.prob_model_calibrator import ( - JittedProbModelOutputCalibrator, - MultiDeviceProbModelOutputCalibrator, + ShardedProbModelOutputCalibrator, ProbModelOutputCalibrator, ) from fortuna.typing import ( @@ -137,7 +136,7 @@ def _calibrate( "Pre-compute ensemble of outputs on the calibration data loader." ) - distribute = jax.local_devices()[0].platform != "cpu" + shard = not calib_config.processor.disable_jit ( calib_ensemble_outputs_loader, @@ -146,7 +145,7 @@ def _calibrate( inputs_loader=calib_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if calib_config.monitor.verbose: logging.info( @@ -157,19 +156,20 @@ def _calibrate( inputs_loader=val_data_loader.to_inputs_loader(), n_output_samples=calib_config.processor.n_posterior_samples, return_size=True, - distribute=distribute, + shard=shard, ) if val_data_loader is not None else (None, None) ) - trainer_cls = select_trainer_given_devices( - devices=calib_config.processor.devices, - base_trainer_cls=ProbModelOutputCalibrator, - jitted_trainer_cls=JittedProbModelOutputCalibrator, - multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator, - disable_jit=calib_config.processor.disable_jit, - ) + # trainer_cls = select_trainer_given_devices( + # devices=calib_config.processor.devices, + # base_trainer_cls=ProbModelOutputCalibrator, + # jitted_trainer_cls=JittedProbModelOutputCalibrator, + # multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator, + # disable_jit=calib_config.processor.disable_jit, + # ) + trainer_cls = ShardedProbModelOutputCalibrator calibrator = trainer_cls( calib_outputs_loader=calib_ensemble_outputs_loader, diff --git a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py index f27ef58b..7d230ce0 100755 --- a/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py +++ b/fortuna/prob_model/posterior/deep_ensemble/deep_ensemble_posterior.py @@ -149,11 +149,11 @@ def _fit(i): rng=self.rng.get(), state=_state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_size, verbose=fit_config.monitor.verbose, callbacks=fit_config.callbacks, diff --git a/fortuna/prob_model/posterior/map/map_posterior.py b/fortuna/prob_model/posterior/map/map_posterior.py index 2c02e76b..08b22a53 100755 --- a/fortuna/prob_model/posterior/map/map_posterior.py +++ b/fortuna/prob_model/posterior/map/map_posterior.py @@ -17,6 +17,7 @@ JittedMAPTrainer, MAPTrainer, MultiDeviceMAPTrainer, + ShardedMAPTrainer ) from fortuna.prob_model.posterior.posterior_state_repository import ( PosteriorStateRepository, @@ -67,13 +68,14 @@ def fit( ) -> Status: super()._checks_on_fit_start(fit_config, map_fit_config) - trainer_cls = select_trainer_given_devices( - devices=fit_config.processor.devices, - base_trainer_cls=MAPTrainer, - jitted_trainer_cls=JittedMAPTrainer, - multi_device_trainer_cls=MultiDeviceMAPTrainer, - disable_jit=fit_config.processor.disable_jit, - ) + # trainer_cls = select_trainer_given_devices( + # devices=fit_config.processor.devices, + # base_trainer_cls=MAPTrainer, + # jitted_trainer_cls=JittedMAPTrainer, + # multi_device_trainer_cls=MultiDeviceMAPTrainer, + # disable_jit=fit_config.processor.disable_jit, + # ) + trainer_cls = ShardedMAPTrainer trainer = trainer_cls( predict_fn=self.joint.likelihood.prob_output_layer.predict, @@ -128,11 +130,11 @@ def init_state_fn(rng): rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/map/map_trainer.py b/fortuna/prob_model/posterior/map/map_trainer.py index c8494dbb..a9d0d686 100644 --- a/fortuna/prob_model/posterior/map/map_trainer.py +++ b/fortuna/prob_model/posterior/map/map_trainer.py @@ -19,6 +19,7 @@ from fortuna.prob_model.posterior.posterior_trainer import PosteriorTrainerABC from fortuna.training.mixins.jitted import JittedMixin from fortuna.training.mixins.multi_device import MultiDeviceMixin +from fortuna.training.mixins.sharding import ShardingMixin from fortuna.typing import ( Array, Batch, @@ -109,3 +110,7 @@ class JittedMAPTrainer(JittedMixin, MAPTrainer): class MultiDeviceMAPTrainer(MultiDeviceMixin, MAPTrainer): pass + + +class ShardedMAPTrainer(ShardingMixin, MAPTrainer): + pass diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py index 8f06738e..80e4e780 100755 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py @@ -172,11 +172,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py index 14020b83..bca4fc25 100644 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_trainer.py @@ -94,7 +94,7 @@ def _unravel_params( def on_train_start( self, state: NormalizingFlowState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[NormalizingFlowState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -150,7 +150,7 @@ def on_train_start( ), ) - return state, dataloaders, rng + return state, data_loaders, rng class JittedADVITrainer(JittedMixin, ADVITrainer): diff --git a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py index d8f25150..7ad79090 100644 --- a/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/cyclical_sgld/cyclical_sgld_posterior.py @@ -176,11 +176,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py index 4641a71d..1f69dda0 100644 --- a/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py +++ b/fortuna/prob_model/posterior/sgmcmc/sghmc/sghmc_posterior.py @@ -172,11 +172,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/posterior/swag/swag_posterior.py b/fortuna/prob_model/posterior/swag/swag_posterior.py index eda4443f..19bf40a8 100755 --- a/fortuna/prob_model/posterior/swag/swag_posterior.py +++ b/fortuna/prob_model/posterior/swag/swag_posterior.py @@ -155,11 +155,11 @@ def fit( rng=self.rng.get(), state=state, loss_fun=self.joint._batched_negative_log_joint_prob, - training_dataloader=train_data_loader, + training_data_loader=train_data_loader, training_dataset_size=train_data_loader.size, n_epochs=fit_config.optimizer.n_epochs, metrics=fit_config.monitor.metrics, - validation_dataloader=val_data_loader, + validation_data_loader=val_data_loader, validation_dataset_size=val_data_loader.size if val_data_loader is not None else None, diff --git a/fortuna/prob_model/predictive/base.py b/fortuna/prob_model/predictive/base.py index c22be74b..199fa1e4 100755 --- a/fortuna/prob_model/predictive/base.py +++ b/fortuna/prob_model/predictive/base.py @@ -24,10 +24,10 @@ from fortuna.data.loader import ( DataLoader, - DeviceDimensionAugmentedLoader, InputsLoader, TargetsLoader, ) +from fortuna.data.loader.base import ShardedPrefetchedLoader from fortuna.prob_model.posterior.base import Posterior from fortuna.typing import ( Array, @@ -56,7 +56,7 @@ def log_prob( data_loader: DataLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> jnp.ndarray: r""" @@ -79,8 +79,8 @@ def log_prob( that would be produced using the posterior distribution state. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -95,7 +95,7 @@ def log_prob( data_loader, n_posterior_samples, rng, - distribute, + shard, **kwargs, ) @@ -240,7 +240,7 @@ def sample( n_target_samples: int = 1, return_aux: Optional[List[str]] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]]: r""" @@ -264,8 +264,8 @@ def sample( Return auxiliary objects. We currently support 'outputs'. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -280,7 +280,7 @@ def fun(_inputs): _inputs, n_target_samples, return_aux, rng, **kwargs ) - if distribute: + if shard: inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) fun = pmap(fun) if return_aux is None or len(return_aux) == 0: @@ -363,7 +363,7 @@ def sample_calibrated_outputs( inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Sample parameters from the posterior distribution state and compute calibrated outputs. @@ -376,8 +376,8 @@ def sample_calibrated_outputs( Number of output samples to draw for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -392,7 +392,7 @@ def sample_calibrated_outputs( inputs_loader, n_output_samples, rng, - distribute, + shard, ) def _sample_batched_calibrated_outputs( @@ -422,7 +422,7 @@ def _sample_outputs( inputs_loader: InputsLoader, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: if rng is None: rng = self.rng.get() @@ -432,7 +432,7 @@ def _sample_outputs( inputs_loader, n_output_samples, rng, - distribute, + shard, ) def _sample_batched_outputs( @@ -440,16 +440,37 @@ def _sample_batched_outputs( inputs: Array, n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, + shard: bool = True ) -> jnp.ndarray: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) + def _apply_fn(params, mutable): + outputs = self.likelihood.model_manager.apply( + params=params, inputs=inputs, mutable=mutable + ) + if mutable is not None: + return outputs[0] + return outputs + + if shard: + _apply_fn = pjit( + _apply_fn, + in_shardings=( + self.posterior.partition_manager.shardings.params, + self.posterior.partition_manager.shardings.mutable + ) + ) + else: + _apply_fn = jit(_apply_fn) + def _sample(key): sample = self.posterior.sample(inputs=inputs, rng=key) - return self.likelihood.model_manager.apply( - params=sample.params, inputs=inputs, mutable=sample.mutable - ) + if shard: + with self.posterior.partition_manager.partitioner.mesh: + return _apply_fn(sample.params, sample.mutable) + return return lax.map(_sample, keys) @@ -459,23 +480,13 @@ def _sample_outputs_loader( n_output_samples: int = 1, rng: Optional[PRNGKeyArray] = None, return_size: bool = False, - distribute: bool = True, + shard: bool = True, ) -> Union[TargetsLoader, Tuple[TargetsLoader, int]]: if rng is None: rng = self.rng.get() keys = random.split(rng, n_output_samples) - if distribute: - inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) - - def _sample(key, _inputs): - sample = self.posterior.sample(inputs=_inputs, rng=key) - return self.likelihood.model_manager.apply( - params=sample.params, inputs=_inputs, mutable=sample.mutable - ) - - _sample = pmap(_sample) if distribute else jit(_sample) - # _sample = pjit(_sample, in_shardings=(PartitionSpec(), PartitionSpec())) + inputs_loader = ShardedPrefetchedLoader(inputs_loader, self.posterior.partition_manager, shard=shard) iterable = [] size = 0 @@ -485,14 +496,18 @@ def _sample(key, _inputs): if not isinstance(inputs, dict) else inputs[list(inputs.keys())[0]].shape[0] ) - if distribute: - outputs = jnp.stack( - list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)) + outputs = jnp.stack( + list( + map( + lambda key: self._sample_batched_outputs( + inputs=inputs, + rng=key, + shard=shard + )[0], + keys + ) ) - outputs = self._unshard_ensemble_arrays(outputs) - else: - # with self.posterior.partition_manager.partitioner.mesh: - outputs = lax.map(lambda key: _sample(key, inputs), keys) + ) iterable.append(outputs) iterable = TargetsLoader.from_iterable(iterable=iterable) if return_size: @@ -504,7 +519,7 @@ def mean( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mean of the target variable, that is @@ -526,8 +541,8 @@ def mean( Number of samples to draw from the posterior distribution for each input. rng: Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -538,7 +553,7 @@ def mean( rng = self.rng.get() return self._loop_fun_through_inputs_loader( - self._batched_mean, inputs_loader, n_posterior_samples, rng, distribute + self._batched_mean, inputs_loader, n_posterior_samples, rng, shard ) def _batched_mean( @@ -573,7 +588,7 @@ def mode( n_posterior_samples: int = 30, means: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive mode of the target variable, that is @@ -596,8 +611,8 @@ def mode( An estimate of the predictive mean. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -611,7 +626,7 @@ def aleatoric_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive aleatoric variance of the target variable, that is @@ -633,8 +648,8 @@ def aleatoric_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -649,7 +664,7 @@ def aleatoric_variance( inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_aleatoric_variance( @@ -682,7 +697,7 @@ def epistemic_variance( inputs_loader: InputsLoader, n_posterior_samples: int = 30, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive epistemic variance of the one-hot encoded target variable, that is @@ -704,8 +719,8 @@ def epistemic_variance( Number of samples to draw from the posterior distribution for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -720,7 +735,7 @@ def epistemic_variance( inputs_loader, n_posterior_samples, rng, - distribute, + shard, ) def _batched_epistemic_variance( @@ -762,7 +777,7 @@ def variance( aleatoric_variances: Optional[jnp.ndarray] = None, epistemic_variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive variance of the target variable, that is @@ -789,8 +804,8 @@ def variance( An estimate of the epistemic predictive variance for each input. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -805,7 +820,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) if epistemic_variances is None: rng, key = random.split(rng) @@ -813,7 +828,7 @@ def variance( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=key, - distribute=distribute, + shard=shard, ) return aleatoric_variances + epistemic_variances @@ -823,7 +838,7 @@ def std( n_posterior_samples: int = 30, variances: Optional[jnp.ndarray] = None, rng: Optional[PRNGKeyArray] = None, - distribute: bool = True, + shard: bool = True, ) -> jnp.ndarray: r""" Estimate the predictive standard deviation of the target variable, that is @@ -846,8 +861,8 @@ def std( An estimate of the predictive variance. rng : Optional[PRNGKeyArray] A random number generator. If not passed, this will be taken from the attributes of this class. - distribute: bool - Whether to distribute computation over multiple devices, if available. + shard: bool + Whether to shard computation over multiple devices, if available. Returns ------- @@ -859,13 +874,13 @@ def std( inputs_loader=inputs_loader, n_posterior_samples=n_posterior_samples, rng=rng, - distribute=distribute, + shard=shard, ) return jnp.sqrt(variances) @staticmethod def _unshard_ensemble_arrays(arr: Array) -> Array: - arr = arr.swapaxes(1, 2) + arr = jnp.moveaxis(arr, 0, 2) arr = arr.reshape((arr.shape[0] * arr.shape[1],) + arr.shape[2:]) return arr.swapaxes(0, 1) @@ -875,13 +890,13 @@ def _loop_fun_through_inputs_loader( inputs_loader: InputsLoader, n_posterior_samples: int, rng: PRNGKeyArray, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> Array: def fun2(_inputs): return fun(_inputs, n_posterior_samples, rng, **kwargs) - if distribute: + if shard: inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) fun2 = pmap(fun2) return jnp.concatenate( @@ -900,13 +915,13 @@ def _loop_fun_through_data_loader( data_loader: DataLoader, n_posterior_samples: int, rng: PRNGKeyArray, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> Array: def fun2(_batch): return fun(_batch, n_posterior_samples, rng, **kwargs) - if distribute: + if shard: data_loader = DeviceDimensionAugmentedLoader(data_loader) fun2 = pmap(fun2) return jnp.concatenate( @@ -922,13 +937,13 @@ def _loop_ensemble_fun_through_inputs_loader( inputs_loader: InputsLoader, n_posterior_samples: int, rng: PRNGKeyArray, - distribute: bool = True, + shard: bool = True, **kwargs, ) -> Array: def fun2(_inputs): return fun(_inputs, n_posterior_samples, rng, **kwargs) - if distribute: + if shard: inputs_loader = DeviceDimensionAugmentedLoader(inputs_loader) fun2 = pmap(fun2) return jnp.concatenate( diff --git a/fortuna/prob_model/prob_model_calibrator.py b/fortuna/prob_model/prob_model_calibrator.py index 24cc1692..6201b2b2 100644 --- a/fortuna/prob_model/prob_model_calibrator.py +++ b/fortuna/prob_model/prob_model_calibrator.py @@ -15,8 +15,7 @@ from fortuna.data import TargetsLoader from fortuna.output_calib_model.state import OutputCalibState from fortuna.training.output_calibrator import ( - JittedMixin, - MultiDeviceMixin, + ShardingMixin, OutputCalibratorABC, ) from fortuna.typing import ( @@ -84,49 +83,5 @@ def __str__(self): return "calibration" -class ProbModelMultiDeviceMixin(MultiDeviceMixin): - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[1] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[1]} was found. " - f"Please set an appropriate batch size." - ) - shape = batch.shape - return ( - batch.swapaxes(0, 1) - .reshape(n_devices, shape[1] // n_devices, shape[0], shape[2]) - .swapaxes(1, 2) - ) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - -class JittedProbModelOutputCalibrator(JittedMixin, ProbModelOutputCalibrator): - pass - - -class MultiDeviceProbModelOutputCalibrator( - ProbModelMultiDeviceMixin, ProbModelOutputCalibrator -): +class ShardedProbModelOutputCalibrator(ShardingMixin, ProbModelOutputCalibrator): pass diff --git a/fortuna/prob_model/regression.py b/fortuna/prob_model/regression.py index e121ea4c..6c8b26fe 100755 --- a/fortuna/prob_model/regression.py +++ b/fortuna/prob_model/regression.py @@ -150,6 +150,8 @@ def _check_output_dim(self, data_loader: DataLoader): outputs = self.model_manager.apply( params=s.params, inputs=np.zeros((1,) + input_shape), mutable=s.mutable ) + if s.mutable is not None: + outputs, _ = outputs if outputs.shape[1] != 2 * output_dim: raise ValueError( f"""The outputs dimension of both `model` and `likelihood_log_variance_model` must be the same as diff --git a/fortuna/training/mixins/checkpointing.py b/fortuna/training/mixins/checkpointing.py index 75e261f2..9d65a7e6 100644 --- a/fortuna/training/mixins/checkpointing.py +++ b/fortuna/training/mixins/checkpointing.py @@ -64,7 +64,7 @@ def save_checkpoint( if save_checkpoint_dir is not None else self.checkpoint_manager ) - if self.checkpoint_manager: + if checkpoint_manager is not None: save_args = save_args_from_target(state) def save_ckpt_fn(_state): @@ -96,15 +96,12 @@ def restore_checkpoint( else: ref = self._get_ref_without_shardings() - restored = pure_callback( - lambda: self.checkpoint_manager.restore( + restored = self.checkpoint_manager.restore( self.checkpoint_manager.latest_step(), items=ref, restore_kwargs={"restore_args": ref}, directory=restore_checkpoint_dir, - ), - ref, - ) + ) if optimizer is not None: restored = restored.replace( @@ -129,7 +126,10 @@ def get_shapes_dtypes_checkpoint( def _get_ref_from_shardings(self): return tree_map_with_path( lambda p, sharding, shape_dtype: ArrayRestoreArgsWithShape( - sharding=sharding, dtype=shape_dtype.dtype, shape=shape_dtype.shape + mesh=self.partition_manager.partitioner.mesh, + sharding=sharding, + dtype=shape_dtype.dtype, + shape=shape_dtype.shape ), self.partition_manager.shardings, self.partition_manager.shapes_dtypes, diff --git a/fortuna/training/mixins/multi_device.py b/fortuna/training/mixins/multi_device.py index ec4ad367..fae73e03 100644 --- a/fortuna/training/mixins/multi_device.py +++ b/fortuna/training/mixins/multi_device.py @@ -37,7 +37,7 @@ def __init__(self, *args, **kwargs): self.multi_device = True @staticmethod - def _add_device_dim_to_input_dataloader(dataloader: DataLoader) -> DataLoader: + def _add_device_dim_to_input_data_loader(data_loader: DataLoader) -> DataLoader: def _reshape_input_batch(batch): n_devices = jax.local_device_count() if batch.shape[0] % n_devices != 0: @@ -50,17 +50,17 @@ def _reshape_input_batch(batch): return batch.reshape((n_devices, -1) + single_input_shape) class DataLoaderWrapper: - def __init__(self, dataloader): - self.dataloader = dataloader + def __init__(self, data_loader): + self.data_loader = data_loader def __iter__(self): - dataloader = map( - lambda batch: tree_map(_reshape_input_batch, batch), self.dataloader + data_loader = map( + lambda batch: tree_map(_reshape_input_batch, batch), self.data_loader ) - dataloader = jax_utils.prefetch_to_device(dataloader, 2) - yield from dataloader + data_loader = jax_utils.prefetch_to_device(data_loader, 2) + yield from data_loader - return DataLoaderWrapper(dataloader) if dataloader is not None else dataloader + return DataLoaderWrapper(data_loader) if data_loader is not None else data_loader @staticmethod def _sync_mutable(state: TrainState) -> TrainState: @@ -80,17 +80,17 @@ def _sync_state(self, state: TrainState) -> TrainState: return jax.device_get(tree_map(lambda x: x[0], state)) def on_train_start( - self, state: TrainState, dataloaders: List[DataLoader], rng: PRNGKeyArray + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: - state, dataloaders, rng = super(MultiDeviceMixin, self).on_train_start( - state, dataloaders, rng + state, data_loaders, rng = super(MultiDeviceMixin, self).on_train_start( + state, data_loaders, rng ) state = jax_utils.replicate(state) - dataloaders = [ - self._add_device_dim_to_input_dataloader(dl) for dl in dataloaders + data_loaders = [ + self._add_device_dim_to_input_data_loader(dl) for dl in data_loaders ] model_key = random.split(rng, jax.local_device_count()) - return state, dataloaders, model_key + return state, data_loaders, model_key def on_train_end(self, state: TrainState) -> TrainState: state = super(MultiDeviceMixin, self).on_train_end(state) diff --git a/fortuna/training/mixins/sharded.py b/fortuna/training/mixins/sharded.py deleted file mode 100644 index 42511ce8..00000000 --- a/fortuna/training/mixins/sharded.py +++ /dev/null @@ -1,153 +0,0 @@ -from functools import partial -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, -) - -from flax import jax_utils -from flax.core import FrozenDict -import jax -from jax import ( - lax, - random, -) -from jax._src.prng import PRNGKeyArray -import jax.numpy as jnp -from jax.tree_util import tree_map -from optax._src.base import PyTree - -from fortuna.data.loader import DataLoader -from fortuna.training.callback import Callback -from fortuna.training.train_state import TrainState -from fortuna.typing import ( - Array, - Batch, -) - - -class MeshedMixin: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.meshed = True - - @staticmethod - def _add_device_dim_to_input_dataloader(dataloader: DataLoader) -> DataLoader: - def _reshape_input_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. Please set an appropriate batch size in the data loader." - ) - single_input_shape = batch.shape[1:] - # reshape to (local_devices, device_batch_size, *single_input_shape) - return batch.reshape((n_devices, -1) + single_input_shape) - - class DataLoaderWrapper: - def __init__(self, dataloader): - self.dataloader = dataloader - - def __iter__(self): - dataloader = map( - lambda batch: tree_map(_reshape_input_batch, batch), self.dataloader - ) - dataloader = jax_utils.prefetch_to_device(dataloader, 2) - yield from dataloader - - return DataLoaderWrapper(dataloader) if dataloader is not None else dataloader - - @staticmethod - def _sync_mutable(state: TrainState) -> TrainState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable is not None - else state - ) - - @staticmethod - def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: - arr = lax.pmean(arr, axis_name="batch") - return arr - - def _sync_state(self, state: TrainState) -> TrainState: - state = self._sync_mutable(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def on_train_start( - self, state: TrainState, dataloaders: List[DataLoader], rng: PRNGKeyArray - ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: - state, dataloaders, rng = super(MultiDeviceMixin, self).on_train_start( - state, dataloaders, rng - ) - state = jax_utils.replicate(state) - dataloaders = [ - self._add_device_dim_to_input_dataloader(dl) for dl in dataloaders - ] - model_key = random.split(rng, jax.local_device_count()) - return state, dataloaders, model_key - - def on_train_end(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - def training_step_start(self, rng: PRNGKeyArray, step: int) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] - return jax.vmap(lambda r: random.fold_in(r, step))(rng) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7)) - def training_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, Any]]: - return super().training_step( - state, batch, loss_fun, rng, n_data, unravel, kwargs - ) - - def training_step_end( - self, - current_epoch: int, - state: TrainState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - callbacks: Optional[List[Callback]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Tuple[TrainState, Dict[str, jnp.ndarray]]: - state, training_losses_and_metrics = super( - MultiDeviceMixin, self - ).training_step_end( - current_epoch, state, aux, batch, metrics, callbacks, kwargs - ) - return state, tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_validation_start(self, state: TrainState) -> TrainState: - state = super(MultiDeviceMixin, self).on_validation_start(state) - state = self._sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 3, 5, 6, 7, 8)) - def validation_step( - self, - state: TrainState, - batch: Batch, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - unravel: Optional[Callable[[any], PyTree]] = None, - kwargs: FrozenDict[str, Any] = FrozenDict(), - ) -> Dict[str, jnp.ndarray]: - validation_losses_and_metrics = super().validation_step( - state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs - ) - return lax.pmean(validation_losses_and_metrics, axis_name="batch") diff --git a/fortuna/training/mixins/sharding.py b/fortuna/training/mixins/sharding.py new file mode 100644 index 00000000..c92a6d4a --- /dev/null +++ b/fortuna/training/mixins/sharding.py @@ -0,0 +1,109 @@ +from functools import partial +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, Union, +) + +from flax import jax_utils +from flax.core import FrozenDict +import jax +from jax import ( + lax, + random, + Array, + device_put +) +from jax.sharding import NamedSharding +import numpy as np +from jax.experimental.pjit import pjit +from jax._src.prng import PRNGKeyArray +import jax.numpy as jnp +from jax.tree_util import tree_map +from numpy import ndarray +from optax._src.base import PyTree + +from fortuna.data.loader import DataLoader +from fortuna.data.loader.base import ShardedPrefetchedLoader +from fortuna.training.callback import Callback +from fortuna.training.train_state import TrainState +from fortuna.typing import ( + Array, + Batch, +) +from jax.sharding import PartitionSpec +from fortuna.partitioner.partition_manager.base import PartitionManager + + +class ShardingMixin: + def __init__(self, *, partition_manager: PartitionManager, **kwargs): + super().__init__(partition_manager=partition_manager, **kwargs) + self.partition_manager = partition_manager + + def training_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Tuple[TrainState, Dict[str, Any]]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().training_step, + static_argnums=(2, 4, 5, 6), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + out_shardings=( + self.partition_manager.shardings, + PartitionSpec(), + ), + )( + state, batch, loss_fun, rng, n_data, unravel, kwargs + ) + + def validation_step( + self, + state: TrainState, + batch: Batch, + loss_fun: Callable, + rng: PRNGKeyArray, + n_data: int, + metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, + unravel: Optional[Callable[[any], PyTree]] = None, + kwargs: FrozenDict[str, Any] = FrozenDict(), + ) -> Dict[str, jnp.ndarray]: + with self.partition_manager.partitioner.mesh: + return pjit( + super().validation_step, + static_argnums=(2, 4, 5, 6, 7), + in_shardings=( + self.partition_manager.shardings, + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + )( + state, batch, loss_fun, rng, n_data, metrics, unravel, kwargs + ) + + def on_train_start( + self, state: TrainState, data_loaders: List[DataLoader], rng: PRNGKeyArray + ) -> Tuple[TrainState, List[ShardedPrefetchedLoader], PRNGKeyArray]: + state, data_loaders, rng = super(ShardingMixin, self).on_train_start( + state, data_loaders, rng + ) + data_loaders = [ + ShardedPrefetchedLoader( + loader=dl, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp"))) for dl in data_loaders + ] + return state, data_loaders, rng diff --git a/fortuna/training/output_calibrator.py b/fortuna/training/output_calibrator.py index 8a598528..f3f00d31 100644 --- a/fortuna/training/output_calibrator.py +++ b/fortuna/training/output_calibrator.py @@ -12,7 +12,8 @@ Union, ) -from flax import jax_utils +from jax.sharding import PartitionSpec +from jax.experimental.pjit import pjit from flax.training.common_utils import stack_forest import jax from jax import ( @@ -20,6 +21,7 @@ random, value_and_grad, ) +from fortuna.data.loader.base import ShardedPrefetchedLoader from jax._src.prng import PRNGKeyArray import jax.numpy as jnp from jax.tree_util import tree_map @@ -80,7 +82,6 @@ def __init__( self.keep_top_n_checkpoints = keep_top_n_checkpoints self.disable_training_metrics_computation = disable_training_metrics_computation self.eval_every_n_epochs = eval_every_n_epochs - self.multi_device = False def train( self, @@ -98,7 +99,7 @@ def train( verbose: bool = True, ) -> Tuple[OutputCalibState, Status]: training_losses_and_metrics = collections.defaultdict(list) - val_losses_and_metrics = collections.defaultdict(list) + validation_losses_and_metrics = collections.defaultdict(list) state, data_loaders, outputs_loaders, rng = self.on_train_start( state, @@ -137,11 +138,11 @@ def train( # validation loop if self.should_perform_validation(val_data_loader, epoch): # performance evaluation on the whole validation dataset - state = self.on_val_start(state) + state = self.on_validation_start(state) ( - val_losses_and_metrics_current_epoch, - val_epoch_metrics_str, - ) = self._val_loop( + validation_losses_and_metrics_current_epoch, + validation_epoch_metrics_str, + ) = self._validation_loop( loss_fun=loss_fun, metrics=metrics, rng=rng, @@ -152,11 +153,11 @@ def train( verbose=verbose, ) if verbose: - logging.info(f"Epoch: {epoch + 1} | " + val_epoch_metrics_str) + logging.info(f"Epoch: {epoch + 1} | " + validation_epoch_metrics_str) # keep track of training losses and metrics [granularity=epoch] and check for early stopping - for k in val_losses_and_metrics_current_epoch.keys(): - val_losses_and_metrics[k].append( - val_losses_and_metrics_current_epoch[k] + for k in validation_losses_and_metrics_current_epoch.keys(): + validation_losses_and_metrics[k].append( + validation_losses_and_metrics_current_epoch[k] ) # check for early stopping if self.is_early_stopping_active and self._early_stopping.should_stop: @@ -167,8 +168,8 @@ def train( training_status = { k: jnp.array(v) for k, v in training_losses_and_metrics.items() } - val_status = {k: jnp.array(v) for k, v in val_losses_and_metrics.items()} - status = dict(**training_status, **val_status) + validation_status = {k: jnp.array(v) for k, v in validation_losses_and_metrics.items()} + status = dict(**training_status, **validation_status) state = self.on_train_end(state) return state, status @@ -304,27 +305,14 @@ def training_step_end( if not self.disable_training_metrics_computation and metrics is not None: preds = self.predict_fn(aux["outputs"]) uncertainties = self.uncertainty_fn(aux["outputs"]) - if self.multi_device: - training_batch_metrics = self.compute_metrics( - preds.reshape((preds.shape[0] * preds.shape[1],) + preds.shape[2:]), - uncertainties.reshape( - (uncertainties.shape[0] * uncertainties.shape[1],) - + uncertainties.shape[2:] - ), - batch[1].reshape( - (batch[1].shape[0] * batch[1].shape[1],) + batch[1].shape[2:] - ), - metrics, - ) - else: - training_batch_metrics = self.compute_metrics( - preds, uncertainties, batch[1], metrics - ) + training_batch_metrics = self.compute_metrics( + preds, uncertainties, batch[1], metrics + ) for k, v in training_batch_metrics.items(): training_losses_and_metrics[k] = v return training_losses_and_metrics - def _val_loop( + def _validation_loop( self, loss_fun: Callable, metrics: Optional[ @@ -337,10 +325,10 @@ def _val_loop( val_dataset_size: int, verbose: bool = True, ) -> Tuple[Dict[str, float], str]: - val_losses_and_metrics_epoch_all_steps = [] - val_epoch_metrics_str = "" + validation_losses_and_metrics_epoch_all_steps = [] + validation_epoch_metrics_str = "" for batch, outputs in zip(val_data_loader, val_outputs_loader): - val_losses_and_metrics_current_batch = self.val_step( + validation_losses_and_metrics_current_batch = self.validation_step( state, batch, outputs, @@ -349,24 +337,24 @@ def _val_loop( val_dataset_size, metrics, ) - val_losses_and_metrics_epoch_all_steps.append( - val_losses_and_metrics_current_batch + validation_losses_and_metrics_epoch_all_steps.append( + validation_losses_and_metrics_current_batch ) # compute validation losses and metrics for the current epoch - val_losses_and_metrics_current_epoch = self.val_epoch_end( - val_losses_and_metrics_epoch_all_steps, state + validation_losses_and_metrics_current_epoch = self.validation_epoch_end( + validation_losses_and_metrics_epoch_all_steps, state ) # logging if verbose: - val_epoch_metrics_str = " | ".join( + validation_epoch_metrics_str = " | ".join( [ f"{m}: {round(float(v), 5)}" - for m, v in val_losses_and_metrics_current_epoch.items() + for m, v in validation_losses_and_metrics_current_epoch.items() ] ) - return val_losses_and_metrics_current_epoch, val_epoch_metrics_str + return validation_losses_and_metrics_current_epoch, validation_epoch_metrics_str - def val_step( + def validation_step( self, state: OutputCalibState, batch: Batch, @@ -378,12 +366,12 @@ def val_step( Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] ] = None, ) -> Dict[str, jnp.ndarray]: - val_loss, aux = self.val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - val_metrics = self.val_metrics_step(aux, batch, metrics) - return {"val_loss": val_loss, **val_metrics} + validation_loss, aux = self.validation_loss_step(state, batch, outputs, loss_fun, rng, n_data) + validation_metrics = self.validation_metrics_step(aux, batch, metrics) + return {"validation_loss": validation_loss, **validation_metrics} @abc.abstractmethod - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -394,7 +382,7 @@ def val_loss_step( ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: pass - def val_metrics_step( + def validation_metrics_step( self, aux: Dict[str, jnp.ndarray], batch: Batch, @@ -403,13 +391,13 @@ def val_metrics_step( ] = None, ) -> Dict[str, jnp.ndarray]: if metrics is not None: - val_metrics = self.compute_metrics( + validation_metrics = self.compute_metrics( self.predict_fn(aux["outputs"]), self.uncertainty_fn(aux["outputs"]), batch[1], metrics, ) - return {f"val_{m}": v for m, v in val_metrics.items()} + return {f"validation_{m}": v for m, v in validation_metrics.items()} else: return {} @@ -420,19 +408,19 @@ def training_epoch_end( training_losses_and_metrics_current_epoch ) - def val_epoch_end( + def validation_epoch_end( self, - val_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], + validation_losses_and_metrics_current_epoch: List[Dict[str, jnp.ndarray]], state: OutputCalibState, ) -> Dict[str, float]: - val_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( - val_losses_and_metrics_current_epoch + validation_losses_and_metrics_current_epoch = self._get_mean_losses_and_metrics( + validation_losses_and_metrics_current_epoch ) # early stopping - improved = self.early_stopping_update(val_losses_and_metrics_current_epoch) + improved = self.early_stopping_update(validation_losses_and_metrics_current_epoch) if improved and self.save_checkpoint_dir is not None: self.save_checkpoint(state, self.save_checkpoint_dir, force_save=True) - return val_losses_and_metrics_current_epoch + return validation_losses_and_metrics_current_epoch def _get_mean_losses_and_metrics( self, losses_and_metrics: List[Dict[str, jnp.ndarray]] @@ -474,7 +462,7 @@ def on_train_end(self, state: OutputCalibState) -> OutputCalibState: ) return state - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: + def on_validation_start(self, state: OutputCalibState) -> OutputCalibState: return state def compute_metrics( @@ -492,8 +480,7 @@ def compute_metrics( return metrics_vals -class JittedMixin: - @partial(jax.jit, static_argnums=(0, 4, 6)) +class ShardingMixin: def training_step( self, state: OutputCalibState, @@ -503,10 +490,21 @@ def training_step( rng: PRNGKeyArray, n_data: int, ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) + with self.partition_manager.partitioner.mesh: + return pjit( + super().training_step, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.output_calib_shardings, + PartitionSpec("dp"), + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + )( + state, batch, outputs, loss_fun, rng, n_data + ) - @partial(jax.jit, static_argnums=(0, 4, 6)) - def val_loss_step( + def validation_loss_step( self, state: OutputCalibState, batch: Batch, @@ -515,103 +513,19 @@ def val_loss_step( rng: PRNGKeyArray, n_data: int, ) -> Dict[str, jnp.ndarray]: - return super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - - -class MultiDeviceMixin: - all_reduce_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.multi_device = True - - @staticmethod - def _add_device_dim_to_data_loader(data_loader: DataLoader) -> DataLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class DataLoaderWrapper: - def __init__(self, data_loader: DataLoader): - self._data_loader = data_loader - - def __iter__(self): - data_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._data_loader - ) - data_loader = jax_utils.prefetch_to_device(data_loader, 2) - yield from data_loader - - return ( - DataLoaderWrapper(data_loader) if data_loader is not None else data_loader - ) - - @staticmethod - def _add_device_dim_to_outputs_loader( - outputs_loader: TargetsLoader, - ) -> TargetsLoader: - def _reshape_batch(batch): - n_devices = jax.local_device_count() - if batch.shape[0] % n_devices != 0: - raise ValueError( - f"The size of all output batches must be a multiple of {n_devices}, that is the number of " - f"available devices. However, a batch of outputs with shape {batch.shape[0]} was found. " - f"Please set an appropriate batch size." - ) - return batch.reshape((n_devices, -1) + batch.shape[1:]) - - class TargetsLoaderWrapper: - def __init__(self, outputs_loader: TargetsLoader): - self._outputs_loader = outputs_loader - - def __iter__(self): - outputs_loader = map( - lambda batch: tree_map(_reshape_batch, batch), self._outputs_loader - ) - outputs_loader = jax_utils.prefetch_to_device(outputs_loader, 2) - yield from outputs_loader - - return ( - TargetsLoaderWrapper(outputs_loader) - if outputs_loader is not None - else outputs_loader - ) - - @staticmethod - def sync_mutable(state: OutputCalibState) -> OutputCalibState: - return ( - state.replace(mutable=MultiDeviceMixin.all_reduce_mean(state.mutable)) - if state.mutable["output_calibrator"] is not None - else state - ) - - @staticmethod - def sync_gradients_and_loss( - grads: jnp.ndarray, loss: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - grad = lax.pmean(grads, axis_name="batch") - loss = lax.pmean(loss, axis_name="batch") - return grad, loss - - def save_checkpoint( - self, - state: OutputCalibState, - save_checkpoint_dir: Path, - keep: int = 1, - force_save: bool = False, - prefix: str = "", - ) -> None: - state = self.sync_mutable(state) - state = jax.device_get(tree_map(lambda x: x[0], state)) - return super(MultiDeviceMixin, self).save_checkpoint( - state, save_checkpoint_dir, keep, force_save, prefix - ) + with self.partition_manager.partitioner.mesh: + return pjit( + super().validation_loss_step, + static_argnums=(3, 5), + in_shardings=( + self.partition_manager.output_calib_shardings, + PartitionSpec("dp"), + PartitionSpec(("dp", "fsdp")), + PartitionSpec(), + ), + )( + state, batch, outputs, loss_fun, rng, n_data + ) def on_train_start( self, @@ -619,85 +533,22 @@ def on_train_start( data_loaders: List[DataLoader], outputs_loaders: List[TargetsLoader], rng: PRNGKeyArray, - ) -> Tuple[OutputCalibState, List[DataLoader], List[TargetsLoader], PRNGKeyArray]: - state, data_loaders, outputs_loaders, rng = super( - MultiDeviceMixin, self - ).on_train_start(state, data_loaders, outputs_loaders, rng) - state = jax_utils.replicate(state) + ) -> Tuple[OutputCalibState, List[ShardedPrefetchedLoader], List[ShardedPrefetchedLoader], PRNGKeyArray]: + state, data_loaders, output_loaders, rng = super(ShardingMixin, self).on_train_start( + state, data_loaders, outputs_loaders, rng + ) data_loaders = [ - self._add_device_dim_to_data_loader(dl) if dl is not None else dl - for dl in data_loaders + ShardedPrefetchedLoader( + loader=data_loader, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec(("dp", "fsdp")) + ) for data_loader in data_loaders ] outputs_loaders = [ - self._add_device_dim_to_outputs_loader(ol) if ol is not None else ol - for ol in outputs_loaders + ShardedPrefetchedLoader( + loader=outputs_loader, + partition_manager=self.partition_manager, + partition_spec=PartitionSpec() + ) for outputs_loader in outputs_loaders ] - model_key = random.split(rng, jax.local_device_count()) - return state, data_loaders, outputs_loaders, model_key - - def on_train_end(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_train_end(state) - return jax.device_get(tree_map(lambda x: x[0], state)) - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def training_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Tuple[OutputCalibState, Dict[str, Any]]: - return super().training_step(state, batch, outputs, loss_fun, rng, n_data) - - def training_step_end( - self, - current_epoch: int, - state: OutputCalibState, - aux: Dict[str, Any], - batch: Batch, - metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]], - ) -> Dict[str, jnp.ndarray]: - training_losses_and_metrics = super(MultiDeviceMixin, self).training_step_end( - current_epoch, state, aux, batch, metrics - ) - return tree_map(lambda x: x.mean(), training_losses_and_metrics) - - def on_val_start(self, state: OutputCalibState) -> OutputCalibState: - state = super(MultiDeviceMixin, self).on_val_start(state) - if state.mutable["output_calibrator"] is not None: - state = self.sync_mutable(state) - return state - - @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0, 4, 6)) - def val_loss_step( - self, - state: OutputCalibState, - batch: Batch, - outputs: Array, - loss_fun: Callable, - rng: PRNGKeyArray, - n_data: int, - ) -> Dict[str, jnp.ndarray]: - val_losses = super().val_loss_step(state, batch, outputs, loss_fun, rng, n_data) - return lax.pmean(val_losses, axis_name="batch") - - def val_metrics_step( - self, - aux: Dict[str, jnp.ndarray], - batch: Batch, - metrics: Optional[ - Tuple[Callable[[jnp.ndarray, jnp.ndarray, Array], Array], ...] - ] = None, - ) -> Dict[str, jnp.ndarray]: - outputs = aux["outputs"] - outputs = outputs.reshape(outputs.shape[0] * outputs.shape[1], -1) - targets = batch[1].reshape(batch[1].shape[0] * batch[1].shape[1], -1) - if metrics is not None: - val_metrics = self.compute_metrics( - self.predict_fn(outputs), self.uncertainty_fn(outputs), targets, metrics - ) - return {f"val_{m}": v for m, v in val_metrics.items()} - else: - return {} + return state, data_loaders, outputs_loaders, rng diff --git a/fortuna/training/trainer.py b/fortuna/training/trainer.py index 7010edf5..8b79949a 100755 --- a/fortuna/training/trainer.py +++ b/fortuna/training/trainer.py @@ -13,14 +13,13 @@ Union, ) -from flax import jax_utils from flax.core import FrozenDict from flax.training.common_utils import stack_forest import jax from jax import ( - lax, random, value_and_grad, + vmap ) from jax._src.prng import PRNGKeyArray import jax.numpy as jnp @@ -370,11 +369,11 @@ def train( rng: PRNGKeyArray, state: TrainState, loss_fun: Callable, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, n_epochs: int = 1, metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], float], ...]] = None, - validation_dataloader: Optional[DataLoader] = None, + validation_data_loader: Optional[DataLoader] = None, validation_dataset_size: Optional[int] = None, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, @@ -382,18 +381,18 @@ def train( **kwargs, ) -> Tuple[TrainState, Status]: training_kwargs = FrozenDict(kwargs) - if validation_dataloader: + if validation_data_loader: assert ( validation_dataset_size is not None - ), "`validation_dataset_size` is required when `validation_dataloader` is provided." + ), "`validation_dataset_size` is required when `validation_data_loader` is provided." training_losses_and_metrics = collections.defaultdict(list) validation_losses_and_metrics = collections.defaultdict(list) - state, dataloaders, rng = self.on_train_start( - state, [training_dataloader, validation_dataloader], rng + state, data_loaders, rng = self.on_train_start( + state, [training_data_loader, validation_data_loader], rng ) - training_dataloader, validation_dataloader = dataloaders + training_data_loader, validation_data_loader = data_loaders progress_bar = trange(n_epochs, desc="Epoch") for epoch in progress_bar: @@ -408,7 +407,7 @@ def train( metrics, rng, state, - training_dataloader, + training_data_loader, training_dataset_size, training_kwargs, verbose, @@ -423,7 +422,7 @@ def train( ) # validation loop - if self.should_perform_validation(validation_dataloader, epoch): + if self.should_perform_validation(validation_data_loader, epoch): # performance evaluation on the whole validation dataset state = self.on_validation_start(state) ( @@ -435,7 +434,7 @@ def train( rng=rng, state=state, training_kwargs=training_kwargs, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=validation_dataset_size, verbose=verbose, unravel=unravel, @@ -473,7 +472,7 @@ def _training_loop( metrics: Optional[Tuple[Callable[[jnp.ndarray, Array], jnp.ndarray], ...]], rng: PRNGKeyArray, state: TrainState, - training_dataloader: DataLoader, + training_data_loader: DataLoader, training_dataset_size: int, training_kwargs: FrozenDict[str, Any], verbose: bool, @@ -489,7 +488,7 @@ def _training_loop( state = self.training_epoch_start(state, callbacks) # ensure to use a different key at each step model_key = self.training_step_start(rng, state.step) - for step, batch in enumerate(training_dataloader): + for step, batch in enumerate(training_data_loader): # forward and backward pass state, aux = self.training_step( state, @@ -552,14 +551,14 @@ def _validation_loop( rng: PRNGKeyArray, state: TrainState, training_kwargs: FrozenDict[str, Any], - validation_dataloader: DataLoader, + validation_data_loader: DataLoader, validation_dataset_size: int, verbose: bool = True, unravel: Optional[Callable[[any], PyTree]] = None, ) -> Tuple[Dict[str, float], str]: validation_losses_and_metrics_epoch_all_steps = [] validation_epoch_metrics_str = "" - for batch in validation_dataloader: + for batch in validation_data_loader: validation_losses_and_metrics_current_batch = self.validation_step( state, batch, @@ -603,10 +602,10 @@ def _get_mean_losses_and_metrics( return losses_and_metrics def should_perform_validation( - self, validation_dataloader: Optional[DataLoader], epoch: int + self, validation_data_loader: Optional[DataLoader], epoch: int ) -> bool: return ( - validation_dataloader is not None + validation_data_loader is not None and self.eval_every_n_epochs > 0 and epoch % self.eval_every_n_epochs == 0 ) @@ -618,7 +617,7 @@ def _sync_array(arr: jnp.ndarray) -> jnp.ndarray: def on_train_start( self, state: TrainState, - dataloaders: List[DataLoader], + data_loaders: List[DataLoader], rng: PRNGKeyArray, ) -> Tuple[TrainState, List[DataLoader], PRNGKeyArray]: if self.freeze_fun is not None: @@ -652,7 +651,7 @@ def on_train_start( ) ), ) - return state, dataloaders, rng + return state, data_loaders, rng def on_train_end(self, state: TrainState) -> TrainState: self.save_checkpoint( @@ -693,7 +692,7 @@ def compute_metrics( def training_step_start( self, rng: PRNGKeyArray, step: Union[int, jax.Array] ) -> PRNGKeyArray: - step = step if isinstance(step, int) or step.ndim == 0 else step[0] + # step = step if isinstance(step, int) or step.ndim == 0 else step[0] return random.fold_in(rng, step) def _sync_state(self, state: TrainState) -> TrainState: diff --git a/fortuna/utils/mesh.py b/fortuna/utils/mesh.py index 905e6697..af502e57 100644 --- a/fortuna/utils/mesh.py +++ b/fortuna/utils/mesh.py @@ -1,6 +1,6 @@ from typing import Dict -from jax import device_count +from jax import local_device_count from jax.experimental.mesh_utils import create_device_mesh from jax.interpreters import pxla from jax.lax import with_sharding_constraint @@ -28,7 +28,7 @@ def get_mesh(axis_dims: Dict[str, int]): if len(np.where(np.array(dims) == -1)[0]) > 1: raise ValueError("At most one axis dimension can be `-1`.") - n_devices = device_count() + n_devices = local_device_count() fixed_prod = np.prod([v for v in dims if v != -1]) reminder = n_devices % fixed_prod diff --git a/fortuna/utils/prefetch.py b/fortuna/utils/prefetch.py new file mode 100644 index 00000000..983a09cf --- /dev/null +++ b/fortuna/utils/prefetch.py @@ -0,0 +1,20 @@ +import collections +import itertools +import jax +from jax.sharding import NamedSharding, Mesh + + +def prefetch_to_mesh(iterator, size: int, mesh: Mesh, xs_spec): + queue = collections.deque() + + def _prefetch(xs): + return jax.device_put(xs, NamedSharding(mesh, xs_spec)) + + def enqueue(n): # Enqueues *up to* `n` elements from the iterator. + for data in itertools.islice(iterator, n): + queue.append(jax.tree_util.tree_map(_prefetch, data)) + + enqueue(size) # Fill up the buffer. + while queue: + yield queue.popleft() + enqueue(1) diff --git a/poetry.lock b/poetry.lock index 3b60dc74..cc9ffade 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2080,16 +2080,17 @@ arrow = ">=0.15.0" [[package]] name = "jax" -version = "0.4.10" +version = "0.4.13" description = "Differentiate, compile, and transform Numpy code." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jax-0.4.10.tar.gz", hash = "sha256:1bf0f2720f778f2937301a16a4d5cd3497f13a4d6c970c24a88918a81816a888"}, + {file = "jax-0.4.13.tar.gz", hash = "sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa"}, ] [package.dependencies] +importlib_metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} ml_dtypes = ">=0.1.0" numpy = ">=1.21" opt_einsum = "*" @@ -2097,38 +2098,40 @@ scipy = ">=1.7" [package.extras] australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.9)"] -cpu = ["jaxlib (==0.4.10)"] -cuda = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-cudnn82 = ["jaxlib (==0.4.10+cuda11.cudnn82)"] -cuda11-cudnn86 = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-local = ["jaxlib (==0.4.10+cuda11.cudnn86)"] -cuda11-pip = ["jaxlib (==0.4.10+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.6)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] -cuda12-local = ["jaxlib (==0.4.10+cuda12.cudnn88)"] -cuda12-pip = ["jaxlib (==0.4.10+cuda12.cudnn88)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] -minimum-jaxlib = ["jaxlib (==0.4.7)"] -tpu = ["jaxlib (==0.4.10)", "libtpu-nightly (==0.1.dev20230511)", "requests"] +ci = ["jaxlib (==0.4.12)"] +cpu = ["jaxlib (==0.4.13)"] +cuda = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-cudnn86 = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-local = ["jaxlib (==0.4.13+cuda11.cudnn86)"] +cuda11-pip = ["jaxlib (==0.4.13+cuda11.cudnn86)", "nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-local = ["jaxlib (==0.4.13+cuda12.cudnn89)"] +cuda12-pip = ["jaxlib (==0.4.13+cuda12.cudnn89)", "nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] +minimum-jaxlib = ["jaxlib (==0.4.11)"] +tpu = ["jaxlib (==0.4.13)", "libtpu-nightly (==0.1.dev20230622)"] [[package]] name = "jaxlib" -version = "0.4.10" +version = "0.4.13" description = "XLA library for JAX" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "jaxlib-0.4.10-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:0814c478382e82b0f90aacd820fb898c4a9caa705a1f515c5fd0928198c814f3"}, - {file = "jaxlib-0.4.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:97e7b2b0f32debbb011556cb2dc82cdfb0087b618e302f92319475727408a64e"}, - {file = "jaxlib-0.4.10-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:f46edb93332285ab6f57b2843869183cbd495b4f35bea0fba25a3766a7429306"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1b705e495149945defe478781865d403bd3994c11e326829aea7aafda0dfa639"}, - {file = "jaxlib-0.4.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:377757745d5e2097fccce71c31292973d544a36329b7ed85bf9c41837e107f74"}, - {file = "jaxlib-0.4.10-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a6349c98c3ffd879b390a3532390e8e49f084aa523c1553aa5c21374ca8b4ea9"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:7dc9c89b2b07cf8c576d5fca433181f324fed52e51db60873d2b6d3e496588e2"}, - {file = "jaxlib-0.4.10-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62f3d2bad0476bb6728d1be813894cf3421a3d31706a0208b1f57eec86d310d5"}, - {file = "jaxlib-0.4.10-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:6ea7ad6b520732994e25429768d6bd731d55c59c75ef6f9faa2f59e419fb0ada"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:749a1135a452db1afb4e5de7770fc5dafebb310c35d9db077ed925fcab028471"}, - {file = "jaxlib-0.4.10-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c966b13467c41ff44ba1b3b7cdceb37a76a75f0420f454a8a51543f8bbaabe4a"}, - {file = "jaxlib-0.4.10-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:07557fcf1e4c7c60bbb48c4f4f426909fcf610a7bfa56cbb139719ba3900722d"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0"}, + {file = "jaxlib-0.4.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac"}, + {file = "jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8"}, + {file = "jaxlib-0.4.13-cp310-cp310-win_amd64.whl", hash = "sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb"}, + {file = "jaxlib-0.4.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649"}, + {file = "jaxlib-0.4.13-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89"}, + {file = "jaxlib-0.4.13-cp311-cp311-win_amd64.whl", hash = "sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab"}, + {file = "jaxlib-0.4.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0"}, + {file = "jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f"}, + {file = "jaxlib-0.4.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc"}, + {file = "jaxlib-0.4.13-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5"}, + {file = "jaxlib-0.4.13-cp39-cp39-win_amd64.whl", hash = "sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e"}, ] [package.dependencies] @@ -2136,6 +2139,10 @@ ml-dtypes = ">=0.1.0" numpy = ">=1.21" scipy = ">=1.7" +[package.extras] +cuda11-pip = ["nvidia-cublas-cu11 (>=11.11)", "nvidia-cuda-cupti-cu11 (>=11.8)", "nvidia-cuda-nvcc-cu11 (>=11.8)", "nvidia-cuda-runtime-cu11 (>=11.8)", "nvidia-cudnn-cu11 (>=8.8)", "nvidia-cufft-cu11 (>=10.9)", "nvidia-cusolver-cu11 (>=11.4)", "nvidia-cusparse-cu11 (>=11.7)"] +cuda12-pip = ["nvidia-cublas-cu12", "nvidia-cuda-cupti-cu12", "nvidia-cuda-nvcc-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cudnn-cu12 (>=8.9)", "nvidia-cufft-cu12", "nvidia-cusolver-cu12", "nvidia-cusparse-cu12"] + [[package]] name = "jedi" version = "0.18.2" diff --git a/tests/fortuna/prob_model/test_train.py b/tests/fortuna/prob_model/test_train.py index 5f18e81d..120bc53b 100755 --- a/tests/fortuna/prob_model/test_train.py +++ b/tests/fortuna/prob_model/test_train.py @@ -15,6 +15,7 @@ FitMonitor, SNGPPosteriorApproximator, ) +from flax import linen as nn from fortuna.prob_model.classification import ProbClassifier from fortuna.prob_model.fit_config.checkpointer import FitCheckpointer from fortuna.prob_model.fit_config.optimizer import FitOptimizer @@ -41,23 +42,24 @@ MyModel, MyModelWithSpectralNorm, ) +from fortuna.partitioner.base import Partitioner OUTPUT_DIM = 2 BATCH_SIZE = 8 INPUT_SHAPE = (3,) -N_DATA = 10 +N_DATA = 16 METHODS = { "map": MAPPosteriorApproximator(), - "advi": ADVIPosteriorApproximator(), - "laplace": LaplacePosteriorApproximator(), - "swag": SWAGPosteriorApproximator(rank=2), - "deep_ensemble": DeepEnsemblePosteriorApproximator(ensemble_size=2), - "sngp": SNGPPosteriorApproximator(output_dim=OUTPUT_DIM, gp_hidden_features=2), - "sghmc": SGHMCPosteriorApproximator(n_samples=3, n_thinning=1, burnin_length=1), - "cyclical_sgld": CyclicalSGLDPosteriorApproximator( - n_samples=3, n_thinning=1, cycle_length=4 - ), + # "advi": ADVIPosteriorApproximator(), + # "laplace": LaplacePosteriorApproximator(), + # "swag": SWAGPosteriorApproximator(rank=2), + # "deep_ensemble": DeepEnsemblePosteriorApproximator(ensemble_size=2), + # "sngp": SNGPPosteriorApproximator(output_dim=OUTPUT_DIM, gp_hidden_features=2), + # "sghmc": SGHMCPosteriorApproximator(n_samples=3, n_thinning=1, burnin_length=1), + # "cyclical_sgld": CyclicalSGLDPosteriorApproximator( + # n_samples=3, n_thinning=1, cycle_length=4 + # ), } @@ -173,12 +175,21 @@ def train_and_sample( def define_prob_model(task, method, model_editor=None): + partitioner = Partitioner( + axis_dims={"mp": 1, "fsdp": -1, "dp": 1}, + rules={ + "l1/kernel": ("mp", "fsdp"), + "bn1": ("fsdp",) + } + ) + if task == "regression": return ProbRegressor( model=MyModel(OUTPUT_DIM), likelihood_log_variance_model=MyModel(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner ) else: return ProbClassifier( @@ -187,6 +198,7 @@ def define_prob_model(task, method, model_editor=None): else MyModelWithSpectralNorm(OUTPUT_DIM), posterior_approximator=METHODS[method], model_editor=model_editor, + partitioner=partitioner ) diff --git a/tests/fortuna/test_trainer.py b/tests/fortuna/test_trainer.py index fd243915..a6b81fa1 100755 --- a/tests/fortuna/test_trainer.py +++ b/tests/fortuna/test_trainer.py @@ -568,7 +568,7 @@ def test_should_perform_validation(self): self.assertTrue(trainer.should_perform_validation({}, 10)) def test__validation_loop(self): - validation_dataloader = [ + validation_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -580,7 +580,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -600,7 +600,7 @@ def test__validation_loop(self): observed_validation_epoch_metrics_str, ) = trainer._validation_loop( state=None, - validation_dataloader=validation_dataloader, + validation_data_loader=validation_data_loader, validation_dataset_size=2, loss_fun=lambda x: x, rng=jax.random.PRNGKey(0), @@ -618,7 +618,7 @@ def test__validation_loop(self): ) def test__training_loop(self): - training_dataloader = [ + training_data_loader = [ [jnp.array([[0, 0.0, 0.0], [0, 0.0, 0]]), jnp.array([0.0, 0.0])], [jnp.array([[0.1, 0.0, 10], [0, 0.0, 0]]), jnp.array([1.0, 0.0])], ] @@ -635,7 +635,7 @@ def test__training_loop(self): metrics=(accuracy,), rng=jax.random.PRNGKey(0), state=FakeTrainState(), - training_dataloader=training_dataloader, + training_data_loader=training_data_loader, training_dataset_size=2, training_kwargs=FrozenDict({}), unravel=None, diff --git a/tests/make_model.py b/tests/make_model.py index 1469ea63..1149a396 100644 --- a/tests/make_model.py +++ b/tests/make_model.py @@ -1,12 +1,13 @@ import flax.linen as nn import jax.numpy as jnp - +from functools import partial from fortuna.model.utils.spectral_norm import WithSpectralNorm class MyModel(nn.Module): output_dim: int dense: nn.Module = nn.Dense + dtype: str = "float32" @nn.compact def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: @@ -14,8 +15,16 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray: dense = self.spectral_norm(self.dense, train=train) else: dense = self.dense + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = x.reshape(x.shape[0], -1) - x = dense(2, name="l1")(x) + x = dense(4, name="l1")(x) + x = norm(name="bn1")(x) x = nn.Dropout(rate=0.9)(x, deterministic=not train) x = dense(self.output_dim, name="l2")(x) return x