Skip to content

Commit 2173a7a

Browse files
Update version of library yfinance. Add functions to validate input of backtest_signals.py. Add doc strings for generate_signals.py.
1 parent 27aa09e commit 2173a7a

File tree

7 files changed

+238
-33
lines changed

7 files changed

+238
-33
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies:
2626
- plotly >=5.2.0,<6
2727

2828
# Data dependencies
29-
- yfinance
29+
- yfinance >=0.2.54
3030

3131
# Install project
3232
- pip: [-e ., pdbp, kaleido]

src/backtest_bay/analysis/backtest_signals.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,34 @@
44

55

66
def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Close"):
7+
"""Backtest trading signals to simulate portfolio performance.
8+
9+
Args:
10+
data (pd.DataFrame): DataFrame containing asset price data.
11+
- Must include a column specified by `price_col` (default: 'Close').
12+
- The index should be datetime or sequential for portfolio tracking.
13+
signals (pd.Series): Series of trading signals.
14+
- 2: Buy Signal
15+
- 1: Sell Signal
16+
- 0: Do Nothing
17+
initial_cash (int, float): Initial cash available for trading.
18+
tac (int, float): Transaction cost as a percentage (e.g., 0.05 for 5%).
19+
trade_pct (float): Percentage of 'initial_cash' to trade per signal.
20+
price_col (str): Column name for the asset's price. Default is 'Close'.
21+
22+
Returns:
23+
pd.DataFrame: Portfolio performance over time with columns:
24+
- 'price': The price of the stock.
25+
- 'signal': Trading signal used (2: Buy, 1: Sell, 0: Do Nothing).
26+
- 'shares': Number of shares.
27+
- 'holdings': Total value of shares (price * shares)
28+
- 'cash': Cash.
29+
- 'assets': Portfolio value (cash + holdings).
30+
"""
31+
_validate_backtest_signals_input(
32+
data, signals, initial_cash, tac, trade_pct, price_col
33+
)
34+
735
prices = data[price_col].squeeze()
836
cash, holdings, shares = initial_cash, 0.0, 0
937
assets = cash + holdings
@@ -24,6 +52,19 @@ def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Clo
2452

2553

2654
def _execute_trade(signal, cash, price, shares, trade_vol, tac):
55+
"""Execute a trade based on the trading signal.
56+
57+
Args:
58+
signal (int): Current trading signal.
59+
cash (int, float): Current cash.
60+
price (float): Current price of the stock.
61+
shares (int): Current shares.
62+
trade_vol (float): Volume of portfolio to trade.
63+
tac (int, float): Transaction cost.
64+
65+
Returns:
66+
tuple: Updated cash and shares after the trade.
67+
"""
2768
buy_signal = 2
2869
sell_signal = 1
2970

@@ -37,6 +78,20 @@ def _execute_trade(signal, cash, price, shares, trade_vol, tac):
3778

3879

3980
def _execute_buy(cash, price, shares, trade_vol, tac):
81+
"""Execute a buy trade.
82+
83+
Args:
84+
cash (int, float): Current cash.
85+
price (float): Current price of the asset.
86+
shares (int): Current shares.
87+
trade_vol (float): Volume of portfolio to trade.
88+
tac (float): Transaction cost.
89+
90+
Returns:
91+
tuple: Updated cash and shares after the buy trade.
92+
- cash (float): Remaining cash after the trade.
93+
- shares (int): Updated number of shares held.
94+
"""
4095
buy_shares = math.floor(trade_vol / (price * (1 + tac)))
4196
cost = buy_shares * price * (1 + tac)
4297

@@ -49,12 +104,36 @@ def _execute_buy(cash, price, shares, trade_vol, tac):
49104

50105

51106
def _is_buy_trade_affordable(buy_shares, cost, cash):
107+
"""Check if the buy trade is affordable.
108+
109+
Args:
110+
buy_shares (int): Number of shares to buy.
111+
cost (int, float): Total cost of the shares.
112+
cash (int, float): Current cash.
113+
114+
Returns:
115+
bool: True if the trade is affordable, False otherwise.
116+
"""
52117
is_trade_vol_enough = buy_shares >= 1
53118
is_cash_enough = cash >= cost
54119
return is_trade_vol_enough and is_cash_enough
55120

56121

57122
def _execute_sell(cash, price, shares, trade_vol, tac):
123+
"""Execute a sell trade.
124+
125+
Args:
126+
cash (float): Current cash.
127+
price (float): Current price.
128+
shares (int): Current shares.
129+
trade_vol (float): Volume of portfolio to trade.
130+
tac (float): Transaction cost.
131+
132+
Returns:
133+
tuple: Updated cash and shares after the sell trade.
134+
- cash (float): Updated cash.
135+
- shares (int): Updated shares.
136+
"""
58137
sell_shares = math.floor(trade_vol / (price * (1 - tac)))
59138

60139
if not _is_sell_trade_affordable(shares):
@@ -69,10 +148,133 @@ def _execute_sell(cash, price, shares, trade_vol, tac):
69148

70149

