Skip to content

Commit

Permalink
Merge pull request #178 from google-research/rajat_dev
Browse files Browse the repository at this point in the history
Setting median outputs as default. Also minor changes to finetuning.
  • Loading branch information
siriuz42 authored Nov 6, 2024
2 parents 9b302ae + be37aba commit 0706b0f
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 420 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
# e.g. poetry version 0.1.${{ github.run_number }}
- name: Set Version number
run: |
poetry version 1.2.1
poetry version 1.2.2
- name: Build and Publish to PyPI
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }}
Expand Down
21 changes: 10 additions & 11 deletions notebooks/finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,14 @@
"outputs": [],
"source": [
"tfm = timesfm.TimesFm(\n",
" context_len=512,\n",
" horizon_len=128,\n",
" input_patch_len=32,\n",
" output_patch_len=128,\n",
" num_layers=20,\n",
" model_dims=1280,\n",
" backend=\"gpu\",\n",
")\n",
"tfm.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")"
" hparams=timesfm.TimesFmHparams(\n",
" backend=\"gpu\",\n",
" per_core_batch_size=32,\n",
" horizon_len=128,\n",
" ),\n",
" checkpoint=timesfm.TimesFmCheckpoint(\n",
" huggingface_repo_id=\"google/timesfm-1.0-200m\"),\n",
" )"
]
},
{
Expand Down Expand Up @@ -209,8 +208,8 @@
"for batch in tqdm(test_batches.as_numpy_iterator()):\n",
" past = batch[0]\n",
" actuals = batch[3]\n",
" _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
" forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
" forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)\n",
" forecasts = forecasts[:, 0 : actuals.shape[1]]\n",
" mae_losses.append(np.abs(forecasts - actuals).mean())\n",
"\n",
"print(f\"MAE: {np.mean(mae_losses)}\")"
Expand Down
Loading

0 comments on commit 0706b0f

Please sign in to comment.