Skip to content

Commit 0a7f0a1

Browse files
Add tests for backtest_signals.py.
1 parent 785a99c commit 0a7f0a1

File tree

2 files changed

+206
-37
lines changed

2 files changed

+206
-37
lines changed

src/backtest_bay/analysis/backtest_signals.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,76 @@
33
import pandas as pd
44

55

6-
def execute_buy(cash, price, shares, trade_vol, tac):
6+
def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Close"):
7+
prices = data[price_col].squeeze()
8+
cash, holdings, shares = initial_cash, 0.0, 0
9+
assets = cash + holdings
10+
portfolio = []
11+
12+
for price, signal in zip(prices, signals, strict=False):
13+
trade_vol = trade_pct * assets
14+
cash, shares = _execute_trade(signal, cash, price, shares, trade_vol, tac)
15+
assets, holdings = _update_portfolio(cash, shares, price)
16+
portfolio.append((price, signal, shares, holdings, cash, assets))
17+
18+
portfolio = pd.DataFrame(
19+
portfolio,
20+
columns=["price", "signal", "shares", "holdings", "cash", "assets"],
21+
index=data.index,
22+
)
23+
return portfolio
24+
25+
26+
def _execute_trade(signal, cash, price, shares, trade_vol, tac):
27+
buy_signal = 2
28+
sell_signal = 1
29+
30+
if signal == buy_signal:
31+
cash, shares = _execute_buy(cash, price, shares, trade_vol, tac)
32+
33+
elif signal == sell_signal:
34+
cash, shares = _execute_sell(cash, price, shares, trade_vol, tac)
35+
36+
return cash, shares
37+
38+
39+
def _execute_buy(cash, price, shares, trade_vol, tac):
740
buy_shares = math.floor(trade_vol / (price * (1 + tac)))
8-
if buy_shares < 1:
9-
return cash, shares
1041
cost = buy_shares * price * (1 + tac)
11-
if cost > cash:
42+
43+
if not _is_buy_trade_affordable(buy_shares, cost, cash):
1244
return cash, shares
45+
1346
cash -= cost
1447
shares += buy_shares
1548
return cash, shares
1649

1750

18-
def execute_sell(cash, price, shares, trade_vol, tac):
51+
def _is_buy_trade_affordable(buy_shares, cost, cash):
52+
is_trade_vol_enough = buy_shares >= 1
53+
is_cash_enough = cash >= cost
54+
return is_trade_vol_enough and is_cash_enough
55+
56+
57+
def _execute_sell(cash, price, shares, trade_vol, tac):
1958
sell_shares = math.floor(trade_vol / (price * (1 - tac)))
2059

21-
if sell_shares > shares:
22-
if shares < 1:
23-
return cash, shares
24-
sell_shares = shares
60+
if not _is_sell_trade_affordable(shares):
61+
return cash, shares
62+
63+
sell_shares = min(sell_shares, shares)
2564

2665
profit = sell_shares * price * (1 - tac)
2766
cash += profit
2867
shares -= sell_shares
2968
return cash, shares
3069

3170

32-
def update_portfolio(cash, shares, price):
33-
holdings = shares * price
34-
assets = cash + holdings
35-
return assets, holdings
71+
def _is_sell_trade_affordable(shares):
72+
return shares >= 1
3673

3774

