Skip to content

Commit f4a622c

Browse files
More detailed validation of the downloaded data in download_data.py. For this reason, a new validation in backtest_signals.py is not required. The cumbersome multi-index is also removed early on.
1 parent 33796c1 commit f4a622c

File tree

6 files changed

+305
-122
lines changed

6 files changed

+305
-122
lines changed

src/backtest_bay/backtest/backtest_signals.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Clo
2828
- 'cash': Cash.
2929
- 'assets': Portfolio value (cash + holdings).
3030
"""
31+
# Note that the inputs 'data' and 'signals' are already validated in
32+
# 'download_data.py' and 'generate_signals.py'
3133
_validate_backtest_signals_input(data, initial_cash, tac, trade_pct, price_col)
3234

3335
prices = data[price_col].squeeze()
@@ -159,36 +161,10 @@ def _update_portfolio(cash, shares, price):
159161

160162
def _validate_backtest_signals_input(data, initial_cash, tac, trade_pct, price_col):
161163
"""Validates input for backtesting signals."""
162-
# Since the variable 'signals' is generated using 'generate_signals', we already
163-
# validated the input in the corresponding test script of 'generate_signals'.
164-
_validate_data(data, price_col)
165164
_validate_initial_cash(initial_cash)
166165
_validate_tac(tac)
167166
_validate_trade_pct(trade_pct)
168-
169-
170-
def _validate_data(data, price_col):
171-
"""Validate the input data for backtesting.
172-
173-
Args:
174-
data (pd.DataFrame): DataFrame containing stock data.
175-
price_col (str): Column name for the stock price.
176-
177-
Raises:
178-
TypeError: If data is not a pandas DataFrame.
179-
ValueError: If the price column is missing or contains non-numeric values.
180-
"""
181-
if not isinstance(data, pd.DataFrame):
182-
error_msg = f"data must be a pandas DataFrame, got {type(data).__name__}."
183-
raise TypeError(error_msg)
184-
185-
if price_col not in data.columns:
186-
error_msg = f"data must contain a '{price_col}' column."
187-
raise ValueError(error_msg)
188-
189-
if not pd.api.types.is_numeric_dtype(data[price_col].squeeze()):
190-
error_msg = f"The '{price_col}' column must contain numeric values."
191-
raise ValueError(error_msg)
167+
_validate_price_col(data, price_col)
192168

193169

194170
def _validate_initial_cash(initial_cash):
@@ -245,3 +221,18 @@ def _validate_trade_pct(trade_pct):
245221
if not (0 < trade_pct <= 1):
246222
error_msg = "trade_pct must be between 0 and 1. Zero is not possible."
247223
raise ValueError(error_msg)
224+
225+
226+
def _validate_price_col(data, price_col):
227+
"""Validate the input price_col for backtesting.
228+
229+
Args:
230+
data (pd.DataFrame): DataFrame containing stock data.
231+
price_col (str): Column name for the stock price.
232+
233+
Raises:
234+
ValueError: If the price column is missing.
235+
"""
236+
if price_col not in data.columns:
237+
error_msg = f"data must contain a '{price_col}' column."
238+
raise ValueError(error_msg)

src/backtest_bay/backtest/generate_signals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def generate_signals(data, method, **kwargs):
2929
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
3030
"""
3131
_validate_input_method(method)
32-
closing_prices = data["Close"].squeeze()
32+
closing_prices = data["Close"]
3333

3434
if method == "bollinger":
3535
signal = _bollinger_signals(prices=closing_prices, **kwargs)

