1- import numpy as np
21import pandas as pd
32
43
@@ -10,16 +9,22 @@ def generate_signals(data, method, **kwargs):
109 (must include 'Close' column).
1110 method (str): The signal generation method. Currently supported:
1211 - 'bollinger_bands': Uses Bollinger Bands for signal generation.
13- **kwargs: Additional parameters specific to the chosen method
14- (e.g., window size, number of standard deviations).
12+ **kwargs: Additional parameters specific to the chosen method.
1513
1614 Returns:
1715 pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
1816 """
19- if method == "flip" :
20- signal = _flip_signals (prices = data ["Close" ].squeeze ())
17+ _validate_input_method (method )
18+ closing_prices = data ["Close" ].squeeze ()
19+
2120 if method == "bollinger" :
22- signal = _bollinger_signals (prices = data ["Close" ].squeeze (), ** kwargs )
21+ signal = _bollinger_signals (prices = closing_prices , ** kwargs )
22+ if method == "macd" :
23+ signal = _macd_signals (prices = closing_prices , ** kwargs )
24+ if method == "roc" :
25+ signal = _roc_signals (prices = closing_prices , ** kwargs )
26+ if method == "rsi" :
27+ signal = _rsi_signals (prices = closing_prices , ** kwargs )
2328 return signal
2429
2530
@@ -37,36 +42,223 @@ def _bollinger_signals(prices, window=20, num_std_dev=2):
3742 Returns:
3843 pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
3944 """
40- moving_avg = prices .rolling (window = window ).mean ().fillna (0 )
45+ _validate_input_bollinger_signals (window , num_std_dev )
46+
47+ moving_avg = prices .rolling (window = window ).mean ()
4148 std_dev = prices .rolling (window = window ).std ()
4249 upper_band = moving_avg + (num_std_dev * std_dev )
4350 lower_band = moving_avg - (num_std_dev * std_dev )
4451
4552 signals = pd .Series (0 , index = prices .index )
46- signals [prices < lower_band ] = 2
47- signals [prices > upper_band ] = 1
53+ signals . loc [prices < lower_band ] = 2
54+ signals . loc [prices > upper_band ] = 1
4855
49- signals = _shift_signals_to_right ( signals )
50- return pd . Series ( signals , index = prices . index )
56+ signals = signals . shift ( periods = 1 , fill_value = 0 )
57+ return signals
5158
5259
53- def _flip_signals (prices ):
54- """Generate trading signals based on price changes from the previous price.
60+ def _macd_signals (prices , short_window = 12 , long_window = 26 , signal_window = 9 ):
61+ """Generate trading signals based on the MACD indicator.
62+
63+ A buy signal (2) is generated when the MACD line crosses above the Signal Line.
64+ A sell signal (1) is generated when the MACD line crosses below the Signal Line.
5565
5666 Args:
57- prices (np.Series): Series of asset prices without index.
67+ prices (pd.Series): Series of asset prices.
68+ short_window (int): Window size for the short EMA (default: 12).
69+ long_window (int): Window size for the long EMA (default: 26).
70+ signal_window (int): Window size for the signal line EMA (default: 9).
5871
5972 Returns:
60- np.ndarray: Trading signals (2: buy if previous price is lower,
61- 1: sell if previous price is higher,
62- 0: do nothing for the first price).
73+ pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
6374 """
64- price_diff = np .diff (prices , prepend = prices .iloc [0 ])
65- signals = np .zeros (len (prices ), dtype = int )
66- signals [1 :][price_diff [1 :] > 0 ] = 1
67- signals [1 :][price_diff [1 :] < 0 ] = 2
75+ _validate_input_macd_signals (short_window , long_window , signal_window )
76+
77+ short_ema = prices .ewm (span = short_window , adjust = False ).mean ()
78+ long_ema = prices .ewm (span = long_window , adjust = False ).mean ()
79+ macd_line = short_ema - long_ema
80+ signal_line = macd_line .ewm (span = signal_window , adjust = False ).mean ()
81+
82+ signals = pd .Series (0 , index = prices .index )
83+ signals .loc [macd_line > signal_line ] = 2
84+ signals .loc [macd_line < signal_line ] = 1
85+
86+ signals = signals .shift (periods = 1 , fill_value = 0 )
6887 return signals
6988
7089
71- def _shift_signals_to_right (signals , shift = 1 ):
72- return np .concatenate (([0 ] * shift , signals [:- shift ]))
90+ def _roc_signals (prices , window = 10 ):
91+ """Generate trading signals based on the Rate of Change (ROC) indicator.
92+
93+ A buy signal (2) is generated when the ROC is positive,
94+ and a sell signal (1) is generated when the ROC is negative.
95+
96+ Args:
97+ prices (pd.Series): Series of asset prices.
98+ window (int): Window size for computing the ROC.
99+
100+ Returns:
101+ pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
102+ """
103+ _validate_input_window (window )
104+
105+ roc = prices .pct_change (periods = window - 1 )
106+
107+ signals = pd .Series (0 , index = prices .index , dtype = int )
108+ signals .loc [roc > 0 ] = 2
109+ signals .loc [roc < 0 ] = 1
110+
111+ signals = signals .shift (periods = 1 , fill_value = 0 )
112+ return signals
113+
114+
115+ def _rsi_signals (prices , window = 14 ):
116+ """Generate trading signals based on the Relative Strength Index (RSI).
117+
118+ A buy signal (2) is generated when RSI is below 30 (oversold),
119+ and a sell signal (1) is generated when RSI is above 70 (overbought).
120+
121+ Args:
122+ prices (pd.Series): Series of asset prices.
123+ window (int): Window size for computing RSI.
124+
125+ Returns:
126+ pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
127+ """
128+ _validate_input_window (window )
129+
130+ delta = prices .diff ()
131+ gain = delta .where (delta > 0 , 0 )
132+ loss = - delta .where (delta < 0 , 0 )
133+
134+ avg_gain = gain .rolling (window = window , min_periods = 1 ).mean ()
135+ avg_loss = loss .rolling (window = window , min_periods = 1 ).mean ()
136+
137+ rs = avg_gain / avg_loss
138+ rsi = 100 - (100 / (1 + rs ))
139+
140+ signals = pd .Series (0 , index = prices .index , dtype = int )
141+ upper_cutoff = 70
142+ lower_cutoff = 30
143+ signals .loc [rsi < lower_cutoff ] = 2
144+ signals .loc [rsi > upper_cutoff ] = 1
145+
146+ signals = signals .shift (periods = 1 , fill_value = 0 )
147+ return signals
148+
149+
150+ def _validate_input_method (method ):
151+ """Validate the input method for the `generate_signals` function.
152+
153+ Args:
154+ method (str): The signal generation method.
155+
156+ Raises:
157+ TypeError: If the method is not a string.
158+ ValueError: If the method is not one of the supported methods.
159+ """
160+ if not isinstance (method , str ):
161+ error_msg = (
162+ f"Invalid type for method: expected str, got { type (method ).__name__ } ."
163+ )
164+ raise TypeError (error_msg )
165+
166+ supported_methods = ["bollinger" , "macd" , "roc" , "rsi" ]
167+
168+ if method not in supported_methods :
169+ error_msg = (
170+ f"Invalid method '{ method } '. "
171+ f"Supported methods are: { ', ' .join (supported_methods )} ."
172+ )
173+ raise ValueError (error_msg )
174+
175+
176+ def _validate_input_bollinger_signals (window , num_std_dev ):
177+ """Validate inputs for the Bollinger Bands trading signal function.
178+
179+ Args:
180+ window (int): Window size for moving average.
181+ num_std_dev (int, float): Number of standard deviations for the bands.
182+ """
183+ _validate_input_window (window )
184+ _validate_input_num_std_dev (num_std_dev )
185+
186+
187+ def _validate_input_macd_signals (short_window , long_window , signal_window ):
188+ """Validate inputs for the MACD trading signal function.
189+
190+ Args:
191+ short_window (int): Window size for the short-term EMA.
192+ long_window (int): Window size for the long-term EMA.
193+ signal_window (int): Window size for the signal line EMA.
194+ """
195+ _validate_input_window (short_window )
196+ _validate_input_window (long_window )
197+ _validate_input_window (signal_window )
198+ _validate_window_relationships (short_window , long_window , signal_window )
199+
200+
201+ def _validate_input_window (window ):
202+ """Validate the window parameter.
203+
204+ Args:
205+ window (int): Window size for moving average.
206+
207+ Raises:
208+ TypeError: If window is not an integer.
209+ ValueError: If window is not greater than 1.
210+ """
211+ if not isinstance (window , int ):
212+ error_msg = f"'window' must be an integer, got { type (window ).__name__ } ."
213+ raise TypeError (error_msg )
214+
215+ if window <= 1 :
216+ error_msg = f"'window' must be greater than 1, got { window } ."
217+ raise ValueError (error_msg )
218+
219+
220+ def _validate_input_num_std_dev (num_std_dev ):
221+ """Validate the num_std_dev parameter for Bollinger Bands.
222+
223+ Args:
224+ num_std_dev (int, float): Number of standard deviations for the bands.
225+
226+ Raises:
227+ TypeError: If num_std_dev is not a number.
228+ ValueError: If num_std_dev is not positive.
229+ """
230+ if not isinstance (num_std_dev , int | float ):
231+ error_msg = f"'num_std_dev' must be a number, got { type (num_std_dev ).__name__ } ."
232+ raise TypeError (error_msg )
233+
234+ if num_std_dev <= 0 :
235+ error_msg = f"'num_std_dev' must be a positive number, got { num_std_dev } ."
236+ raise ValueError (error_msg )
237+
238+
239+ def _validate_window_relationships (short_window , long_window , signal_window ):
240+ """Validate logical relationships between MACD windows.
241+
242+ Args:
243+ short_window (int): Window size for the short-term EMA.
244+ long_window (int): Window size for the long-term EMA.
245+ signal_window (int): Window size for the signal line EMA.
246+
247+ Raises:
248+ ValueError: If:
249+ - `short_window` is greater than or equal to `long_window`.
250+ - `signal_window` is greater than `short_window`.
251+ """
252+ if short_window >= long_window :
253+ error_msg = (
254+ "'short_window' must be less than 'long_window', "
255+ f"got short_window={ short_window } and long_window={ long_window } ."
256+ )
257+ raise ValueError (error_msg )
258+
259+ if signal_window > short_window :
260+ error_msg = (
261+ "'signal_window' must be less than or equal to 'short_window',"
262+ f"got signal_window={ signal_window } and short_window={ short_window } ."
263+ )
264+ raise ValueError (error_msg )
0 commit comments