Skip to content

Commit c50fed9

Browse files
authored
Speed up predict_df (#437)
*Issue #, if available:* *Description of changes:* - Remove for-loop with numpy operations + single pd.DataFrame construction By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent efb86e0 commit c50fed9

File tree

2 files changed

+44
-38
lines changed

2 files changed

+44
-38
lines changed

src/chronos/base.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -218,22 +218,26 @@ def predict_df(
218218
quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
219219
mean_np = mean.numpy() # [n_series, horizon]
220220

221-
results_dfs = []
222-
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
223-
q_pred = quantiles_np[i] # (horizon, num_quantiles)
224-
point_pred = mean_np[i] # (horizon)
225-
226-
series_forecast_data = {id_column: series_id, timestamp_column: future_ts, "target_name": target}
227-
series_forecast_data["predictions"] = point_pred
228-
for q_idx, q_level in enumerate(quantile_levels):
229-
series_forecast_data[str(q_level)] = q_pred[:, q_idx]
230-
231-
results_dfs.append(pd.DataFrame(series_forecast_data))
232-
233-
predictions_df = pd.concat(results_dfs, ignore_index=True)
234-
predictions_df.set_index(id_column, inplace=True)
235-
predictions_df = predictions_df.loc[original_order]
236-
predictions_df.reset_index(inplace=True)
221+
series_ids = list(prediction_timestamps.keys())
222+
future_ts = list(prediction_timestamps.values())
223+
224+
data = {
225+
id_column: np.repeat(series_ids, prediction_length),
226+
timestamp_column: np.concatenate(future_ts),
227+
"target_name": target,
228+
"predictions": mean_np.ravel(),
229+
}
230+
231+
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
232+
for q_idx, q_level in enumerate(quantile_levels):
233+
data[str(q_level)] = quantiles_flat[:, q_idx]
234+
235+
predictions_df = pd.DataFrame(data)
236+
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
237+
if validate_inputs:
238+
predictions_df.set_index(id_column, inplace=True)
239+
predictions_df = predictions_df.loc[original_order]
240+
predictions_df.reset_index(inplace=True)
237241

238242
return predictions_df
239243

src/chronos/chronos2/pipeline.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable
12+
from typing import TYPE_CHECKING, Callable, Literal, Mapping, Sequence
1313

1414
import numpy as np
1515
import torch
@@ -914,27 +914,29 @@ def predict_df(
914914
quantiles_np = torch.stack(quantiles).numpy() # [n_tasks, n_variates, horizon, num_quantiles]
915915
mean_np = torch.stack(mean).numpy() # [n_tasks, n_variates, horizon]
916916

917-
results_dfs = []
918-
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
919-
q_pred = quantiles_np[i] # (n_variates, prediction_length, len(quantile_levels))
920-
point_pred = mean_np[i] # (n_variates, prediction_length)
921-
922-
for target_idx, target_col in enumerate(target):
923-
series_forecast_data: dict[str | tuple[str, str], Any] = {
924-
id_column: series_id,
925-
timestamp_column: future_ts,
926-
"target_name": target_col,
927-
}
928-
series_forecast_data["predictions"] = point_pred[target_idx]
929-
for q_idx, q_level in enumerate(quantile_levels):
930-
series_forecast_data[str(q_level)] = q_pred[target_idx, :, q_idx]
931-
932-
results_dfs.append(pd.DataFrame(series_forecast_data))
933-
934-
predictions_df = pd.concat(results_dfs, ignore_index=True)
935-
predictions_df.set_index(id_column, inplace=True)
936-
predictions_df = predictions_df.loc[original_order]
937-
predictions_df.reset_index(inplace=True)
917+
n_tasks = len(prediction_timestamps)
918+
n_variates = len(target)
919+
920+
series_ids = list(prediction_timestamps.keys())
921+
future_ts = list(prediction_timestamps.values())
922+
923+
data = {
924+
id_column: np.repeat(series_ids, n_variates * prediction_length),
925+
timestamp_column: np.concatenate([np.tile(ts, n_variates) for ts in future_ts]),
926+
"target_name": np.tile(np.repeat(target, prediction_length), n_tasks),
927+
"predictions": mean_np.ravel(),
928+
}
929+
930+
quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels))
931+
for q_idx, q_level in enumerate(quantile_levels):
932+
data[str(q_level)] = quantiles_flat[:, q_idx]
933+
934+
predictions_df = pd.DataFrame(data)
935+
# If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required
936+
if validate_inputs:
937+
predictions_df.set_index(id_column, inplace=True)
938+
predictions_df = predictions_df.loc[original_order]
939+
predictions_df.reset_index(inplace=True)
938940

939941
return predictions_df
940942

0 commit comments

Comments
 (0)