Skip to content

Commit deaf7f0

Browse files
Add test cases for backtest_signals.py.
1 parent 4180dae commit deaf7f0

File tree

2 files changed

+188
-93
lines changed

2 files changed

+188
-93
lines changed

src/backtest_bay/backtest/backtest_signals.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Clo
2828
- 'cash': Cash.
2929
- 'assets': Portfolio value (cash + holdings).
3030
"""
31-
_validate_backtest_signals_input(
32-
data, signals, initial_cash, tac, trade_pct, price_col
33-
)
31+
_validate_backtest_signals_input(data, initial_cash, tac, trade_pct, price_col)
3432

3533
prices = data[price_col].squeeze()
3634
cash, holdings, shares = initial_cash, 0.0, 0
@@ -159,12 +157,11 @@ def _update_portfolio(cash, shares, price):
159157
return assets, holdings
160158

161159

162-
def _validate_backtest_signals_input(
163-
data, signals, initial_cash, tac, trade_pct, price_col
164-
):
160+
def _validate_backtest_signals_input(data, initial_cash, tac, trade_pct, price_col):
165161
"""Validates input for backtesting signals."""
162+
# Since the variable 'signals' is generated using 'generate_signals', we already
163+
# validated the input in the corresponding test script of 'generate_signals'.
166164
_validate_data(data, price_col)
167-
_validate_signals(signals, data, price_col)
168165
_validate_initial_cash(initial_cash)
169166
_validate_tac(tac)
170167
_validate_trade_pct(trade_pct)
@@ -194,36 +191,6 @@ def _validate_data(data, price_col):
194191
raise ValueError(error_msg)
195192

196193

197-
def _validate_signals(signals, data, price_col):
198-
"""Validate the trading signals.
199-
200-
Args:
201-
signals (pd.Series): Trading signals for backtesting.
202-
data (pd.DataFrame): DataFrame containing stock data.
203-
price_col (str): Column name for the stock price.
204-
205-
Raises:
206-
TypeError: If signals is not a pandas Series.
207-
ValueError: If signals contain invalid values.
208-
ValueError: If signals do not match the length of the price column.
209-
"""
210-
if not isinstance(signals, pd.Series):
211-
error_msg = f"signals must be a pandas Series, got {type(signals).__name__}."
212-
raise TypeError(error_msg)
213-
214-
if not all(isinstance(signal, int) for signal in signals):
215-
error_msg = "signals must contain only integers."
216-
raise ValueError(error_msg)
217-
218-
if not all(signal in [0, 1, 2] for signal in signals):
219-
error_msg = "signals must contain only 0 (Hold), 1 (Sell), or 2 (Buy)."
220-
raise ValueError(error_msg)
221-
222-
if len(signals) != len(data[price_col]):
223-
error_msg = f"signals must have the same length as the '{price_col}' column."
224-
raise ValueError(error_msg)
225-
226-
227194
def _validate_initial_cash(initial_cash):
228195
"""Validate the initial cash value for backtesting.
229196

tests/backtest/test_backtest_signals.py

Lines changed: 184 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,82 @@
11
import pandas as pd
22
import pandas.testing as pdt
3+
import pytest
34

45
from backtest_bay.backtest.backtest_signals import (
56
_execute_buy,
67
_execute_sell,
78
_is_buy_trade_affordable,
89
_is_sell_trade_affordable,
910
_update_portfolio,
11+
_validate_data,
12+
_validate_initial_cash,
13+
_validate_tac,
14+
_validate_trade_pct,
1015
backtest_signals,
1116
)
1217

1318

