44
55
66def backtest_signals (data , signals , initial_cash , tac , trade_pct , price_col = "Close" ):
7+ """Backtest trading signals to simulate portfolio performance.
8+
9+ Args:
10+ data (pd.DataFrame): DataFrame containing asset price data.
11+ - Must include a column specified by `price_col` (default: 'Close').
12+ - The index should be datetime or sequential for portfolio tracking.
13+ signals (pd.Series): Series of trading signals.
14+ - 2: Buy Signal
15+ - 1: Sell Signal
16+ - 0: Do Nothing
17+ initial_cash (int, float): Initial cash available for trading.
18+ tac (int, float): Transaction cost as a percentage (e.g., 0.05 for 5%).
19+ trade_pct (float): Percentage of 'initial_cash' to trade per signal.
20+ price_col (str): Column name for the asset's price. Default is 'Close'.
21+
22+ Returns:
23+ pd.DataFrame: Portfolio performance over time with columns:
24+ - 'price': The price of the stock.
25+ - 'signal': Trading signal used (2: Buy, 1: Sell, 0: Do Nothing).
26+ - 'shares': Number of shares.
27+ - 'holdings': Total value of shares (price * shares)
28+ - 'cash': Cash.
29+ - 'assets': Portfolio value (cash + holdings).
30+ """
31+ _validate_backtest_signals_input (
32+ data , signals , initial_cash , tac , trade_pct , price_col
33+ )
34+
735 prices = data [price_col ].squeeze ()
836 cash , holdings , shares = initial_cash , 0.0 , 0
937 assets = cash + holdings
@@ -24,6 +52,19 @@ def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Clo
2452
2553
2654def _execute_trade (signal , cash , price , shares , trade_vol , tac ):
55+ """Execute a trade based on the trading signal.
56+
57+ Args:
58+ signal (int): Current trading signal.
59+ cash (int, float): Current cash.
60+ price (float): Current price of the stock.
61+ shares (int): Current shares.
62+ trade_vol (float): Volume of portfolio to trade.
63+ tac (int, float): Transaction cost.
64+
65+ Returns:
66+ tuple: Updated cash and shares after the trade.
67+ """
2768 buy_signal = 2
2869 sell_signal = 1
2970
@@ -37,6 +78,20 @@ def _execute_trade(signal, cash, price, shares, trade_vol, tac):
3778
3879
3980def _execute_buy (cash , price , shares , trade_vol , tac ):
81+ """Execute a buy trade.
82+
83+ Args:
84+ cash (int, float): Current cash.
85+ price (float): Current price of the asset.
86+ shares (int): Current shares.
87+ trade_vol (float): Volume of portfolio to trade.
88+ tac (float): Transaction cost.
89+
90+ Returns:
91+ tuple: Updated cash and shares after the buy trade.
92+ - cash (float): Remaining cash after the trade.
93+ - shares (int): Updated number of shares held.
94+ """
4095 buy_shares = math .floor (trade_vol / (price * (1 + tac )))
4196 cost = buy_shares * price * (1 + tac )
4297
@@ -49,12 +104,36 @@ def _execute_buy(cash, price, shares, trade_vol, tac):
49104
50105
51106def _is_buy_trade_affordable (buy_shares , cost , cash ):
107+ """Check if the buy trade is affordable.
108+
109+ Args:
110+ buy_shares (int): Number of shares to buy.
111+ cost (int, float): Total cost of the shares.
112+ cash (int, float): Current cash.
113+
114+ Returns:
115+ bool: True if the trade is affordable, False otherwise.
116+ """
52117 is_trade_vol_enough = buy_shares >= 1
53118 is_cash_enough = cash >= cost
54119 return is_trade_vol_enough and is_cash_enough
55120
56121
57122def _execute_sell (cash , price , shares , trade_vol , tac ):
123+ """Execute a sell trade.
124+
125+ Args:
126+ cash (float): Current cash.
127+ price (float): Current price.
128+ shares (int): Current shares.
129+ trade_vol (float): Volume of portfolio to trade.
130+ tac (float): Transaction cost.
131+
132+ Returns:
133+ tuple: Updated cash and shares after the sell trade.
134+ - cash (float): Updated cash.
135+ - shares (int): Updated shares.
136+ """
58137 sell_shares = math .floor (trade_vol / (price * (1 - tac )))
59138
60139 if not _is_sell_trade_affordable (shares ):
@@ -69,10 +148,133 @@ def _execute_sell(cash, price, shares, trade_vol, tac):
69148
70149
71150def _is_sell_trade_affordable (shares ):
151+ """Checks if there are enough shares to sell."""
72152 return shares >= 1
73153
74154
75155def _update_portfolio (cash , shares , price ):
156+ """Updates holdings and assets after trade."""
76157 holdings = shares * price
77158 assets = cash + holdings
78159 return assets , holdings
160+
161+
162+ def _validate_backtest_signals_input (
163+ data , signals , initial_cash , tac , trade_pct , price_col
164+ ):
165+ """Validates input for backtesting signals."""
166+ _validate_data (data , price_col )
167+ _validate_signals (signals , data , price_col )
168+ _validate_initial_cash (initial_cash )
169+ _validate_tac (tac )
170+ _validate_trade_pct (trade_pct )
171+
172+
173+ def _validate_data (data , price_col ):
174+ """Validate the input data for backtesting.
175+
176+ Args:
177+ data (pd.DataFrame): DataFrame containing stock data.
178+ price_col (str): Column name for the stock price.
179+
180+ Raises:
181+ TypeError: If data is not a pandas DataFrame.
182+ ValueError: If the price column is missing or contains non-numeric values.
183+ """
184+ if not isinstance (data , pd .DataFrame ):
185+ error_msg = f"data must be a pandas DataFrame, got { type (data ).__name__ } ."
186+ raise TypeError (error_msg )
187+
188+ if price_col not in data .columns :
189+ error_msg = f"data must contain a '{ price_col } ' column."
190+ raise ValueError (error_msg )
191+
192+ if not pd .api .types .is_numeric_dtype (data [price_col ].squeeze ()):
193+ error_msg = f"The '{ price_col } ' column must contain numeric values."
194+ raise ValueError (error_msg )
195+
196+
197+ def _validate_signals (signals , data , price_col ):
198+ """Validate the trading signals.
199+
200+ Args:
201+ signals (pd.Series): Trading signals for backtesting.
202+ data (pd.DataFrame): DataFrame containing stock data.
203+ price_col (str): Column name for the stock price.
204+
205+ Raises:
206+ TypeError: If signals is not a pandas Series.
207+ ValueError: If signals contain invalid values.
208+ ValueError: If signals do not match the length of the price column.
209+ """
210+ if not isinstance (signals , pd .Series ):
211+ error_msg = f"signals must be a pandas Series, got { type (signals ).__name__ } ."
212+ raise TypeError (error_msg )
213+
214+ if not all (isinstance (signal , int ) for signal in signals ):
215+ error_msg = "signals must contain only integers."
216+ raise ValueError (error_msg )
217+
218+ if not all (signal in [0 , 1 , 2 ] for signal in signals ):
219+ error_msg = "signals must contain only 0 (Hold), 1 (Sell), or 2 (Buy)."
220+ raise ValueError (error_msg )
221+
222+ if len (signals ) != len (data [price_col ]):
223+ error_msg = f"signals must have the same length as the '{ price_col } ' column."
224+ raise ValueError (error_msg )
225+
226+
227+ def _validate_initial_cash (initial_cash ):
228+ """Validate the initial cash value for backtesting.
229+
230+ Args:
231+ initial_cash (int or float): The initial cash available for trading.
232+
233+ Raises:
234+ TypeError: If initial_cash is not an integer or float.
235+ ValueError: If initial_cash is not positive.
236+ """
237+ if not isinstance (initial_cash , int | float ):
238+ error_msg = f"initial_cash must be a number, got { type (initial_cash ).__name__ } ."
239+ raise TypeError (error_msg )
240+ if initial_cash <= 0 :
241+ error_msg = "initial_cash must be a positive number."
242+ raise ValueError (error_msg )
243+
244+
245+ def _validate_tac (tac ):
246+ """Validate the transaction cost (tac) for backtesting.
247+
248+ Args:
249+ tac (int or float): Transaction cost as a percentage.
250+
251+ Raises:
252+ TypeError: If tac is not an integer or float.
253+ ValueError: If tac is negative or greater than 1.
254+ """
255+ if not isinstance (tac , int | float ):
256+ error_msg = f"tac must be a number, got { type (tac ).__name__ } ."
257+ raise TypeError (error_msg )
258+
259+ if not (0 <= tac <= 1 ):
260+ error_msg = "tac must be between 0 and 1."
261+ raise ValueError (error_msg )
262+
263+
264+ def _validate_trade_pct (trade_pct ):
265+ """Validate the trade percentage (trade_pct) for backtesting.
266+
267+ Args:
268+ trade_pct (float): Trade percentage of total assets per trade.
269+
270+ Raises:
271+ TypeError: If trade_pct is not a float.
272+ ValueError: If trade_pct is not between 0 and 1.
273+ """
274+ if not isinstance (trade_pct , float ):
275+ error_msg = f"trade_pct must be a float, got { type (trade_pct ).__name__ } ."
276+ raise TypeError (error_msg )
277+
278+ if not (0 < trade_pct <= 1 ):
279+ error_msg = "trade_pct must be between 0 and 1. Zero is not possible."
280+ raise ValueError (error_msg )
0 commit comments