Skip to content

Commit

Permalink
Merge pull request #204 from google-research/rajat_dev
Browse files Browse the repository at this point in the history
Adding v2.0 support
  • Loading branch information
rajatsen91 authored Dec 31, 2024
2 parents 5a69171 + 4aec27e commit 028188b
Show file tree
Hide file tree
Showing 12 changed files with 1,478 additions and 1,423 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
# e.g. poetry version 0.1.${{ github.run_number }}
- name: Set Version number
run: |
poetry version 1.2.4
poetry version 1.2.6
- name: Build and Publish to PyPI
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }}
Expand Down
57 changes: 46 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,35 @@ Research for time-series forecasting.

* Paper: [A decoder-only foundation model for time-series forecasting](https://arxiv.org/abs/2310.10688), to appear in ICML 2024.
* [Google Research blog](https://research.google/blog/a-decoder-only-foundation-model-for-time-series-forecasting/)
* [Hugging Face checkpoint repo](https://huggingface.co/google/timesfm-1.0-200m)
* [Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)

This repo contains the code to load public TimesFM checkpoints and run model
inference. Please visit our
[Hugging Face checkpoint repo](https://huggingface.co/google/timesfm-1.0-200m)
[Hugging Face release](https://huggingface.co/collections/google/timesfm-release-66e4be5fdb56e960c1e482a6)
to download model checkpoints.

This is not an officially supported Google product.

We recommend at least 16GB RAM to load TimesFM dependencies.
We recommend at least 32GB RAM to load TimesFM dependencies.

## Update - Sep. 12, 2024
- We have released full pytorch support (excluding PEFT parts).
- Shoutout to @tanmayshishodia for checking in PEFT methods like LoRA and DoRA.
- To install TimesFM, you can now simply do: `pip install timesfm`.
## Update - Dec. 30, 2024
- We are launching a 500m checkpoint as a part of TimesFM-2.0 release.
- Launched [finetuning support](https://github.com/google-research/timesfm/blob/master/notebooks/finetuning.ipynb) that lets you finetune the weights of the pretrained TimesFM model on your own data.
- Launched [~zero-shot covariate support](https://github.com/google-research/timesfm/blob/master/notebooks/covariates.ipynb) with external regressors. More details [here](https://github.com/google-research/timesfm?tab=readme-ov-file#covariates-support).

## Checkpoint timesfm-1.0-200m (-pytorch)

timesfm-1.0-200m is the first open model checkpoint:
timesfm-1.0-200m is our first open model checkpoint:

- It performs univariate time series forecasting for context lengths up to 512 timepoints and any horizon lengths, with an optional frequency indicator.
- It focuses on point forecasts, and does not support probabilistic forecasts. We experimentally offer quantile heads but they have not been calibrated after pretraining.
- It requires the context to be contiguous (i.e. no "holes"), and the context and the horizon to be of the same frequency.

## Checkpoint timesfm-2.0-500m (-jax/-pytorch)

timesfm-2.0-500m is our second open model checkpoint:

- It performs univariate time series forecasting for context lengths up to 2048 timepoints and any horizon lengths, with an optional frequency indicator.
- It focuses on point forecasts. We experimentally offer 10 quantile heads but they have not been calibrated after pretraining.

## Benchmarks

Expand Down Expand Up @@ -103,6 +107,37 @@ Then the base class can be loaded as,
```python
import timesfm

# Loading the timesfm-2.0 checkpoint:
# For PAX
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
context_len=2048,

use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-jax"),
)

# For Torch
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
use_positional_embedding=False,
context_len=2048,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
)

# Loading the timesfm-1.0 checkpoint:
# For PAX
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
Expand All @@ -126,9 +161,9 @@ tfm = timesfm.TimesFm(
)
```

Note some of the parameters are fixed to load the 200m model
Note some of the parameters are fixed to load the 200m and 500m models

1. The `context_len` in `hparams` here can be set as the max context length **of the model**. **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. Currently, the model handles a max context length of 512, which can be increased in later releases. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.
1. The `context_len` in `hparams` here can be set as the max context length **of the model** (a maximum of 2048 for 2.0 models and 512 for 1.0 models). **It needs to be a multiplier of `input_patch_len`, i.e. a multiplier of 32.** You can provide a shorter series to the `tfm.forecast()` function and the model will handle it. The input time series can have **any context length**. Padding / truncation will be handled by the inference code if needed.

2. The horizon length can be set to anything. We recommend setting it to the largest horizon length you would need in the forecasting tasks for your application. We generally recommend horizon length <= context length but it is not a requirement in the function call.

Expand Down
2 changes: 1 addition & 1 deletion experiments/extended_benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ poetry run python3 -m experiments.extended_benchmarks.run_timesfm --model_path=g

Note: In the current version of TimesFM we focus on point forecasts and therefore the mase, smape have been calculated using the quantile head corresponding to the median i.e 0.5 quantile. We do offer 10 quantile heads but they have not been calibrated after pretraining. We recommend using them with caution or calibrate/conformalize them on a hold out for your applications. More to follow on later versions.

## Benchmark Results
## Benchmark Results for TimesFM-1.0

![Benchmark Results Table](./tfm_extended_new.png)

Expand Down
2 changes: 1 addition & 1 deletion experiments/long_horizon_benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ poetry run python3 -m experiments.long_horizon_benchmarks.run_eval \

You can change the model size from "mini" to "large" as required. The datasets we benchmark on are etth1, etth2, ettm1 and ettm2.

## Benchmark Results
## Benchmark Results for TimesFM-1.0

![Benchmark Results Table](./tfm_long_horizon.png)

Expand Down
86 changes: 33 additions & 53 deletions notebooks/covariates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"This toturial notebook demonstrates how to utilize exogenous covariates with TimesFM when making forecasts. Before running this notebook, make sure:\n",
"\n",
"- You've read through the README of TimesFM.\n",
"- A local kernel with Python 3.10 is up and running."
"- A local kernel with Python 3.10 is up and running, for the jax version.\n",
"- Install the JAX version following the installation instructions."
]
},
{
Expand All @@ -19,36 +20,15 @@
"## Setup the environment and install TimesFM."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n",
"os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install timesfm\n",
"import timesfm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the checkpoint\n",
"\n",
"**Notice:** Please set up the backend as per your machine (\"cpu\", \"gpu\" or \"tpu\"). This notebook will run by default on CPU.\n",
"**Notice:** Please set up the backend as per your machine (\"cpu\", \"gpu\" or \"tpu\"). This notebook will run by default on GPU.\n",
"\n",
"We load the 1.0-200m model checkpoint from HuggingFace."
"We load the 2.0-500m model checkpoint from HuggingFace."
]
},
{
Expand All @@ -57,23 +37,21 @@
"metadata": {},
"outputs": [],
"source": [
"timesfm_backend = \"cpu\" # @param\n",
"\n",
"from jax._src import config\n",
"config.update(\n",
" \"jax_platforms\", {\"cpu\": \"cpu\", \"gpu\": \"cuda\", \"tpu\": \"\"}[timesfm_backend]\n",
")\n",
"import timesfm\n",
"timesfm_backend = \"gpu\" # @param\n",
"\n",
"model = timesfm.TimesFm(\n",
" context_len=512,\n",
" horizon_len=128,\n",
" input_patch_len=32,\n",
" output_patch_len=128,\n",
" num_layers=20,\n",
" model_dims=1280,\n",
" backend=timesfm_backend,\n",
")\n",
"model.load_from_checkpoint(repo_id=\"google/timesfm-1.0-200m\")"
" hparams=timesfm.TimesFmHparams(\n",
" backend=timesfm_backend,\n",
" per_core_batch_size=32,\n",
" horizon_len=128,\n",
" num_layers=50,\n",
" use_positional_embedding=False,\n",
" context_len=2048,\n",
" ),\n",
" checkpoint=timesfm.TimesFmCheckpoint(\n",
" huggingface_repo_id=\"google/timesfm-2.0-500m-jax\"),\n",
" )"
]
},
{
Expand Down Expand Up @@ -140,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -176,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -208,7 +186,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -239,13 +217,15 @@
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"# Benchmark\n",
"batch_size = 128\n",
"context_len = 120\n",
"horizon_len = 24\n",
"input_data = get_batched_data_fn(batch_size = 128)\n",
"metrics = defaultdict(list)\n",
"import time\n",
"\n",
"\n",
"for i, example in enumerate(input_data()):\n",
" raw_forecast, _ = model.forecast(\n",
Expand Down Expand Up @@ -307,15 +287,15 @@
"source": [
"You should see results close to \n",
"```\n",
"eval_mae_timesfm: 6.762283045916956\n",
"eval_mae_xreg_timesfm: 5.39219617611074\n",
"eval_mae_xreg: 37.15275842572484\n",
"eval_mse_timesfm: 166.7771466306823\n",
"eval_mse_xreg_timesfm: 120.64757721021306\n",
"eval_mse_xreg: 1672.2116821201796\n",
"eval_mae_timesfm: 6.729583250571446\n",
"eval_mae_xreg_timesfm: 5.3375301110158\n",
"eval_mae_xreg: 37.152760709266\n",
"eval_mse_timesfm: 162.3132151851567\n",
"eval_mse_xreg_timesfm: 120.9900627409689\n",
"eval_mse_xreg: 1672.208769045399\n",
"```\n",
"\n",
"With the covariates, the TimesFM forecast Mean Absolute Error improves from 6.76 to 5.39, and Mean Squred Error from 166.78 to 120.65. The results of purely fitting the linear model are also provided for reference."
"With the covariates, the TimesFM forecast Mean Absolute Error improves from 6.73 to 5.34, and Mean Squred Error from 162.31 to 120.99. The results of purely fitting the linear model are also provided for reference."
]
},
{
Expand Down Expand Up @@ -381,9 +361,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "cuda-gpt",
"display_name": "chronos-v2",
"language": "python",
"name": "cuda"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -395,7 +375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 028188b

Please sign in to comment.