src/backtest_bay/backtest/task_backtest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,4 @@ def _merge_stock_data_with_portfolio(data, portfolio):
5555
Returns:
5656
pd.DataFrame: Merged DataFrame.
5757
"""
58-
data.columns = data.columns.droplevel(1)
5958
return data.merge(portfolio, how="left", left_index=True, right_index=True)

src/backtest_bay/data/download_data.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
from datetime import datetime
22

3+
import pandas as pd
34
import yfinance as yf
45

56

67
def download_data(symbol, interval, start_date, end_date):
8+
"""Download historical stock data and validate it.
9+
10+
This function downloads historical stock data for a given symbol, date range,
11+
and interval using the yfinance library. The input parameters and the downloaded
12+
data are validated to ensure they meet the required criteria.
13+
14+
Args:
15+
symbol (str): Stock symbol to download data for (e.g., 'AAPL' for Apple).
16+
interval (str): Data interval (e.g., '1d' for daily, '1h' for hourly).
17+
start_date (str): Start date for the data in 'YYYY-MM-DD' format.
18+
end_date (str): End date for the data in 'YYYY-MM-DD' format.
19+
20+
Returns:
21+
pd.DataFrame: DataFrame containing the downloaded stock data.
22+
"""
723
_validate_input(symbol, interval, start_date, end_date)
824
data = yf.download(symbol, start=start_date, end=end_date, interval=interval)
9-
_validate_output(data, symbol, start_date, end_date, interval)
25+
_validate_data(data, symbol, start_date, end_date, interval)
26+
data.columns = _remove_multiindex_from_cols(data.columns)
1027
return data
1128

1229

1330
def _validate_input(symbol, interval, start_date, end_date):
31+
"""Validate symbol, interval, and date inputs."""
1432
_validate_symbol(symbol)
1533
_validate_interval(interval)
1634
_validate_date_format(start_date)
@@ -19,13 +37,15 @@ def _validate_input(symbol, interval, start_date, end_date):
1937

2038

2139
def _validate_symbol(symbol):
40+
"""Check if symbol is a non-empty string."""
2241
is_symbol_string = isinstance(symbol, str)
2342
if not is_symbol_string:
2443
error_msg = "Symbol must be a non-empty string."
2544
raise TypeError(error_msg)
2645

2746

2847
def _validate_interval(interval):
48+
"""Validate if interval is within the allowed set."""
2949
valid_intervals = {
3050
"1m",
3151
"2m",
@@ -47,6 +67,7 @@ def _validate_interval(interval):
4767

4868

4969
def _validate_date_format(date_str):
70+
"""Check if date string is in 'YYYY-MM-DD' format."""
5071
if not isinstance(date_str, str):
5172
error_msg = "Date must be a string in 'YYYY-MM-DD' format."
5273
raise TypeError(error_msg)
@@ -58,15 +79,111 @@ def _validate_date_format(date_str):
5879

5980

6081
def _validate_date_range(start_date, end_date):
82+
"""Ensure start date is before end date."""
6183
if start_date > end_date:
6284
error_msg = "Start date must be before end date."
6385
raise ValueError(error_msg)
6486

6587

66-
def _validate_output(data, symbol, start_date, end_date, interval):
88+
def _validate_data(data, symbol, start_date, end_date, interval):
89+
"""Validate the downloaded data.
90+
91+
Args:
92+
data (pd.DataFrame): DataFrame containing stock data.
93+
symbol (str): Stock symbol for the data.
94+
start_date (str): Start date for the data.
95+
end_date (str): End date for the data.
96+
interval (str): Interval for the data.
97+
98+
Raises:
99+
TypeError: If data is not a pandas DataFrame or index is not a DatetimeIndex.
100+
ValueError: If required columns are missing or contain non-numeric values.
101+
ValueError: If the DataFrame is empty.
102+
"""
103+
_validate_data_type_dataframe(data)
104+
_validate_data_empty(data, symbol, start_date, end_date, interval)
105+
_validate_data_index_datetime(data.index)
106+
_validate_data_multiindex(data.columns)
107+
_validate_data_numeric(data)
108+
109+
110+
def _validate_data_type_dataframe(data):
111+
"""Check if the input is a pandas DataFrame."""
112+
if not isinstance(data, pd.DataFrame):
113+
error_msg = f"data must be a pandas DataFrame, got {type(data).__name__}."
114+
raise TypeError(error_msg)
115+
116+
117+
def _validate_data_empty(data, symbol, start_date, end_date, interval):
118+
"""Check if the DataFrame is empty."""
67119
if data.empty:
68120
error_msg = (
69121
f"No data found for {symbol} between {start_date} and {end_date} "
70122
f"with interval '{interval}'."
71123
)
72124
raise ValueError(error_msg)
125+
126+
127+
def _validate_data_index_datetime(index):
128+
"""Check if the index is of type DatetimeIndex."""
129+
if not isinstance(index, pd.DatetimeIndex):
130+
error_msg = (
131+
f"data index must be a pandas DatetimeIndex, got {type(index).__name__}."
132+
)
133+
raise TypeError(error_msg)
134+
135+
136+
def _validate_data_multiindex(columns):
137+
"""Check if the columns have the required MultiIndex.
138+
139+
Args:
140+
columns (pd.MultiIndex): MultiIndex of DataFrame columns.
141+
142+
Raises:
143+
ValueError: If the MultiIndex is not present, does not have exactly two levels,
144+
or if required columns are missing from level 0.
145+
"""
146+
required_cols = {"Close", "Open", "High", "Low"}
147+
148+
if not isinstance(columns, pd.MultiIndex):
149+
error_msg = "DataFrame must have a MultiIndex for columns."
150+
raise TypeError(error_msg)
151+
152+
yfinance_index_levels = 2
153+
if columns.nlevels != yfinance_index_levels:
154+
error_msg = (
155+
f"MultiIndex must have exactly 2 levels, got {columns.nlevels} levels."
156+
)
157+
raise ValueError(error_msg)
158+
159+
level_0_values = set(columns.get_level_values(0))
160+
missing_cols = required_cols - level_0_values
161+
if missing_cols:
162+
error_msg = (
163+
"Level 0 of MultiIndex must contain the following columns: "
164+
f"{', '.join(missing_cols)}."
165+
)
166+
raise ValueError(error_msg)
167+
168+
169+
def _validate_data_numeric(data):
170+
"""Check if the required columns contain numeric values."""
171+
required_cols = ["Close", "Open", "High", "Low"]
172+
non_numeric_cols = [
173+
col
174+
for col in required_cols
175+
if not pd.api.types.is_numeric_dtype(data[col].squeeze())
176+
]
177+
178+
if non_numeric_cols:
179+
error_msg = (
180+
"The following columns must contain numeric values: "
181+
f"{', '.join(non_numeric_cols)}."
182+
)
183+
raise ValueError(error_msg)
184+
185+
186+
def _remove_multiindex_from_cols(cols):
187+
"""Remove MultiIndex from columns but retain level 0 as column names."""
188+
cols = cols.get_level_values(0)
189+
return cols

tests/backtest/test_backtest_signals.py

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
_is_buy_trade_affordable,
99
_is_sell_trade_affordable,
1010
_update_portfolio,
11-
_validate_data,
1211
_validate_initial_cash,
12+
_validate_price_col,
1313
_validate_tac,
1414
_validate_trade_pct,
1515
backtest_signals,
@@ -151,67 +151,6 @@ def test_update_portfolio_correct_calculation():
151151
assert holdings == expected_holdings
152152

153153

154-
# Tests for _validate_data
155-
@pytest.mark.parametrize(
156-
("data", "price_col"),
157-
[
158-
(pd.DataFrame({"Close": [100, 101, 102]}), "Close"),
159-
(pd.DataFrame({"Open": [99, 100, 101], "Close": [100, 101, 102]}), "Close"),
160-
(pd.DataFrame({"Close": [100.5, 101.2, 102.1, 103.8]}), "Close"),
161-
],
162-
)
163-
def test_validate_data_valid_input(data, price_col):
164-
"""Test valid data input for _validate_data."""
165-
_validate_data(data, price_col)
166-
167-
168-
@pytest.mark.parametrize(
169-
("data", "price_col", "expected_error"),
170-
[
171-
([100, 101], "Close", "data must be a pandas DataFrame, got list."),
172-
({"Close": [100, 101]}, "Close", "data must be a pandas DataFrame, got dict."),
173-
],
174-
)
175-
def test_validate_data_invalid_type(data, price_col, expected_error):
176-
"""Test invalid data types for _validate_data."""
177-
with pytest.raises(TypeError, match=expected_error):
178-
_validate_data(data, price_col)
179-
180-
181-
@pytest.mark.parametrize(
182-
("data", "price_col", "expected_error"),
183-
[
184-
(pd.DataFrame({"Open": [100]}), "Close", "data must contain a 'Close' column."),
185-
(pd.DataFrame({"Close": [100]}), "Open", "data must contain a 'Open' column."),
186-
],
187-
)
188-
def test_validate_data_missing_price_column(data, price_col, expected_error):
189-
"""Test missing price column for _validate_data."""
190-
with pytest.raises(ValueError, match=expected_error):
191-
_validate_data(data, price_col)
192-
193-
194-
@pytest.mark.parametrize(
195-
("data", "price_col", "expected_error"),
196-
[
197-
(
198-
pd.DataFrame({"Close": ["100", 100]}),
199-
"Close",
200-
"The 'Close' column must contain numeric values.",
201-
),
202-
(
203-
pd.DataFrame({"Close": ["a", "b", "c", "d"]}),
204-
"Close",
205-
"The 'Close' column must contain numeric values.",
206-
),
207-
],
208-
)
209-
def test_validate_data_non_numeric_price_column(data, price_col, expected_error):
210-
"""Test non-numeric price column for _validate_data."""
211-
with pytest.raises(ValueError, match=expected_error):
212-
_validate_data(data, price_col)
213-
214-
215154
# Tests for _validate_initial_cash
216155
@pytest.mark.parametrize("initial_cash", [1000, 1000.50, 0.01])
217156
def test_validate_initial_cash_valid_input(initial_cash):
@@ -272,3 +211,17 @@ def test_validate_trade_pct_invalid_input(trade_pct, expected_error):
272211
"""Test invalid trade_pct values."""
273212
with pytest.raises((TypeError, ValueError), match=expected_error):
274213
_validate_trade_pct(trade_pct)
214+
215+
216+
# Tests for _validate_price_col
217+
@pytest.mark.parametrize(
218+
("data", "price_col"),
219+
[
220+
(pd.DataFrame({"Close": [100, 101, 102]}), "Close"),
221+
(pd.DataFrame({"Open": [99, 100, 101], "Close": [100, 101, 102]}), "Close"),
222+
(pd.DataFrame({"Close": [100.5, 101.2, 102.1, 103.8]}), "Close"),
223+
],
224+
)
225+
def test_validate_data_valid_input(data, price_col):
226+
"""Test valid data input for _validate_data."""
227+
_validate_price_col(data, price_col)

0 commit comments

Comments
 (0)