|
1 | 1 | import pandas as pd |
2 | 2 | import pandas.testing as pdt |
| 3 | +import pytest |
3 | 4 |
|
4 | 5 | from backtest_bay.backtest.backtest_signals import ( |
5 | 6 | _execute_buy, |
6 | 7 | _execute_sell, |
7 | 8 | _is_buy_trade_affordable, |
8 | 9 | _is_sell_trade_affordable, |
9 | 10 | _update_portfolio, |
| 11 | + _validate_data, |
| 12 | + _validate_initial_cash, |
| 13 | + _validate_tac, |
| 14 | + _validate_trade_pct, |
10 | 15 | backtest_signals, |
11 | 16 | ) |
12 | 17 |
|
13 | 18 |
|
| 19 | +# tests for backtest_signals |
| 20 | +def test_backtest_portfolio_correct_calculation(): |
| 21 | + """Test backtest_portfolio for correct calculation.""" |
| 22 | + index = pd.date_range("2023-01-01", periods=5, freq="D") |
| 23 | + data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index) |
| 24 | + |
| 25 | + # Only hold |
| 26 | + signals = pd.Series([0, 0, 0, 0, 0]) |
| 27 | + portfolio = backtest_signals( |
| 28 | + data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
| 29 | + ) |
| 30 | + expected_portfolio = pd.DataFrame( |
| 31 | + data={ |
| 32 | + "price": [10, 5, 10, 8, 10], |
| 33 | + "signal": [0, 0, 0, 0, 0], |
| 34 | + "shares": [0, 0, 0, 0, 0], |
| 35 | + "holdings": [0, 0, 0, 0, 0], |
| 36 | + "cash": [100, 100, 100, 100, 100], |
| 37 | + "assets": [100, 100, 100, 100, 100], |
| 38 | + }, |
| 39 | + index=index, |
| 40 | + ) |
| 41 | + pdt.assert_frame_equal(portfolio, expected_portfolio) |
| 42 | + |
| 43 | + # Buy once |
| 44 | + signals = pd.Series([0, 2, 0, 0, 0]) |
| 45 | + portfolio = backtest_signals( |
| 46 | + data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
| 47 | + ) |
| 48 | + expected_portfolio = pd.DataFrame( |
| 49 | + data={ |
| 50 | + "price": [10, 5, 10, 8, 10], |
| 51 | + "signal": [0, 2, 0, 0, 0], |
| 52 | + "shares": [0, 20, 20, 20, 20], |
| 53 | + "holdings": [0, 100, 200, 160, 200], |
| 54 | + "cash": [100, 0, 0, 0, 0], |
| 55 | + "assets": [100, 100, 200, 160, 200], |
| 56 | + }, |
| 57 | + index=index, |
| 58 | + ) |
| 59 | + pdt.assert_frame_equal(portfolio, expected_portfolio) |
| 60 | + |
| 61 | + # Buy and sell once |
| 62 | + signals = pd.Series([0, 2, 0, 0, 1]) |
| 63 | + portfolio = backtest_signals( |
| 64 | + data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
| 65 | + ) |
| 66 | + expected_portfolio = pd.DataFrame( |
| 67 | + data={ |
| 68 | + "price": [10, 5, 10, 8, 10], |
| 69 | + "signal": [0, 2, 0, 0, 1], |
| 70 | + "shares": [0, 20, 20, 20, 4], |
| 71 | + "holdings": [0, 100, 200, 160, 40], |
| 72 | + "cash": [100, 0, 0, 0, 160], |
| 73 | + "assets": [100, 100, 200, 160, 200], |
| 74 | + }, |
| 75 | + index=index, |
| 76 | + ) |
| 77 | + pdt.assert_frame_equal(portfolio, expected_portfolio) |
| 78 | + |
| 79 | + |
14 | 80 | # tests for _is_buy_affordable |
15 | 81 | def test_is_buy_trade_affordable_enough_cash(): |
16 | 82 | """Test buying when there is enough cash.""" |
@@ -91,62 +157,124 @@ def test_update_portfolio_correct_calculation(): |
91 | 157 | assert holdings == expected_holdings |
92 | 158 |
|
93 | 159 |
|
94 | | -# tests for backtest_signals |
95 | | -def test_backtest_portfolio_correct_calculation(): |
96 | | - """Test backtest_portfolio for correct calculation.""" |
97 | | - index = pd.date_range("2023-01-01", periods=5, freq="D") |
98 | | - data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index) |
| 160 | +# Tests for _validate_data |
| 161 | +@pytest.mark.parametrize( |
| 162 | + ("data", "price_col"), |
| 163 | + [ |
| 164 | + (pd.DataFrame({"Close": [100, 101, 102]}), "Close"), |
| 165 | + (pd.DataFrame({"Open": [99, 100, 101], "Close": [100, 101, 102]}), "Close"), |
| 166 | + (pd.DataFrame({"Close": [100.5, 101.2, 102.1, 103.8]}), "Close"), |
| 167 | + ], |
| 168 | +) |
| 169 | +def test_validate_data_valid_input(data, price_col): |
| 170 | + """Test valid data input for _validate_data.""" |
| 171 | + _validate_data(data, price_col) |
99 | 172 |
|
100 | | - # Only hold |
101 | | - signals = pd.Series([0, 0, 0, 0, 0]) |
102 | | - portfolio = backtest_signals( |
103 | | - data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
104 | | - ) |
105 | | - expected_portfolio = pd.DataFrame( |
106 | | - data={ |
107 | | - "price": [10, 5, 10, 8, 10], |
108 | | - "signal": [0, 0, 0, 0, 0], |
109 | | - "shares": [0, 0, 0, 0, 0], |
110 | | - "holdings": [0, 0, 0, 0, 0], |
111 | | - "cash": [100, 100, 100, 100, 100], |
112 | | - "assets": [100, 100, 100, 100, 100], |
113 | | - }, |
114 | | - index=index, |
115 | | - ) |
116 | | - pdt.assert_frame_equal(portfolio, expected_portfolio) |
117 | 173 |
|
118 | | - # Buy once |
119 | | - signals = pd.Series([0, 2, 0, 0, 0]) |
120 | | - portfolio = backtest_signals( |
121 | | - data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
122 | | - ) |
123 | | - expected_portfolio = pd.DataFrame( |
124 | | - data={ |
125 | | - "price": [10, 5, 10, 8, 10], |
126 | | - "signal": [0, 2, 0, 0, 0], |
127 | | - "shares": [0, 20, 20, 20, 20], |
128 | | - "holdings": [0, 100, 200, 160, 200], |
129 | | - "cash": [100, 0, 0, 0, 0], |
130 | | - "assets": [100, 100, 200, 160, 200], |
131 | | - }, |
132 | | - index=index, |
133 | | - ) |
134 | | - pdt.assert_frame_equal(portfolio, expected_portfolio) |
| 174 | +@pytest.mark.parametrize( |
| 175 | + ("data", "price_col", "expected_error"), |
| 176 | + [ |
| 177 | + ([100, 101], "Close", "data must be a pandas DataFrame, got list."), |
| 178 | + ({"Close": [100, 101]}, "Close", "data must be a pandas DataFrame, got dict."), |
| 179 | + ], |
| 180 | +) |
| 181 | +def test_validate_data_invalid_type(data, price_col, expected_error): |
| 182 | + """Test invalid data types for _validate_data.""" |
| 183 | + with pytest.raises(TypeError, match=expected_error): |
| 184 | + _validate_data(data, price_col) |
135 | 185 |
|
136 | | - # Buy and sell once |
137 | | - signals = pd.Series([0, 2, 0, 0, 1]) |
138 | | - portfolio = backtest_signals( |
139 | | - data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close" |
140 | | - ) |
141 | | - expected_portfolio = pd.DataFrame( |
142 | | - data={ |
143 | | - "price": [10, 5, 10, 8, 10], |
144 | | - "signal": [0, 2, 0, 0, 1], |
145 | | - "shares": [0, 20, 20, 20, 4], |
146 | | - "holdings": [0, 100, 200, 160, 40], |
147 | | - "cash": [100, 0, 0, 0, 160], |
148 | | - "assets": [100, 100, 200, 160, 200], |
149 | | - }, |
150 | | - index=index, |
151 | | - ) |
152 | | - pdt.assert_frame_equal(portfolio, expected_portfolio) |
| 186 | + |
| 187 | +@pytest.mark.parametrize( |
| 188 | + ("data", "price_col", "expected_error"), |
| 189 | + [ |
| 190 | + (pd.DataFrame({"Open": [100]}), "Close", "data must contain a 'Close' column."), |
| 191 | + (pd.DataFrame({"Close": [100]}), "Open", "data must contain a 'Open' column."), |
| 192 | + ], |
| 193 | +) |
| 194 | +def test_validate_data_missing_price_column(data, price_col, expected_error): |
| 195 | + """Test missing price column for _validate_data.""" |
| 196 | + with pytest.raises(ValueError, match=expected_error): |
| 197 | + _validate_data(data, price_col) |
| 198 | + |
| 199 | + |
| 200 | +@pytest.mark.parametrize( |
| 201 | + ("data", "price_col", "expected_error"), |
| 202 | + [ |
| 203 | + ( |
| 204 | + pd.DataFrame({"Close": ["100", 100]}), |
| 205 | + "Close", |
| 206 | + "The 'Close' column must contain numeric values.", |
| 207 | + ), |
| 208 | + ( |
| 209 | + pd.DataFrame({"Close": ["a", "b", "c", "d"]}), |
| 210 | + "Close", |
| 211 | + "The 'Close' column must contain numeric values.", |
| 212 | + ), |
| 213 | + ], |
| 214 | +) |
| 215 | +def test_validate_data_non_numeric_price_column(data, price_col, expected_error): |
| 216 | + """Test non-numeric price column for _validate_data.""" |
| 217 | + with pytest.raises(ValueError, match=expected_error): |
| 218 | + _validate_data(data, price_col) |
| 219 | + |
| 220 | + |
| 221 | +# Tests for _validate_initial_cash |
| 222 | +@pytest.mark.parametrize("initial_cash", [1000, 1000.50, 0.01]) |
| 223 | +def test_validate_initial_cash_valid_input(initial_cash): |
| 224 | + """Test valid initial_cash values.""" |
| 225 | + _validate_initial_cash(initial_cash) |
| 226 | + |
| 227 | + |
| 228 | +@pytest.mark.parametrize( |
| 229 | + ("initial_cash", "expected_error"), |
| 230 | + [ |
| 231 | + (0, "initial_cash must be a positive number."), # Zero value |
| 232 | + ("1000", "initial_cash must be a number, got str."), # String |
| 233 | + ], |
| 234 | +) |
| 235 | +def test_validate_initial_cash_invalid_input(initial_cash, expected_error): |
| 236 | + """Test invalid initial_cash values.""" |
| 237 | + with pytest.raises((TypeError, ValueError), match=expected_error): |
| 238 | + _validate_initial_cash(initial_cash) |
| 239 | + |
| 240 | + |
| 241 | +# Tests for _validate_tac |
| 242 | +@pytest.mark.parametrize("tac", [0, 0.05, 1]) |
| 243 | +def test_validate_tac_valid_input(tac): |
| 244 | + """Test valid tac values.""" |
| 245 | + _validate_tac(tac) |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.parametrize( |
| 249 | + ("tac", "expected_error"), |
| 250 | + [ |
| 251 | + (-0.01, "tac must be between 0 and 1."), |
| 252 | + (1.2, "tac must be between 0 and 1."), |
| 253 | + ("0.05", "tac must be a number, got str."), |
| 254 | + ], |
| 255 | +) |
| 256 | +def test_validate_tac_invalid_input(tac, expected_error): |
| 257 | + """Test invalid tac values.""" |
| 258 | + with pytest.raises((TypeError, ValueError), match=expected_error): |
| 259 | + _validate_tac(tac) |
| 260 | + |
| 261 | + |
| 262 | +# Tests for _validate_trade_pct |
| 263 | +@pytest.mark.parametrize("trade_pct", [0.01, 0.5, 1.0]) |
| 264 | +def test_validate_trade_pct_valid_input(trade_pct): |
| 265 | + """Test valid trade_pct values.""" |
| 266 | + _validate_trade_pct(trade_pct) |
| 267 | + |
| 268 | + |
| 269 | +@pytest.mark.parametrize( |
| 270 | + ("trade_pct", "expected_error"), |
| 271 | + [ |
| 272 | + (0.0, "trade_pct must be between 0 and 1. Zero is not possible."), |
| 273 | + (-1.0, "trade_pct must be between 0 and 1. Zero is not possible."), |
| 274 | + (1, "trade_pct must be a float, got int."), |
| 275 | + ], |
| 276 | +) |
| 277 | +def test_validate_trade_pct_invalid_input(trade_pct, expected_error): |
| 278 | + """Test invalid trade_pct values.""" |
| 279 | + with pytest.raises((TypeError, ValueError), match=expected_error): |
| 280 | + _validate_trade_pct(trade_pct) |
0 commit comments