19+
# tests for backtest_signals
20+
def test_backtest_portfolio_correct_calculation():
21+
"""Test backtest_portfolio for correct calculation."""
22+
index = pd.date_range("2023-01-01", periods=5, freq="D")
23+
data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index)
24+
25+
# Only hold
26+
signals = pd.Series([0, 0, 0, 0, 0])
27+
portfolio = backtest_signals(
28+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
29+
)
30+
expected_portfolio = pd.DataFrame(
31+
data={
32+
"price": [10, 5, 10, 8, 10],
33+
"signal": [0, 0, 0, 0, 0],
34+
"shares": [0, 0, 0, 0, 0],
35+
"holdings": [0, 0, 0, 0, 0],
36+
"cash": [100, 100, 100, 100, 100],
37+
"assets": [100, 100, 100, 100, 100],
38+
},
39+
index=index,
40+
)
41+
pdt.assert_frame_equal(portfolio, expected_portfolio)
42+
43+
# Buy once
44+
signals = pd.Series([0, 2, 0, 0, 0])
45+
portfolio = backtest_signals(
46+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
47+
)
48+
expected_portfolio = pd.DataFrame(
49+
data={
50+
"price": [10, 5, 10, 8, 10],
51+
"signal": [0, 2, 0, 0, 0],
52+
"shares": [0, 20, 20, 20, 20],
53+
"holdings": [0, 100, 200, 160, 200],
54+
"cash": [100, 0, 0, 0, 0],
55+
"assets": [100, 100, 200, 160, 200],
56+
},
57+
index=index,
58+
)
59+
pdt.assert_frame_equal(portfolio, expected_portfolio)
60+
61+
# Buy and sell once
62+
signals = pd.Series([0, 2, 0, 0, 1])
63+
portfolio = backtest_signals(
64+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
65+
)
66+
expected_portfolio = pd.DataFrame(
67+
data={
68+
"price": [10, 5, 10, 8, 10],
69+
"signal": [0, 2, 0, 0, 1],
70+
"shares": [0, 20, 20, 20, 4],
71+
"holdings": [0, 100, 200, 160, 40],
72+
"cash": [100, 0, 0, 0, 160],
73+
"assets": [100, 100, 200, 160, 200],
74+
},
75+
index=index,
76+
)
77+
pdt.assert_frame_equal(portfolio, expected_portfolio)
78+
79+
1480
# tests for _is_buy_affordable
1581
def test_is_buy_trade_affordable_enough_cash():
1682
"""Test buying when there is enough cash."""
@@ -91,62 +157,124 @@ def test_update_portfolio_correct_calculation():
91157
assert holdings == expected_holdings
92158

93159

94-
# tests for backtest_signals
95-
def test_backtest_portfolio_correct_calculation():
96-
"""Test backtest_portfolio for correct calculation."""
97-
index = pd.date_range("2023-01-01", periods=5, freq="D")
98-
data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index)
160+
# Tests for _validate_data
161+
@pytest.mark.parametrize(
162+
("data", "price_col"),
163+
[
164+
(pd.DataFrame({"Close": [100, 101, 102]}), "Close"),
165+
(pd.DataFrame({"Open": [99, 100, 101], "Close": [100, 101, 102]}), "Close"),
166+
(pd.DataFrame({"Close": [100.5, 101.2, 102.1, 103.8]}), "Close"),
167+
],
168+
)
169+
def test_validate_data_valid_input(data, price_col):
170+
"""Test valid data input for _validate_data."""
171+
_validate_data(data, price_col)
99172

100-
# Only hold
101-
signals = pd.Series([0, 0, 0, 0, 0])
102-
portfolio = backtest_signals(
103-
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
104-
)
105-
expected_portfolio = pd.DataFrame(
106-
data={
107-
"price": [10, 5, 10, 8, 10],
108-
"signal": [0, 0, 0, 0, 0],
109-
"shares": [0, 0, 0, 0, 0],
110-
"holdings": [0, 0, 0, 0, 0],
111-
"cash": [100, 100, 100, 100, 100],
112-
"assets": [100, 100, 100, 100, 100],
113-
},
114-
index=index,
115-
)
116-
pdt.assert_frame_equal(portfolio, expected_portfolio)
117173

118-
# Buy once
119-
signals = pd.Series([0, 2, 0, 0, 0])
120-
portfolio = backtest_signals(
121-
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
122-
)
123-
expected_portfolio = pd.DataFrame(
124-
data={
125-
"price": [10, 5, 10, 8, 10],
126-
"signal": [0, 2, 0, 0, 0],
127-
"shares": [0, 20, 20, 20, 20],
128-
"holdings": [0, 100, 200, 160, 200],
129-
"cash": [100, 0, 0, 0, 0],
130-
"assets": [100, 100, 200, 160, 200],
131-
},
132-
index=index,
133-
)
134-
pdt.assert_frame_equal(portfolio, expected_portfolio)
174+
@pytest.mark.parametrize(
175+
("data", "price_col", "expected_error"),
176+
[
177+
([100, 101], "Close", "data must be a pandas DataFrame, got list."),
178+
({"Close": [100, 101]}, "Close", "data must be a pandas DataFrame, got dict."),
179+
],
180+
)
181+
def test_validate_data_invalid_type(data, price_col, expected_error):
182+
"""Test invalid data types for _validate_data."""
183+
with pytest.raises(TypeError, match=expected_error):
184+
_validate_data(data, price_col)
135185

