Skip to content

Commit 086e660

Browse files
authored
Add unittests for df_utils (#414)
*Issue #, if available:* *Description of changes:* This PR improves test coverage by adding unit tests for `df_utils`. Previously these methods were only being tested as part of Chronos-2 integration tests. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent d608e0d commit 086e660

File tree

2 files changed

+364
-0
lines changed

2 files changed

+364
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,6 @@ cython_debug/
163163
.DS_store
164164

165165
chronos-2-finetuned
166+
167+
# Kiro IDE
168+
.kiro

test/test_df_utils.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from unittest.mock import patch
5+
6+
import numpy as np
7+
import pandas as pd
8+
import pytest
9+
10+
from chronos.df_utils import (
11+
convert_df_input_to_list_of_dicts_input,
12+
validate_df_inputs,
13+
)
14+
from test.util import create_df, create_future_df, get_forecast_start_times
15+
16+
17+
# Tests for validate_df_inputs function
18+
19+
20+
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
21+
def test_validate_df_inputs_returns_correct_metadata_for_valid_inputs(freq):
22+
"""Test that function returns validated dataframes, frequency, series lengths, and original order."""
23+
# Create test data with 2 series
24+
df = create_df(series_ids=["A", "B"], n_points=[10, 15], target_cols=["target"], freq=freq)
25+
26+
# Call validate_df_inputs
27+
validated_df, validated_future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
28+
df=df,
29+
future_df=None,
30+
target_columns=["target"],
31+
prediction_length=5,
32+
id_column="item_id",
33+
timestamp_column="timestamp",
34+
)
35+
36+
# Verify key return values
37+
assert validated_future_df is None
38+
assert inferred_freq is not None
39+
assert series_lengths == [10, 15]
40+
assert list(original_order) == ["A", "B"]
41+
# Verify dataframe is sorted
42+
assert validated_df["item_id"].iloc[0] == "A"
43+
assert validated_df["item_id"].iloc[10] == "B"
44+
45+
46+
def test_validate_df_inputs_casts_mixed_dtypes_correctly():
47+
"""Test that numeric columns are cast to float32 and categorical/string/object columns are cast to category."""
48+
# Create dataframe with mixed column types
49+
df = pd.DataFrame({
50+
"item_id": ["A"] * 10,
51+
"timestamp": pd.date_range(end="2001-10-01", periods=10, freq="h"),
52+
"target": np.random.randn(10), # numeric
53+
"numeric_cov": np.random.randint(0, 10, 10), # integer numeric
54+
"string_cov": ["cat1"] * 5 + ["cat2"] * 5, # string
55+
"bool_cov": [True, False] * 5, # boolean
56+
})
57+
58+
# Call validate_df_inputs
59+
validated_df, _, _, _, _ = validate_df_inputs(
60+
df=df,
61+
future_df=None,
62+
target_columns=["target"],
63+
prediction_length=5,
64+
)
65+
66+
# Verify dtypes after validation
67+
assert validated_df["target"].dtype == np.float32
68+
assert validated_df["numeric_cov"].dtype == np.float32
69+
assert validated_df["string_cov"].dtype.name == "category"
70+
assert validated_df["bool_cov"].dtype == np.float32 # booleans are cast to float32
71+
72+
73+
def test_validate_df_inputs_raises_error_when_series_has_insufficient_data():
74+
"""Test that ValueError is raised for series with < 3 data points."""
75+
# Create dataframe with one series having only 2 points
76+
df = create_df(series_ids=["A", "B"], n_points=[10, 2], target_cols=["target"], freq="h")
77+
78+
# Verify error is raised with series ID in message
79+
with pytest.raises(ValueError, match=r"Every time series must have at least 3 data points.*series B"):
80+
validate_df_inputs(
81+
df=df,
82+
future_df=None,
83+
target_columns=["target"],
84+
prediction_length=5,
85+
)
86+
87+
88+
def test_validate_df_inputs_raises_error_when_future_df_has_mismatched_series_ids():
89+
"""Test that ValueError is raised when future_df has different series IDs than df."""
90+
# Create df with series A and B
91+
df = create_df(series_ids=["A", "B"], n_points=[10, 15], target_cols=["target"], freq="h")
92+
93+
# Create future_df with only series A
94+
forecast_start_times = get_forecast_start_times(df, freq="h")
95+
future_df = create_future_df(
96+
forecast_start_times=[forecast_start_times[0]],
97+
series_ids=["A"],
98+
n_points=[5],
99+
covariates=None,
100+
freq="h"
101+
)
102+
103+
# Verify appropriate error is raised
104+
with pytest.raises(ValueError, match=r"future_df must contain the same time series IDs as df"):
105+
validate_df_inputs(
106+
df=df,
107+
future_df=future_df,
108+
target_columns=["target"],
109+
prediction_length=5,
110+
)
111+
112+
113+
def test_validate_df_inputs_raises_error_when_future_df_has_incorrect_lengths():
114+
"""Test that ValueError is raised when future_df lengths don't match prediction_length."""
115+
# Create df with series A and B with a covariate
116+
df = create_df(series_ids=["A", "B"], n_points=[10, 13], target_cols=["target"], covariates=["cov1"], freq="h")
117+
118+
# Create future_df with varying lengths per series (3 and 7 instead of 5)
119+
forecast_start_times = get_forecast_start_times(df, freq="h")
120+
future_df = create_future_df(
121+
forecast_start_times=forecast_start_times,
122+
series_ids=["A", "B"],
123+
n_points=[3, 7], # incorrect lengths
124+
covariates=["cov1"],
125+
freq="h"
126+
)
127+
128+
# Verify error message indicates which series have incorrect lengths
129+
with pytest.raises(ValueError, match=r"future_df must contain prediction_length=5 values for each series.*different lengths"):
130+
validate_df_inputs(
131+
df=df,
132+
future_df=future_df,
133+
target_columns=["target"],
134+
prediction_length=5,
135+
)
136+
137+
138+
# Tests for convert_df_input_to_list_of_dicts_input function
139+
140+
141+
def test_convert_df_with_single_target_preserves_values():
142+
"""Test conversion with single target column."""
143+
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
144+
145+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
146+
df=df,
147+
future_df=None,
148+
target_columns=["target"],
149+
prediction_length=5,
150+
)
151+
152+
# Verify output list has correct length (one per series)
153+
assert len(inputs) == 2
154+
155+
# Verify target arrays have correct shape and values match input
156+
assert inputs[0]["target"].shape == (1, 10) # (n_targets=1, n_timesteps=10)
157+
assert inputs[1]["target"].shape == (1, 12) # (n_targets=1, n_timesteps=12)
158+
159+
# Verify values are preserved
160+
df_sorted = df.sort_values(["item_id", "timestamp"])
161+
np.testing.assert_array_almost_equal(inputs[0]["target"][0], df_sorted[df_sorted["item_id"] == "A"]["target"].values)
162+
np.testing.assert_array_almost_equal(inputs[1]["target"][0], df_sorted[df_sorted["item_id"] == "B"]["target"].values)
163+
164+
165+
def test_convert_df_with_multiple_targets_preserves_values_and_shape():
166+
"""Test conversion with multiple target columns."""
167+
df = create_df(series_ids=["A", "B"], n_points=[10, 14], target_cols=["target1", "target2"], freq="h")
168+
169+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
170+
df=df,
171+
future_df=None,
172+
target_columns=["target1", "target2"],
173+
prediction_length=5,
174+
)
175+
176+
# Verify target arrays have shape (n_targets, n_timesteps)
177+
assert inputs[0]["target"].shape == (2, 10)
178+
assert inputs[1]["target"].shape == (2, 14)
179+
180+
# Verify all target values are preserved for both series
181+
df_sorted = df.sort_values(["item_id", "timestamp"])
182+
for i, series_id in enumerate(["A", "B"]):
183+
series_data = df_sorted[df_sorted["item_id"] == series_id]
184+
np.testing.assert_array_almost_equal(inputs[i]["target"][0], series_data["target1"].values)
185+
np.testing.assert_array_almost_equal(inputs[i]["target"][1], series_data["target2"].values)
186+
187+
188+
def test_convert_df_with_past_covariates_includes_them_in_output():
189+
"""Test conversion with past covariates only."""
190+
df = create_df(series_ids=["A", "B"], n_points=[10, 16], target_cols=["target"], covariates=["cov1", "cov2"], freq="h")
191+
192+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
193+
df=df,
194+
future_df=None,
195+
target_columns=["target"],
196+
prediction_length=5,
197+
)
198+
199+
# Verify output includes past_covariates dictionary
200+
assert "past_covariates" in inputs[0]
201+
assert "cov1" in inputs[0]["past_covariates"]
202+
assert "cov2" in inputs[0]["past_covariates"]
203+
204+
# Verify covariate values match input for both series
205+
assert inputs[0]["past_covariates"]["cov1"].shape == (10,)
206+
assert inputs[0]["past_covariates"]["cov2"].shape == (10,)
207+
assert inputs[1]["past_covariates"]["cov1"].shape == (16,)
208+
assert inputs[1]["past_covariates"]["cov2"].shape == (16,)
209+
210+
# Verify no future_covariates key in output
211+
assert "future_covariates" not in inputs[0]
212+
213+
214+
def test_convert_df_with_past_and_future_covariates_includes_both():
215+
"""Test conversion with both past and future covariates."""
216+
df = create_df(series_ids=["A", "B"], n_points=[10, 18], target_cols=["target"], covariates=["cov1"], freq="h")
217+
218+
forecast_start_times = get_forecast_start_times(df, freq="h")
219+
future_df = create_future_df(
220+
forecast_start_times=forecast_start_times,
221+
series_ids=["A", "B"],
222+
n_points=[5, 5],
223+
covariates=["cov1"],
224+
freq="h"
225+
)
226+
227+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
228+
df=df,
229+
future_df=future_df,
230+
target_columns=["target"],
231+
prediction_length=5,
232+
)
233+
234+
# Verify output includes both past_covariates and future_covariates dictionaries for both series
235+
assert "past_covariates" in inputs[0]
236+
assert "future_covariates" in inputs[0]
237+
assert "past_covariates" in inputs[1]
238+
assert "future_covariates" in inputs[1]
239+
240+
# Verify all covariate values are preserved with correct shapes
241+
assert inputs[0]["past_covariates"]["cov1"].shape == (10,)
242+
assert inputs[0]["future_covariates"]["cov1"].shape == (5,)
243+
assert inputs[1]["past_covariates"]["cov1"].shape == (18,)
244+
assert inputs[1]["future_covariates"]["cov1"].shape == (5,)
245+
246+
247+
@pytest.mark.parametrize("freq", ["s", "min", "30min", "h", "D", "W", "ME", "QE", "YE"])
248+
def test_convert_df_generates_prediction_timestamps_with_correct_frequency(freq):
249+
"""Test that prediction timestamps follow the inferred frequency."""
250+
# Use multiple series with irregular lengths
251+
df = create_df(series_ids=["A", "B", "C"], n_points=[10, 15, 12], target_cols=["target"], freq=freq)
252+
253+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
254+
df=df,
255+
future_df=None,
256+
target_columns=["target"],
257+
prediction_length=5,
258+
)
259+
260+
# Verify timestamps for all series
261+
for series_id in ["A", "B", "C"]:
262+
# Verify timestamps start after last context timestamp
263+
last_context_time = df[df["item_id"] == series_id]["timestamp"].max()
264+
first_pred_time = prediction_timestamps[series_id][0]
265+
assert first_pred_time > last_context_time
266+
267+
# Verify timestamps are evenly spaced according to frequency
268+
pred_times = prediction_timestamps[series_id]
269+
assert len(pred_times) == 5
270+
inferred_freq = pd.infer_freq(pred_times)
271+
assert inferred_freq is not None
272+
273+
274+
def test_convert_df_skips_validation_when_disabled():
275+
"""Test that validate_inputs=False skips validation."""
276+
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], freq="h")
277+
278+
# Mock validate_df_inputs to verify it's not called when validation is disabled
279+
with patch("chronos.df_utils.validate_df_inputs") as mock_validate:
280+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
281+
df=df,
282+
future_df=None,
283+
target_columns=["target"],
284+
prediction_length=5,
285+
validate_inputs=False,
286+
)
287+
288+
# Verify validate_df_inputs was not called
289+
mock_validate.assert_not_called()
290+
291+
# Verify conversion still works
292+
assert len(inputs) == 2
293+
294+
295+
def test_convert_df_preserves_all_values_with_random_inputs():
296+
"""Generate random dataframe and verify all values are preserved exactly."""
297+
# Generate random parameters
298+
n_series = np.random.randint(2, 5)
299+
n_targets = np.random.randint(1, 4)
300+
n_past_only_covariates = np.random.randint(1, 3)
301+
n_future_covariates = np.random.randint(1, 3)
302+
prediction_length = 5
303+
304+
series_ids = [f"series_{i}" for i in range(n_series)]
305+
n_points = [np.random.randint(10, 20) for _ in range(n_series)]
306+
target_cols = [f"target_{i}" for i in range(n_targets)]
307+
past_only_covariates = [f"past_cov_{i}" for i in range(n_past_only_covariates)]
308+
future_covariates = [f"future_cov_{i}" for i in range(n_future_covariates)]
309+
all_covariates = past_only_covariates + future_covariates
310+
311+
# Create dataframe with all covariates
312+
df = create_df(series_ids=series_ids, n_points=n_points, target_cols=target_cols, covariates=all_covariates, freq="h")
313+
314+
# Create future_df with only future covariates (not past-only ones)
315+
forecast_start_times = get_forecast_start_times(df, freq="h")
316+
future_df = create_future_df(
317+
forecast_start_times=forecast_start_times,
318+
series_ids=series_ids,
319+
n_points=[prediction_length] * n_series,
320+
covariates=future_covariates,
321+
freq="h"
322+
)
323+
324+
# Convert to list-of-dicts format
325+
inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
326+
df=df,
327+
future_df=future_df,
328+
target_columns=target_cols,
329+
prediction_length=prediction_length,
330+
)
331+
332+
# Verify all target values are preserved exactly
333+
df_sorted = df.sort_values(["item_id", "timestamp"])
334+
for i, series_id in enumerate(series_ids):
335+
series_data = df_sorted[df_sorted["item_id"] == series_id]
336+
assert inputs[i]["target"].shape == (n_targets, n_points[i])
337+
338+
for j, target_col in enumerate(target_cols):
339+
np.testing.assert_array_almost_equal(inputs[i]["target"][j], series_data[target_col].values)
340+
341+
# Verify all past covariate values are preserved (both past-only and future covariates)
342+
for i, series_id in enumerate(series_ids):
343+
series_data = df_sorted[df_sorted["item_id"] == series_id]
344+
assert "past_covariates" in inputs[i]
345+
for cov in all_covariates:
346+
np.testing.assert_array_almost_equal(inputs[i]["past_covariates"][cov], series_data[cov].values)
347+
348+
# Verify only future covariates are in future_covariates (not past-only ones)
349+
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
350+
for i, series_id in enumerate(series_ids):
351+
series_future_data = future_df_sorted[future_df_sorted["item_id"] == series_id]
352+
assert "future_covariates" in inputs[i]
353+
# Only future covariates should be present
354+
assert set(inputs[i]["future_covariates"].keys()) == set(future_covariates)
355+
for cov in future_covariates:
356+
np.testing.assert_array_almost_equal(inputs[i]["future_covariates"][cov], series_future_data[cov].values)
357+
358+
# Verify output structure is correct
359+
assert len(inputs) == n_series
360+
assert list(original_order) == series_ids
361+
assert len(prediction_timestamps) == n_series

0 commit comments

Comments
 (0)