Skip to content

Commit 27aa09e

Browse files
Add trading strategies macd, roc and rsi. Implement input validation for hyperparamters of the trading strategies. Add test cases for all implemented functions. Adjust config.py for additional trading strategies.
1 parent ea20faa commit 27aa09e

File tree

3 files changed

+446
-41
lines changed

3 files changed

+446
-41
lines changed
Lines changed: 215 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pandas as pd
32

43

@@ -10,16 +9,22 @@ def generate_signals(data, method, **kwargs):
109
(must include 'Close' column).
1110
method (str): The signal generation method. Currently supported:
1211
- 'bollinger_bands': Uses Bollinger Bands for signal generation.
13-
**kwargs: Additional parameters specific to the chosen method
14-
(e.g., window size, number of standard deviations).
12+
**kwargs: Additional parameters specific to the chosen method.
1513
1614
Returns:
1715
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
1816
"""
19-
if method == "flip":
20-
signal = _flip_signals(prices=data["Close"].squeeze())
17+
_validate_input_method(method)
18+
closing_prices = data["Close"].squeeze()
19+
2120
if method == "bollinger":
22-
signal = _bollinger_signals(prices=data["Close"].squeeze(), **kwargs)
21+
signal = _bollinger_signals(prices=closing_prices, **kwargs)
22+
if method == "macd":
23+
signal = _macd_signals(prices=closing_prices, **kwargs)
24+
if method == "roc":
25+
signal = _roc_signals(prices=closing_prices, **kwargs)
26+
if method == "rsi":
27+
signal = _rsi_signals(prices=closing_prices, **kwargs)
2328
return signal
2429

2530

@@ -37,36 +42,223 @@ def _bollinger_signals(prices, window=20, num_std_dev=2):
3742
Returns:
3843
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
3944
"""
40-
moving_avg = prices.rolling(window=window).mean().fillna(0)
45+
_validate_input_bollinger_signals(window, num_std_dev)
46+
47+
moving_avg = prices.rolling(window=window).mean()
4148
std_dev = prices.rolling(window=window).std()
4249
upper_band = moving_avg + (num_std_dev * std_dev)
4350
lower_band = moving_avg - (num_std_dev * std_dev)
4451

4552
signals = pd.Series(0, index=prices.index)
46-
signals[prices < lower_band] = 2
47-
signals[prices > upper_band] = 1
53+
signals.loc[prices < lower_band] = 2
54+
signals.loc[prices > upper_band] = 1
4855

49-
signals = _shift_signals_to_right(signals)
50-
return pd.Series(signals, index=prices.index)
56+
signals = signals.shift(periods=1, fill_value=0)
57+
return signals
5158

5259