38-
def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Close"):
39-
prices = data[price_col].squeeze()
40-
cash, holdings, shares = initial_cash, 0.0, 0
75+
def _update_portfolio(cash, shares, price):
76+
holdings = shares * price
4177
assets = cash + holdings
42-
portfolio = []
43-
44-
buy_signal = 2
45-
sell_signal = 1
46-
47-
for price, signal in zip(prices, signals, strict=False):
48-
trade_vol = trade_pct * assets
49-
50-
if signal == buy_signal:
51-
cash, shares = execute_buy(cash, price, shares, trade_vol, tac)
52-
53-
elif signal == sell_signal:
54-
cash, shares = execute_sell(cash, price, shares, trade_vol, tac)
55-
56-
assets, holdings = update_portfolio(cash, shares, price)
57-
portfolio.append((cash, holdings, shares, assets, signal, price))
58-
59-
portfolio = pd.DataFrame(
60-
portfolio, columns=["cash", "holdings", "shares", "total", "signal", "price"]
61-
)
62-
return portfolio
78+
return assets, holdings
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pandas.testing as pdt
4+
5+
from backtest_bay.analysis.backtest_signals import (
6+
_execute_buy,
7+
_execute_sell,
8+
_is_buy_trade_affordable,
9+
_is_sell_trade_affordable,
10+
_update_portfolio,
11+
backtest_signals,
12+
)
13+
14+
15+
# tests for _is_buy_affordable
16+
def test_is_buy_trade_affordable_enough_cash():
17+
"""Test buying when there is enough cash."""
18+
affordable = _is_buy_trade_affordable(buy_shares=2, cost=10, cash=10)
19+
assert affordable
20+
21+
22+
def test_is_buy_trade_affordable_not_enough_cash():
23+
"""Test buying when there is not enough cash."""
24+
affordable = _is_buy_trade_affordable(buy_shares=2, cost=10, cash=9)
25+
assert not affordable
26+
27+
28+
def test_is_buy_trade_affordable_not_enough_shares():
29+
"""Test buying when buy shares are less than 1."""
30+
affordable = _is_buy_trade_affordable(buy_shares=0.1, cost=10, cash=9)
31+
assert not affordable
32+
33+
34+
# tests for _execute_buy
35+
def test_execute_buy_enough_cash():
36+
"""Test buying when there is enough cash to purchase at least one share."""
37+
cash, shares = _execute_buy(cash=30, price=2, shares=2, trade_vol=30, tac=0.5)
38+
expected_cash, expected_shares = 0, 12
39+
assert cash == expected_cash
40+
assert shares == expected_shares
41+
42+
43+
def test_execute_buy_to_less_cash():
44+
"""Test buying when there isn't enough cash to purchase at least one share."""
45+
cash, shares = _execute_buy(cash=20, price=2, shares=2, trade_vol=30, tac=0.5)
46+
expected_cash, expected_shares = 20, 2
47+
assert cash == expected_cash
48+
assert shares == expected_shares
49+
50+
51+
# tests for _is_sell_trade_affordable
52+
def test_is_sell_trade_affordable_enough_shares():
53+
"""Test selling when there are enough shares."""
54+
affordable = _is_sell_trade_affordable(shares=1)
55+
assert affordable
56+
57+
58+
def test_is_sell_trade_affordable_not_enough_shares():
59+
"""Test selling when there are not enough shares."""
60+
affordable = _is_sell_trade_affordable(shares=0)
61+
assert not affordable
62+
63+
64+
# tests for _execute_sell
65+
def test_execute_sell_enough_shares():
66+
"""Test selling when there are enough shares."""
67+
cash, shares = _execute_sell(cash=40, price=2, shares=50, trade_vol=30, tac=0.5)
68+
expected_cash, expected_shares = 70, 20
69+
assert cash == expected_cash
70+
assert shares == expected_shares
71+
72+
cash, shares = _execute_sell(cash=40, price=2, shares=8, trade_vol=30, tac=0.5)
73+
expected_cash, expected_shares = 48, 0
74+
assert cash == expected_cash
75+
assert shares == expected_shares
76+
77+
78+
def test_execute_sell_not_enough_shares():
79+
"""Test selling when there are not enough shares."""
80+
cash, shares = _execute_sell(cash=40, price=2, shares=0, trade_vol=30, tac=0.5)
81+
expected_cash, expected_shares = 40, 0
82+
assert cash == expected_cash
83+
assert shares == expected_shares
84+
85+
86+
# tests for _update_portfolio
87+
def test_update_portfolio_correct_calculation():
88+
"""Test _update_portfolio for correct calculation."""
89+
assets, holdings = _update_portfolio(cash=100, shares=5, price=20)
90+
expected_assets, expected_holdings = 200, 100
91+
assert assets == expected_assets
92+
assert holdings == expected_holdings
93+
94+
95+
# tests for backtest_signals
96+
def test_backtest_portfolio_correct_calculation():
97+
"""Test backtest_portfolio for correct calculation."""
98+
index = pd.date_range("2023-01-01", periods=5, freq="D")
99+
data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index)
100+
101+
# Only hold
102+
signals = np.array([0, 0, 0, 0, 0])
103+
portfolio = backtest_signals(
104+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
105+
)
106+
expected_portfolio = pd.DataFrame(
107+
data={
108+
"price": [10, 5, 10, 8, 10],
109+
"signal": [0, 0, 0, 0, 0],
110+
"shares": [0, 0, 0, 0, 0],
111+
"holdings": [0, 0, 0, 0, 0],
112+
"cash": [100, 100, 100, 100, 100],
113+
"assets": [100, 100, 100, 100, 100],
114+
},
115+
index=index,
116+
)
117+
pdt.assert_frame_equal(portfolio, expected_portfolio)
118+
119+
# Buy once
120+
signals = np.array([0, 2, 0, 0, 0])
121+
portfolio = backtest_signals(
122+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
123+
)
124+
expected_portfolio = pd.DataFrame(
125+
data={
126+
"price": [10, 5, 10, 8, 10],
127+
"signal": [0, 2, 0, 0, 0],
128+
"shares": [0, 20, 20, 20, 20],
129+
"holdings": [0, 100, 200, 160, 200],
130+
"cash": [100, 0, 0, 0, 0],
131+
"assets": [100, 100, 200, 160, 200],
132+
},
133+
index=index,
134+
)
135+
pdt.assert_frame_equal(portfolio, expected_portfolio)
136+
137+
# Buy and sell once
138+
signals = np.array([0, 2, 0, 0, 1])
139+
portfolio = backtest_signals(
140+
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
141+
)
142+
expected_portfolio = pd.DataFrame(
143+
data={
144+
"price": [10, 5, 10, 8, 10],
145+
"signal": [0, 2, 0, 0, 1],
146+
"shares": [0, 20, 20, 20, 4],
147+
"holdings": [0, 100, 200, 160, 40],
148+
"cash": [100, 0, 0, 0, 160],
149+
"assets": [100, 100, 200, 160, 200],
150+
},
151+
index=index,
152+
)
153+
pdt.assert_frame_equal(portfolio, expected_portfolio)

0 commit comments

Comments
 (0)