Skip to content

Commit 0706b0f

Browse files
authored
Merge pull request #178 from google-research/rajat_dev
Setting median outputs as default. Also minor changes to finetuning.
2 parents 9b302ae + be37aba commit 0706b0f

File tree

6 files changed

+491
-420
lines changed

6 files changed

+491
-420
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
# e.g. poetry version 0.1.${{ github.run_number }}
2424
- name: Set Version number
2525
run: |
26-
poetry version 1.2.1
26+
poetry version 1.2.2
2727
- name: Build and Publish to PyPI
2828
run: |
2929
poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }}

notebooks/finetuning.ipynb

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,14 @@
6262
"outputs": [],
6363
"source": [
6464
"tfm = timesfm.TimesFm(\n",
65-
" context_len=512,\n",
66-
" horizon_len=128,\n",
67-
" input_patch_len=32,\n",
68-
" output_patch_len=128,\n",
69-
" num_layers=20,\n",
70-
" model_dims=1280,\n",
71-
" backend=\"gpu\",\n",
72-
")\n",
73-
"tfm.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")"
65+
" hparams=timesfm.TimesFmHparams(\n",
66+
" backend=\"gpu\",\n",
67+
" per_core_batch_size=32,\n",
68+
" horizon_len=128,\n",
69+
" ),\n",
70+
" checkpoint=timesfm.TimesFmCheckpoint(\n",
71+
" huggingface_repo_id=\"google/timesfm-1.0-200m\"),\n",
72+
" )"
7473
]
7574
},
7675
{
@@ -209,8 +208,8 @@
209208
"for batch in tqdm(test_batches.as_numpy_iterator()):\n",
210209
" past = batch[0]\n",
211210
" actuals = batch[3]\n",
212-
" _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
213-
" forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
211+
" forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)\n",
212+
" forecasts = forecasts[:, 0 : actuals.shape[1]]\n",
214213
" mae_losses.append(np.abs(forecasts - actuals).mean())\n",
215214
"\n",
216215
"print(f\"MAE: {np.mean(mae_losses)}\")"

0 commit comments

Comments
 (0)