53-
def _flip_signals(prices):
54-
"""Generate trading signals based on price changes from the previous price.
60+
def _macd_signals(prices, short_window=12, long_window=26, signal_window=9):
61+
"""Generate trading signals based on the MACD indicator.
62+
63+
A buy signal (2) is generated when the MACD line crosses above the Signal Line.
64+
A sell signal (1) is generated when the MACD line crosses below the Signal Line.
5565
5666
Args:
57-
prices (np.Series): Series of asset prices without index.
67+
prices (pd.Series): Series of asset prices.
68+
short_window (int): Window size for the short EMA (default: 12).
69+
long_window (int): Window size for the long EMA (default: 26).
70+
signal_window (int): Window size for the signal line EMA (default: 9).
5871
5972
Returns:
60-
np.ndarray: Trading signals (2: buy if previous price is lower,
61-
1: sell if previous price is higher,
62-
0: do nothing for the first price).
73+
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
6374
"""
64-
price_diff = np.diff(prices, prepend=prices.iloc[0])
65-
signals = np.zeros(len(prices), dtype=int)
66-
signals[1:][price_diff[1:] > 0] = 1
67-
signals[1:][price_diff[1:] < 0] = 2
75+
_validate_input_macd_signals(short_window, long_window, signal_window)
76+
77+
short_ema = prices.ewm(span=short_window, adjust=False).mean()
78+
long_ema = prices.ewm(span=long_window, adjust=False).mean()
79+
macd_line = short_ema - long_ema
80+
signal_line = macd_line.ewm(span=signal_window, adjust=False).mean()
81+
82+
signals = pd.Series(0, index=prices.index)
83+
signals.loc[macd_line > signal_line] = 2
84+
signals.loc[macd_line < signal_line] = 1
85+
86+
signals = signals.shift(periods=1, fill_value=0)
6887
return signals
6988

7089

71-
def _shift_signals_to_right(signals, shift=1):
72-
return np.concatenate(([0] * shift, signals[:-shift]))
90+
def _roc_signals(prices, window=10):
91+
"""Generate trading signals based on the Rate of Change (ROC) indicator.
92+
93+
A buy signal (2) is generated when the ROC is positive,
94+
and a sell signal (1) is generated when the ROC is negative.
95+
96+
Args:
97+
prices (pd.Series): Series of asset prices.
98+
window (int): Window size for computing the ROC.
99+
100+
Returns:
101+
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
102+
"""
103+
_validate_input_window(window)
104+
105+
roc = prices.pct_change(periods=window - 1)
106+
107+
signals = pd.Series(0, index=prices.index, dtype=int)
108+
signals.loc[roc > 0] = 2
109+
signals.loc[roc < 0] = 1
110+
111+
signals = signals.shift(periods=1, fill_value=0)
112+
return signals
113+
114+
115+
def _rsi_signals(prices, window=14):
116+
"""Generate trading signals based on the Relative Strength Index (RSI).
117+
118+
A buy signal (2) is generated when RSI is below 30 (oversold),
119+
and a sell signal (1) is generated when RSI is above 70 (overbought).
120+
121+
Args:
122+
prices (pd.Series): Series of asset prices.
123+
window (int): Window size for computing RSI.
124+
125+
Returns:
126+
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
127+
"""
128+
_validate_input_window(window)
129+
130+
delta = prices.diff()
131+
gain = delta.where(delta > 0, 0)
132+
loss = -delta.where(delta < 0, 0)
133+
134+
avg_gain = gain.rolling(window=window, min_periods=1).mean()
135+
avg_loss = loss.rolling(window=window, min_periods=1).mean()
136+
137+
rs = avg_gain / avg_loss
138+
rsi = 100 - (100 / (1 + rs))
139+
140+
signals = pd.Series(0, index=prices.index, dtype=int)
141+
upper_cutoff = 70
142+
lower_cutoff = 30
143+
signals.loc[rsi < lower_cutoff] = 2
144+
signals.loc[rsi > upper_cutoff] = 1
145+
146+
signals = signals.shift(periods=1, fill_value=0)
147+
return signals
148+
149+
150+
def _validate_input_method(method):
151+
"""Validate the input method for the `generate_signals` function.
152+
153+
Args:
154+
method (str): The signal generation method.
155+
156+
Raises:
157+
TypeError: If the method is not a string.
158+
ValueError: If the method is not one of the supported methods.
159+
"""
160+
if not isinstance(method, str):
161+
error_msg = (
162+
f"Invalid type for method: expected str, got {type(method).__name__}."
163+
)
164+
raise TypeError(error_msg)
165+
166+
supported_methods = ["bollinger", "macd", "roc", "rsi"]
167+
168+
if method not in supported_methods:
169+
error_msg = (
170+
f"Invalid method '{method}'. "
171+
f"Supported methods are: {', '.join(supported_methods)}."
172+
)
173+
raise ValueError(error_msg)
174+
175+
176+
def _validate_input_bollinger_signals(window, num_std_dev):
177+
"""Validate inputs for the Bollinger Bands trading signal function.
178+
179+
Args:
180+
window (int): Window size for moving average.
181+
num_std_dev (int, float): Number of standard deviations for the bands.
182+
"""
183+
_validate_input_window(window)
184+
_validate_input_num_std_dev(num_std_dev)
185+
186+
187+
def _validate_input_macd_signals(short_window, long_window, signal_window):
188+
"""Validate inputs for the MACD trading signal function.
189+
190+
Args:
191+
short_window (int): Window size for the short-term EMA.
192+
long_window (int): Window size for the long-term EMA.
193+
signal_window (int): Window size for the signal line EMA.
194+
"""
195+
_validate_input_window(short_window)
196+
_validate_input_window(long_window)
197+
_validate_input_window(signal_window)
198+
_validate_window_relationships(short_window, long_window, signal_window)
199+
200+
201+
def _validate_input_window(window):
202+
"""Validate the window parameter.
203+
204+
Args:
205+
window (int): Window size for moving average.
206+
207+
Raises:
208+
TypeError: If window is not an integer.
209+
ValueError: If window is not greater than 1.
210+
"""
211+
if not isinstance(window, int):
212+
error_msg = f"'window' must be an integer, got {type(window).__name__}."
213+
raise TypeError(error_msg)
214+
215+
if window <= 1:
216+
error_msg = f"'window' must be greater than 1, got {window}."
217+
raise ValueError(error_msg)
218+
219+
220+
def _validate_input_num_std_dev(num_std_dev):
221+
"""Validate the num_std_dev parameter for Bollinger Bands.
222+
223+
Args:
224+
num_std_dev (int, float): Number of standard deviations for the bands.
225+
226+
Raises:
227+
TypeError: If num_std_dev is not a number.
228+
ValueError: If num_std_dev is not positive.
229+
"""
230+
if not isinstance(num_std_dev, int | float):
231+
error_msg = f"'num_std_dev' must be a number, got {type(num_std_dev).__name__}."
232+
raise TypeError(error_msg)
233+
234+
if num_std_dev <= 0:
235+
error_msg = f"'num_std_dev' must be a positive number, got {num_std_dev}."
236+
raise ValueError(error_msg)
237+
238+
239+
def _validate_window_relationships(short_window, long_window, signal_window):
240+
"""Validate logical relationships between MACD windows.
241+
242+
Args:
243+
short_window (int): Window size for the short-term EMA.
244+
long_window (int): Window size for the long-term EMA.
245+
signal_window (int): Window size for the signal line EMA.
246+
247+
Raises:
248+
ValueError: If:
249+
- `short_window` is greater than or equal to `long_window`.
250+
- `signal_window` is greater than `short_window`.
251+
"""
252+
if short_window >= long_window:
253+
error_msg = (
254+
"'short_window' must be less than 'long_window', "
255+
f"got short_window={short_window} and long_window={long_window}."
256+
)
257+
raise ValueError(error_msg)
258+
259+
if signal_window > short_window:
260+
error_msg = (
261+
"'signal_window' must be less than or equal to 'short_window',"
262+
f"got signal_window={signal_window} and short_window={short_window}."
263+
)
264+
raise ValueError(error_msg)

src/backtest_bay/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
START_DATES = ["2022-01-01"]
1616
END_DATES = ["2025-01-01"]
1717
INTERVALS = ["1d"]
18-
STRATEGIES = ["flip", "bollinger"]
18+
STRATEGIES = ["bollinger", "macd", "roc", "rsi"]
1919

2020
INITIAL_CASH = 1000
2121
TAC = 0.05

0 commit comments

Comments
 (0)