From c1e2094a1011cd75ad96b829d5c02180e633f1d8 Mon Sep 17 00:00:00 2001 From: andreramosfdc Date: Wed, 30 Oct 2024 18:08:17 -0300 Subject: [PATCH] Enable model snapshot downloads to be saved in a specified local directory. --- src/timesfm/timesfm_base.py | 1 + src/timesfm/timesfm_torch.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 0364089..4d60a4c 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -130,6 +130,7 @@ class TimesFmCheckpoint: huggingface_repo_id: str | None = None type: Any = None step: int | None = None + local_dir: str | None = None class TimesFmBase: diff --git a/src/timesfm/timesfm_torch.py b/src/timesfm/timesfm_torch.py index 5137e71..5775e8d 100644 --- a/src/timesfm/timesfm_torch.py +++ b/src/timesfm/timesfm_torch.py @@ -55,8 +55,9 @@ def load_from_checkpoint( checkpoint_path = checkpoint.path repo_id = checkpoint.huggingface_repo_id if checkpoint_path is None: - checkpoint_path = path.join(snapshot_download(repo_id), - "torch_model.ckpt") + checkpoint_path = path.join( + snapshot_download(repo_id, local_dir=checkpoint.local_dir), + "torch_model.ckpt") self._model = ppd.PatchedTimeSeriesDecoder(self._model_config) loaded_checkpoint = torch.load(checkpoint_path, weights_only=True) logging.info("Loading checkpoint from %s", checkpoint_path)