Skip to content

Commit f565636

Browse files
committed
Modify and complete test_forecast_ar.py and test_clean_apple.py so that all functions are tested and number of assert statements are reduced.
1 parent 1d3a3cf commit f565636

File tree

3 files changed

+78
-11
lines changed

3 files changed

+78
-11
lines changed

src/lennart_epp/data_management/clean_apple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22

33

4-
def _select_and_rename_column(df: pd.DataFrame) -> pd.DataFrame:
4+
def _rename_column(df: pd.DataFrame) -> pd.DataFrame:
55
"""Select the 'Close' column and rename it to 'close_price'.
66
77
Args:
@@ -92,7 +92,7 @@ def clean_apple_data(df: pd.DataFrame) -> pd.DataFrame:
9292
_validate_dataframe(df)
9393

9494
df = _convert_to_datetime(df)
95-
df = _select_and_rename_column(df)
95+
df = _rename_column(df)
9696
df = _handle_missing_values(df)
9797
df = _remove_duplicates(df)
9898
df = _convert_to_numeric(df)

tests/analysis/test_forecast_ar.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,45 @@ def test_integrated_coefficients():
1919
return pd.DataFrame({"coefficient": coeffs})
2020

2121

22-
def test_forecast_ar_multi_step(test_dataframe, test_integrated_coefficients):
22+
def test_forecast_ar_multi_step_output_type(
23+
test_dataframe, test_integrated_coefficients
24+
):
25+
"""Test that forecast_ar_multi_step returns a pandas Series."""
2326
forecast_steps = 10
2427
forecasts = forecast_ar_multi_step(
2528
test_dataframe, test_integrated_coefficients, forecast_steps
2629
)
2730

2831
assert isinstance(forecasts, pd.Series)
32+
33+
34+
def test_forecast_ar_multi_step_length(test_dataframe, test_integrated_coefficients):
35+
"""Test that the forecast length matches the requested number of steps."""
36+
forecast_steps = 10
37+
forecasts = forecast_ar_multi_step(
38+
test_dataframe, test_integrated_coefficients, forecast_steps
39+
)
40+
2941
assert len(forecasts) == forecast_steps
42+
43+
44+
def test_forecast_ar_multi_step_no_nans(test_dataframe, test_integrated_coefficients):
45+
"""Test that the forecast does not contain NaN values."""
46+
forecast_steps = 10
47+
forecasts = forecast_ar_multi_step(
48+
test_dataframe, test_integrated_coefficients, forecast_steps
49+
)
50+
3051
assert not forecasts.isna().any()
3152

53+
54+
def test_forecast_ar_multi_step_index(test_dataframe, test_integrated_coefficients):
55+
"""Test that the forecast index matches the expected date range."""
56+
forecast_steps = 10
57+
forecasts = forecast_ar_multi_step(
58+
test_dataframe, test_integrated_coefficients, forecast_steps
59+
)
60+
3261
expected_index_start = test_dataframe.index[-344]
3362
expected_index = pd.date_range(
3463
start=expected_index_start, periods=forecast_steps, freq="D"

tests/data_management/test_clean_apple.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
_convert_to_numeric,
88
_handle_missing_values,
99
_remove_duplicates,
10-
_select_and_rename_column,
10+
_rename_column,
1111
_validate_dataframe,
12+
clean_apple_data,
1213
)
1314

1415

@@ -29,7 +30,8 @@ def raw_data():
2930

3031

3132
def test_select_and_rename_column(raw_data):
32-
result = _select_and_rename_column(raw_data)
33+
"""Test that the 'Close' column is correctly selected and renamed 'close_price'."""
34+
result = _rename_column(raw_data)
3335
assert list(result.columns) == ["close_price"]
3436
pd.testing.assert_series_equal(
3537
result["close_price"], raw_data["Close"], check_names=False
@@ -40,22 +42,30 @@ def test_select_and_rename_column(raw_data):
4042

4143

4244
def test_handle_missing_values(raw_data):
45+
"""Test that missing values are handled correctly."""
4346
result = _handle_missing_values(raw_data)
44-
assert result.isna().sum().sum() == 0
45-
assert result.loc[1, "Close"] == expected_value_missing
47+
assert all(
48+
[
49+
result.isna().sum().sum() == 0,
50+
result.loc[1, "Close"] == expected_value_missing,
51+
]
52+
)
4653

4754

4855
expected_length_index = 4
4956

5057

5158
def test_remove_duplicates(raw_data):
59+
"""Test that duplicate indices are removed, ensuring a unique datetime index."""
5260
df = raw_data.copy().set_index("Date")
5361
result = _remove_duplicates(df)
54-
assert result.index.duplicated().sum() == 0
55-
assert len(result) == expected_length_index
62+
assert all(
63+
[result.index.duplicated().sum() == 0, len(result) == expected_length_index]
64+
)
5665

5766

5867
def test_convert_to_numeric(raw_data):
68+
"""Test that all DataFrame columns are converted to numeric data types."""
5969
df = raw_data.copy().astype(str)
6070
result = _convert_to_numeric(df)
6171
for col in result.columns:
@@ -64,6 +74,7 @@ def test_convert_to_numeric(raw_data):
6474

6575

6676
def test_validate_dataframe(raw_data):
77+
"""Test that the DataFrame validation correctly checks for the 'Close' column."""
6778
_validate_dataframe(raw_data)
6879
df_missing = raw_data.drop(columns=["Close"])
6980
with pytest.raises(
@@ -73,6 +84,33 @@ def test_validate_dataframe(raw_data):
7384

7485

7586
def test_convert_to_datetime(raw_data):
87+
"""Test that the 'Date' column is correctly converted to a datetime index."""
7688
result = _convert_to_datetime(raw_data.copy())
77-
assert isinstance(result.index, pd.DatetimeIndex)
78-
assert result.index[0] == pd.Timestamp("2022-01-01")
89+
assert all(
90+
[
91+
isinstance(result.index, pd.DatetimeIndex),
92+
result.index[0] == pd.Timestamp("2022-01-01"),
93+
]
94+
)
95+
96+
97+
def test_clean_apple_data_structure(raw_data):
98+
"""Test that the cleaned DataFrame maintains the correct structure."""
99+
result = clean_apple_data(raw_data)
100+
assert isinstance(result, pd.DataFrame)
101+
assert "close_price" in result.columns
102+
103+
104+
def test_clean_apple_data_no_missing_values(raw_data):
105+
"""Test that the cleaned DataFrame contains no missing values."""
106+
result = clean_apple_data(raw_data)
107+
assert result.isna().sum().sum() == 0
108+
109+
110+
def test_clean_apple_data_validates_input():
111+
"""Test that a ValueError is raised if 'Close' column is missing."""
112+
df_missing = pd.DataFrame({"Date": ["2022-01-01"], "Open": [21]})
113+
with pytest.raises(
114+
ValueError, match="The DataFrame does not contain a 'Close' column."
115+
):
116+
clean_apple_data(df_missing)

0 commit comments

Comments
 (0)