136-
# Buy and sell once
137-
signals = pd.Series([0, 2, 0, 0, 1])
138-
portfolio = backtest_signals(
139-
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
140-
)
141-
expected_portfolio = pd.DataFrame(
142-
data={
143-
"price": [10, 5, 10, 8, 10],
144-
"signal": [0, 2, 0, 0, 1],
145-
"shares": [0, 20, 20, 20, 4],
146-
"holdings": [0, 100, 200, 160, 40],
147-
"cash": [100, 0, 0, 0, 160],
148-
"assets": [100, 100, 200, 160, 200],
149-
},
150-
index=index,
151-
)
152-
pdt.assert_frame_equal(portfolio, expected_portfolio)
186+
187+
@pytest.mark.parametrize(
188+
("data", "price_col", "expected_error"),
189+
[
190+
(pd.DataFrame({"Open": [100]}), "Close", "data must contain a 'Close' column."),
191+
(pd.DataFrame({"Close": [100]}), "Open", "data must contain a 'Open' column."),
192+
],
193+
)
194+
def test_validate_data_missing_price_column(data, price_col, expected_error):
195+
"""Test missing price column for _validate_data."""
196+
with pytest.raises(ValueError, match=expected_error):
197+
_validate_data(data, price_col)
198+
199+
200+
@pytest.mark.parametrize(
201+
("data", "price_col", "expected_error"),
202+
[
203+
(
204+
pd.DataFrame({"Close": ["100", 100]}),
205+
"Close",
206+
"The 'Close' column must contain numeric values.",
207+
),
208+
(
209+
pd.DataFrame({"Close": ["a", "b", "c", "d"]}),
210+
"Close",
211+
"The 'Close' column must contain numeric values.",
212+
),
213+
],
214+
)
215+
def test_validate_data_non_numeric_price_column(data, price_col, expected_error):
216+
"""Test non-numeric price column for _validate_data."""
217+
with pytest.raises(ValueError, match=expected_error):
218+
_validate_data(data, price_col)
219+
220+
221+
# Tests for _validate_initial_cash
222+
@pytest.mark.parametrize("initial_cash", [1000, 1000.50, 0.01])
223+
def test_validate_initial_cash_valid_input(initial_cash):
224+
"""Test valid initial_cash values."""
225+
_validate_initial_cash(initial_cash)
226+
227+
228+
@pytest.mark.parametrize(
229+
("initial_cash", "expected_error"),
230+
[
231+
(0, "initial_cash must be a positive number."), # Zero value
232+
("1000", "initial_cash must be a number, got str."), # String
233+
],
234+
)
235+
def test_validate_initial_cash_invalid_input(initial_cash, expected_error):
236+
"""Test invalid initial_cash values."""
237+
with pytest.raises((TypeError, ValueError), match=expected_error):
238+
_validate_initial_cash(initial_cash)
239+
240+
241+
# Tests for _validate_tac
242+
@pytest.mark.parametrize("tac", [0, 0.05, 1])
243+
def test_validate_tac_valid_input(tac):
244+
"""Test valid tac values."""
245+
_validate_tac(tac)
246+
247+
248+
@pytest.mark.parametrize(
249+
("tac", "expected_error"),
250+
[
251+
(-0.01, "tac must be between 0 and 1."),
252+
(1.2, "tac must be between 0 and 1."),
253+
("0.05", "tac must be a number, got str."),
254+
],
255+
)
256+
def test_validate_tac_invalid_input(tac, expected_error):
257+
"""Test invalid tac values."""
258+
with pytest.raises((TypeError, ValueError), match=expected_error):
259+
_validate_tac(tac)
260+
261+
262+
# Tests for _validate_trade_pct
263+
@pytest.mark.parametrize("trade_pct", [0.01, 0.5, 1.0])
264+
def test_validate_trade_pct_valid_input(trade_pct):
265+
"""Test valid trade_pct values."""
266+
_validate_trade_pct(trade_pct)
267+
268+
269+
@pytest.mark.parametrize(
270+
("trade_pct", "expected_error"),
271+
[
272+
(0.0, "trade_pct must be between 0 and 1. Zero is not possible."),
273+
(-1.0, "trade_pct must be between 0 and 1. Zero is not possible."),
274+
(1, "trade_pct must be a float, got int."),
275+
],
276+
)
277+
def test_validate_trade_pct_invalid_input(trade_pct, expected_error):
278+
"""Test invalid trade_pct values."""
279+
with pytest.raises((TypeError, ValueError), match=expected_error):
280+
_validate_trade_pct(trade_pct)

0 commit comments

Comments
 (0)