71150
def _is_sell_trade_affordable(shares):
151+
"""Checks if there are enough shares to sell."""
72152
return shares >= 1
73153

74154

75155
def _update_portfolio(cash, shares, price):
156+
"""Updates holdings and assets after trade."""
76157
holdings = shares * price
77158
assets = cash + holdings
78159
return assets, holdings
160+
161+
162+
def _validate_backtest_signals_input(
163+
data, signals, initial_cash, tac, trade_pct, price_col
164+
):
165+
"""Validates input for backtesting signals."""
166+
_validate_data(data, price_col)
167+
_validate_signals(signals, data, price_col)
168+
_validate_initial_cash(initial_cash)
169+
_validate_tac(tac)
170+
_validate_trade_pct(trade_pct)
171+
172+
173+
def _validate_data(data, price_col):
174+
"""Validate the input data for backtesting.
175+
176+
Args:
177+
data (pd.DataFrame): DataFrame containing stock data.
178+
price_col (str): Column name for the stock price.
179+
180+
Raises:
181+
TypeError: If data is not a pandas DataFrame.
182+
ValueError: If the price column is missing or contains non-numeric values.
183+
"""
184+
if not isinstance(data, pd.DataFrame):
185+
error_msg = f"data must be a pandas DataFrame, got {type(data).__name__}."
186+
raise TypeError(error_msg)
187+
188+
if price_col not in data.columns:
189+
error_msg = f"data must contain a '{price_col}' column."
190+
raise ValueError(error_msg)
191+
192+
if not pd.api.types.is_numeric_dtype(data[price_col].squeeze()):
193+
error_msg = f"The '{price_col}' column must contain numeric values."
194+
raise ValueError(error_msg)
195+
196+
197+
def _validate_signals(signals, data, price_col):
198+
"""Validate the trading signals.
199+
200+
Args:
201+
signals (pd.Series): Trading signals for backtesting.
202+
data (pd.DataFrame): DataFrame containing stock data.
203+
price_col (str): Column name for the stock price.
204+
205+
Raises:
206+
TypeError: If signals is not a pandas Series.
207+
ValueError: If signals contain invalid values.
208+
ValueError: If signals do not match the length of the price column.
209+
"""
210+
if not isinstance(signals, pd.Series):
211+
error_msg = f"signals must be a pandas Series, got {type(signals).__name__}."
212+
raise TypeError(error_msg)
213+
214+
if not all(isinstance(signal, int) for signal in signals):
215+
error_msg = "signals must contain only integers."
216+
raise ValueError(error_msg)
217+
218+
if not all(signal in [0, 1, 2] for signal in signals):
219+
error_msg = "signals must contain only 0 (Hold), 1 (Sell), or 2 (Buy)."
220+
raise ValueError(error_msg)
221+
222+
if len(signals) != len(data[price_col]):
223+
error_msg = f"signals must have the same length as the '{price_col}' column."
224+
raise ValueError(error_msg)
225+
226+
227+
def _validate_initial_cash(initial_cash):
228+
"""Validate the initial cash value for backtesting.
229+
230+
Args:
231+
initial_cash (int or float): The initial cash available for trading.
232+
233+
Raises:
234+
TypeError: If initial_cash is not an integer or float.
235+
ValueError: If initial_cash is not positive.
236+
"""
237+
if not isinstance(initial_cash, int | float):
238+
error_msg = f"initial_cash must be a number, got {type(initial_cash).__name__}."
239+
raise TypeError(error_msg)
240+
if initial_cash <= 0:
241+
error_msg = "initial_cash must be a positive number."
242+
raise ValueError(error_msg)
243+
244+
245+
def _validate_tac(tac):
246+
"""Validate the transaction cost (tac) for backtesting.
247+
248+
Args:
249+
tac (int or float): Transaction cost as a percentage.
250+
251+
Raises:
252+
TypeError: If tac is not an integer or float.
253+
ValueError: If tac is negative or greater than 1.
254+
"""
255+
if not isinstance(tac, int | float):
256+
error_msg = f"tac must be a number, got {type(tac).__name__}."
257+
raise TypeError(error_msg)
258+
259+
if not (0 <= tac <= 1):
260+
error_msg = "tac must be between 0 and 1."
261+
raise ValueError(error_msg)
262+
263+
264+
def _validate_trade_pct(trade_pct):
265+
"""Validate the trade percentage (trade_pct) for backtesting.
266+
267+
Args:
268+
trade_pct (float): Trade percentage of total assets per trade.
269+
270+
Raises:
271+
TypeError: If trade_pct is not a float.
272+
ValueError: If trade_pct is not between 0 and 1.
273+
"""
274+
if not isinstance(trade_pct, float):
275+
error_msg = f"trade_pct must be a float, got {type(trade_pct).__name__}."
276+
raise TypeError(error_msg)
277+
278+
if not (0 < trade_pct <= 1):
279+
error_msg = "trade_pct must be between 0 and 1. Zero is not possible."
280+
raise ValueError(error_msg)

