Skip to content

Commit efb86e0

Browse files
authored
Chronos-2: Add after_batch callback (#436)
*Issue #, if available:* *Description of changes:* Adds support for custom callbacks after each batch is processed during prediction. This allows for keeping track of the time limit in AutoGluon. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 2896499 commit efb86e0

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

src/chronos/chronos2/pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
12+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
1313

1414
import numpy as np
1515
import torch
@@ -577,6 +577,8 @@ def predict(
577577
# effective batch size increases by a factor of `len(unrolled_quantiles)` when making long-horizon predictions,
578578
# by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
579579
unrolled_quantiles = kwargs.pop("unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
580+
# A callback which is called after each batch has been processed
581+
after_batch_callback: Callable = kwargs.pop("after_batch", lambda: None)
580582

581583
if len(kwargs) > 0:
582584
raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}.")
@@ -641,6 +643,7 @@ def predict(
641643
target_idx_ranges=batch_target_idx_ranges,
642644
)
643645
all_predictions.extend(batch_prediction)
646+
after_batch_callback()
644647

645648
return all_predictions
646649

test/test_chronos2.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from chronos.chronos2.config import Chronos2CoreConfig
1717
from chronos.chronos2.layers import MHA
1818
from chronos.df_utils import convert_df_input_to_list_of_dicts_input
19-
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor
19+
from test.util import create_df, create_future_df, get_forecast_start_times, validate_tensor, timeout_callback
2020

2121
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
2222

@@ -1008,6 +1008,17 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
10081008
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
10091009

10101010

1011+
def test_when_predict_df_called_with_timeout_callback_then_timeout_error_is_raised(pipeline):
1012+
num_series = 1000
1013+
large_df = create_df(series_ids=[j for j in range(num_series)], n_points=[2048] * num_series)
1014+
with pytest.raises(TimeoutError, match="time limit exceeded"):
1015+
pipeline.predict_df(
1016+
large_df,
1017+
prediction_length=48,
1018+
after_batch=timeout_callback(0.1),
1019+
)
1020+
1021+
10111022
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
10121023
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
10131024
"""Test that the pipeline works with different attention implementations."""

test/util.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Tuple
1+
import time
2+
from typing import Callable, Optional, Tuple
23

34
import numpy as np
45
import pandas as pd
@@ -13,7 +14,6 @@ def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype: Optional[tor
1314
assert a.dtype == dtype
1415

1516

16-
1717
def create_df(series_ids=["A", "B"], n_points=[10, 10], target_cols=["target"], covariates=None, freq="h"):
1818
"""Helper to create test context DataFrames."""
1919
series_dfs = []
@@ -44,4 +44,15 @@ def get_forecast_start_times(df, freq="h"):
4444
context_end_times = df.groupby("item_id")["timestamp"].max()
4545
forecast_start_times = [pd.date_range(end_time, periods=2, freq=freq)[-1] for end_time in context_end_times]
4646

47-
return forecast_start_times
47+
return forecast_start_times
48+
49+
50+
def timeout_callback(seconds: float | None) -> Callable:
51+
"""Return a callback object that raises an exception if time limit is exceeded."""
52+
start_time = time.monotonic()
53+
54+
def callback() -> None:
55+
if seconds is not None and time.monotonic() - start_time > seconds:
56+
raise TimeoutError("time limit exceeded")
57+
58+
return callback

0 commit comments

Comments
 (0)