Skip to content

Commit

Permalink
Speed up metrics computation by optimizing segment validation (#1338)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Aug 1, 2023
1 parent cceb500 commit ddc1711
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
-
- Add sorting by timestamp before the fit in `CatBoostPerSegmentModel` and `CatBoostMultiSegmentModel` ([#1337](https://github.com/tinkoff-ai/etna/pull/1337))
- Speed up metrics computation by optimizing segment validation, forbid NaNs during metrics computation ([#1338](https://github.com/tinkoff-ai/etna/pull/1338))
- Unify errors, warnings and checks in models ([#1312](https://github.com/tinkoff-ai/etna/pull/1312))
- Remove upper limitation on version of numba ([#1321](https://github.com/tinkoff-ai/etna/pull/1321))

Expand Down
97 changes: 71 additions & 26 deletions etna/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def name(self) -> str:
return self.__class__.__name__

@staticmethod
def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
"""
Check if all the segments from ``y_true`` are in ``y_pred`` and vice versa.
def _validate_segments(y_true: TSDataset, y_pred: TSDataset):
"""Check that segments in ``y_true`` and ``y_pred`` are the same.
Parameters
----------
Expand All @@ -125,9 +124,7 @@ def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
Raises
------
ValueError:
if there are mismatches in y_true and y_pred segments,
ValueError:
if one of segments in y_true or y_pred doesn't contain 'target' column.
if there are mismatches in y_true and y_pred segments
"""
segments_true = set(y_true.df.columns.get_level_values("segment"))
segments_pred = set(y_pred.df.columns.get_level_values("segment"))
Expand All @@ -144,33 +141,78 @@ def _validate_segment_columns(y_true: TSDataset, y_pred: TSDataset):
f"There are segments in y_true that are not in y_pred, for example: "
f"{', '.join(list(true_diff_pred)[:5])}"
)
for segment in segments_true:

@staticmethod
def _validate_target_columns(y_true: TSDataset, y_pred: TSDataset):
"""Check that all the segments from ``y_true`` and ``y_pred`` has 'target' column.
Parameters
----------
y_true:
y_true dataset
y_pred:
y_pred dataset
Raises
------
ValueError:
if one of segments in y_true or y_pred doesn't contain 'target' column.
"""
segments = set(y_true.df.columns.get_level_values("segment"))

for segment in segments:
for name, dataset in zip(("y_true", "y_pred"), (y_true, y_pred)):
if "target" not in dataset.loc[:, segment].columns:
if (segment, "target") not in dataset.columns:
raise ValueError(
f"All the segments in {name} should contain 'target' column. Segment {segment} doesn't."
)

@staticmethod
def _validate_timestamp_columns(timestamp_true: pd.Series, timestamp_pred: pd.Series):
"""
Check that ``y_true`` and ``y_pred`` have the same timestamp.
def _validate_index(y_true: TSDataset, y_pred: TSDataset):
"""Check that ``y_true`` and ``y_pred`` have the same timestamps.
Parameters
----------
timestamp_true:
y_true's timestamp column
timestamp_pred:
y_pred's timestamp column
y_true:
y_true dataset
y_pred:
y_pred dataset
Raises
------
ValueError:
If there are mismatches in ``y_true`` and ``y_pred`` timestamps
"""
if set(timestamp_pred) != set(timestamp_true):
if not y_true.index.equals(y_pred.index):
raise ValueError("y_true and y_pred have different timestamps")

@staticmethod
def _validate_nans(y_true: TSDataset, y_pred: TSDataset):
"""Check that ``y_true`` and ``y_pred`` doesn't have NaNs.
Parameters
----------
y_true:
y_true dataset
y_pred:
y_pred dataset
Raises
------
ValueError:
If there are NaNs in ``y_true`` or ``y_pred``
"""
df_true = y_true.df.loc[:, pd.IndexSlice[:, "target"]]
df_pred = y_pred.df.loc[:, pd.IndexSlice[:, "target"]]

df_true_isna = df_true.isna().any().any()
if df_true_isna > 0:
raise ValueError("There are NaNs in y_true")

df_pred_isna = df_pred.isna().any().any()
if df_pred_isna > 0:
raise ValueError("There are NaNs in y_pred")

@staticmethod
def _macro_average(metrics_per_segments: Dict[str, float]) -> Union[float, Dict[str, float]]:
"""
Expand Down Expand Up @@ -226,18 +268,21 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
metric's value aggregated over segments or not (depends on mode)
"""
self._log_start()
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)

df_true = y_true[:, :, "target"].sort_index(axis=1)
df_pred = y_pred[:, :, "target"].sort_index(axis=1)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
metrics_per_segment[segment] = self.metric_fn(
y_true=y_true[:, segment, "target"].values, y_pred=y_pred[:, segment, "target"].values, **self.kwargs
)
segments = df_true.columns.get_level_values("segment").unique()

for i, segment in enumerate(segments):
cur_y_true = df_true.iloc[:, i]
cur_y_pred = df_pred.iloc[:, i]
metrics_per_segment[segment] = self.metric_fn(y_true=cur_y_true, y_pred=cur_y_pred, **self.kwargs)
metrics = self._aggregate_metrics(metrics_per_segment)
return metrics

Expand Down
18 changes: 8 additions & 10 deletions etna/metrics/intervals_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,15 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
-------
metric's value aggregated over segments or not (depends on mode)
"""
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
upper_quantile_flag = y_true[:, segment, "target"] <= y_pred[:, segment, f"target_{self.quantiles[1]:.4g}"]
lower_quantile_flag = y_true[:, segment, "target"] >= y_pred[:, segment, f"target_{self.quantiles[0]:.4g}"]

Expand Down Expand Up @@ -135,16 +134,15 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
-------
metric's value aggregated over segments or not (depends on mode)
"""
self._validate_segment_columns(y_true=y_true, y_pred=y_pred)
self._validate_segments(y_true=y_true, y_pred=y_pred)
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)

segments = set(y_true.df.columns.get_level_values("segment"))
metrics_per_segment = {}
for segment in segments:
self._validate_timestamp_columns(
timestamp_true=y_true[:, segment, "target"].dropna().index,
timestamp_pred=y_pred[:, segment, "target"].dropna().index,
)
upper_quantile = y_pred[:, segment, f"target_{self.quantiles[1]:.4g}"]
lower_quantile = y_pred[:, segment, f"target_{self.quantiles[0]:.4g}"]

Expand Down
71 changes: 60 additions & 11 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -117,34 +120,60 @@ def test_metrics_invalid_aggregation(metric_class):
@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid timeranges"""
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
metric = metric_class()
with pytest.raises(ValueError, match="There are segments in .* that are not in .*"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_target_columns(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
forecast_df, true_df = train_test_dfs
columns = forecast_df.df.columns.to_list()
columns[0] = ("segment_1", "not_target")
forecast_df.df.columns = pd.MultiIndex.from_tuples(columns, names=["segment", "feature"])
metric = metric_class()
with pytest.raises(ValueError, match="All the segments in .* should contain 'target' column"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_index(metric_class, two_dfs_with_different_timestamps):
"""Check metrics behavior in case of invalid index"""
forecast_df, true_df = two_dfs_with_different_timestamps
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="y_true and y_pred have different timestamps"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets):
"""Check metrics behavior in case of invalid segments sets"""
forecast_df, true_df = two_dfs_with_different_segments_sets
def test_invalid_nans_pred(metric_class, train_test_dfs):
"""Check metrics behavior in case of nans in prediction."""
forecast_df, true_df = train_test_dfs
forecast_df.df.iloc[0, 0] = np.NaN
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="There are NaNs in y_pred"):
_ = metric(y_true=true_df, y_pred=forecast_df)


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_invalid_segments_target(metric_class, train_test_dfs):
"""Check metrics behavior in case of no target column in segment"""
def test_invalid_nans_true(metric_class, train_test_dfs):
"""Check metrics behavior in case of nans in true values."""
forecast_df, true_df = train_test_dfs
forecast_df.df.drop(columns=[("segment_1", "target")], inplace=True)
true_df.df.iloc[0, 0] = np.NaN
metric = metric_class()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="There are NaNs in y_true"):
_ = metric(y_true=true_df, y_pred=forecast_df)


Expand Down Expand Up @@ -181,6 +210,26 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs):
assert value == true_metric_value


@pytest.mark.parametrize(
"metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE)
)
def test_metric_values_with_changed_segment_order(metric_class, train_test_dfs):
forecast_df, true_df = train_test_dfs
forecast_df_new, true_df_new = deepcopy(train_test_dfs)
segments = np.array(forecast_df.segments)

forecast_segment_order = segments[[3, 2, 0, 1, 4]]
forecast_df_new.df = forecast_df_new.df.loc[:, pd.IndexSlice[forecast_segment_order, :]]
true_segment_order = segments[[4, 1, 3, 2, 0]]
true_df_new.df = true_df_new.df.loc[:, pd.IndexSlice[true_segment_order, :]]

metric = metric_class(mode="per-segment")
metrics_initial = metric(y_pred=forecast_df, y_true=true_df)
metrics_changed_order = metric(y_pred=forecast_df_new, y_true=true_df_new)

assert metrics_initial == metrics_changed_order


@pytest.mark.parametrize(
"metric, greater_is_better",
(
Expand Down

1 comment on commit ddc1711

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.