src/backtest_bay/analysis/generate_signals.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,22 @@ def generate_signals(data, method, **kwargs):
88
data (pd.DataFrame): DataFrame containing asset price data
99
(must include 'Close' column).
1010
method (str): The signal generation method. Currently supported:
11-
- 'bollinger_bands': Uses Bollinger Bands for signal generation.
11+
- 'bollinger'
12+
- 'macd'
13+
- 'roc'
14+
- 'rsi'
1215
**kwargs: Additional parameters specific to the chosen method.
16+
- 'bollinger':
17+
- window (int): Window size for moving average.
18+
- num_std_dev (float, int): Standard deviation multiplier for bands.
19+
- 'macd':
20+
- short_window (int): Window size for the fast EMA.
21+
- long_window (int): Window size for the slow EMA.
22+
- signal_window (int): Window size for the signal line EMA.
23+
- 'roc':
24+
- window (int): Window size for calculating the rate of change.
25+
- 'rsi':
26+
- window (int): Window size for calculating RSI.
1327
1428
Returns:
1529
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).

src/backtest_bay/analysis/task_backtest_signals.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,19 @@
1515
id_backtest = (
1616
f"{row.stock}_{row.start_date}_{row.end_date}_" f"{row.interval}_{row.strategy}"
1717
)
18+
1819
data_path = BLD / f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}.pkl"
19-
data = pd.read_pickle(data_path)
2020
produces = BLD / f"{id_backtest}.pkl"
2121

22-
signals = generate_signals(data, row.strategy)
23-
2422
@pytask.task(id=id_backtest)
25-
def task_backtest_signals(
26-
data=data,
27-
signals=signals,
28-
initial_cash=INITIAL_CASH,
29-
tac=TAC,
30-
investment_pct=TRADE_PCT,
31-
scripts=scripts,
32-
data_path=data_path,
33-
produces=produces,
34-
):
35-
portfolio = backtest_signals(data, signals, initial_cash, tac, investment_pct)
23+
def task_backtest(scripts=scripts, data_path=data_path, produces=produces, row=row):
24+
data = pd.read_pickle(data_path)
25+
signals = generate_signals(data=data, method=row.strategy)
26+
portfolio = backtest_signals(
27+
data=data,
28+
signals=signals,
29+
initial_cash=INITIAL_CASH,
30+
tac=TAC,
31+
trade_pct=TRADE_PCT,
32+
)
3633
portfolio.to_pickle(produces)

src/backtest_bay/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# Configure input data
1414
STOCKS = ["AAPL", "MSFT"]
15-
START_DATES = ["2022-01-01"]
15+
START_DATES = ["2019-01-01"]
1616
END_DATES = ["2025-01-01"]
1717
INTERVALS = ["1d"]
1818
STRATEGIES = ["bollinger", "macd", "roc", "rsi"]

src/backtest_bay/data/task_download_data.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,12 @@
1414

1515
produces = BLD / f"{id_download}.pkl"
1616

17-
if produces.exists():
18-
continue
19-
2017
@pytask.task(id=id_download)
21-
def task_download_data(
22-
symbol=row.stock,
23-
start_date=row.start_date,
24-
end_date=row.end_date,
25-
interval=row.interval,
26-
depends_on=scripts,
27-
produces=produces,
28-
):
18+
def task_download_data(depends_on=scripts, produces=produces, row=row):
2919
data = download_data(
30-
symbol=symbol, start_date=start_date, end_date=end_date, interval=interval
20+
symbol=row.stock,
21+
start_date=row.start_date,
22+
end_date=row.end_date,
23+
interval=row.interval,
3124
)
3225
data.to_pickle(produces)

tests/analysis/test_backtest_signals.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pandas as pd
32
import pandas.testing as pdt
43

@@ -99,7 +98,7 @@ def test_backtest_portfolio_correct_calculation():
9998
data = pd.DataFrame({"Close": [10, 5, 10, 8, 10]}, index=index)
10099

101100
# Only hold
102-
signals = np.array([0, 0, 0, 0, 0])
101+
signals = pd.Series([0, 0, 0, 0, 0])
103102
portfolio = backtest_signals(
104103
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
105104
)
@@ -117,7 +116,7 @@ def test_backtest_portfolio_correct_calculation():
117116
pdt.assert_frame_equal(portfolio, expected_portfolio)
118117

119118
# Buy once
120-
signals = np.array([0, 2, 0, 0, 0])
119+
signals = pd.Series([0, 2, 0, 0, 0])
121120
portfolio = backtest_signals(
122121
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
123122
)
@@ -135,7 +134,7 @@ def test_backtest_portfolio_correct_calculation():
135134
pdt.assert_frame_equal(portfolio, expected_portfolio)
136135

137136
# Buy and sell once
138-
signals = np.array([0, 2, 0, 0, 1])
137+
signals = pd.Series([0, 2, 0, 0, 1])
139138
portfolio = backtest_signals(
140139
data, signals, initial_cash=100, tac=0, trade_pct=1.0, price_col="Close"
141140
)

0 commit comments

Comments
 (0)