|
9 | 9 | import warnings |
10 | 10 | from copy import deepcopy |
11 | 11 | from pathlib import Path |
12 | | -from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Callable |
| 12 | +from typing import TYPE_CHECKING, Callable, Literal, Mapping, Sequence |
13 | 13 |
|
14 | 14 | import numpy as np |
15 | 15 | import torch |
@@ -914,27 +914,29 @@ def predict_df( |
914 | 914 | quantiles_np = torch.stack(quantiles).numpy() # [n_tasks, n_variates, horizon, num_quantiles] |
915 | 915 | mean_np = torch.stack(mean).numpy() # [n_tasks, n_variates, horizon] |
916 | 916 |
|
917 | | - results_dfs = [] |
918 | | - for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()): |
919 | | - q_pred = quantiles_np[i] # (n_variates, prediction_length, len(quantile_levels)) |
920 | | - point_pred = mean_np[i] # (n_variates, prediction_length) |
921 | | - |
922 | | - for target_idx, target_col in enumerate(target): |
923 | | - series_forecast_data: dict[str | tuple[str, str], Any] = { |
924 | | - id_column: series_id, |
925 | | - timestamp_column: future_ts, |
926 | | - "target_name": target_col, |
927 | | - } |
928 | | - series_forecast_data["predictions"] = point_pred[target_idx] |
929 | | - for q_idx, q_level in enumerate(quantile_levels): |
930 | | - series_forecast_data[str(q_level)] = q_pred[target_idx, :, q_idx] |
931 | | - |
932 | | - results_dfs.append(pd.DataFrame(series_forecast_data)) |
933 | | - |
934 | | - predictions_df = pd.concat(results_dfs, ignore_index=True) |
935 | | - predictions_df.set_index(id_column, inplace=True) |
936 | | - predictions_df = predictions_df.loc[original_order] |
937 | | - predictions_df.reset_index(inplace=True) |
| 917 | + n_tasks = len(prediction_timestamps) |
| 918 | + n_variates = len(target) |
| 919 | + |
| 920 | + series_ids = list(prediction_timestamps.keys()) |
| 921 | + future_ts = list(prediction_timestamps.values()) |
| 922 | + |
| 923 | + data = { |
| 924 | + id_column: np.repeat(series_ids, n_variates * prediction_length), |
| 925 | + timestamp_column: np.concatenate([np.tile(ts, n_variates) for ts in future_ts]), |
| 926 | + "target_name": np.tile(np.repeat(target, prediction_length), n_tasks), |
| 927 | + "predictions": mean_np.ravel(), |
| 928 | + } |
| 929 | + |
| 930 | + quantiles_flat = quantiles_np.reshape(-1, len(quantile_levels)) |
| 931 | + for q_idx, q_level in enumerate(quantile_levels): |
| 932 | + data[str(q_level)] = quantiles_flat[:, q_idx] |
| 933 | + |
| 934 | + predictions_df = pd.DataFrame(data) |
| 935 | + # If validate_inputs=False, the df is used as-is without sorting by item_id, no reordering required |
| 936 | + if validate_inputs: |
| 937 | + predictions_df.set_index(id_column, inplace=True) |
| 938 | + predictions_df = predictions_df.loc[original_order] |
| 939 | + predictions_df.reset_index(inplace=True) |
938 | 940 |
|
939 | 941 | return predictions_df |
940 | 942 |
|
|
0 commit comments