Skip to content

Commit 785a99c

Browse files
Add strategy and backtest hyperparameters to config.py. Add scripts to implement simple trading strategies and a scripts to backtest the strategies.
1 parent 75ee8cb commit 785a99c

File tree

7 files changed

+240
-27
lines changed

7 files changed

+240
-27
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import math
2+
3+
import pandas as pd
4+
5+
6+
def execute_buy(cash, price, shares, trade_vol, tac):
7+
buy_shares = math.floor(trade_vol / (price * (1 + tac)))
8+
if buy_shares < 1:
9+
return cash, shares
10+
cost = buy_shares * price * (1 + tac)
11+
if cost > cash:
12+
return cash, shares
13+
cash -= cost
14+
shares += buy_shares
15+
return cash, shares
16+
17+
18+
def execute_sell(cash, price, shares, trade_vol, tac):
19+
sell_shares = math.floor(trade_vol / (price * (1 - tac)))
20+
21+
if sell_shares > shares:
22+
if shares < 1:
23+
return cash, shares
24+
sell_shares = shares
25+
26+
profit = sell_shares * price * (1 - tac)
27+
cash += profit
28+
shares -= sell_shares
29+
return cash, shares
30+
31+
32+
def update_portfolio(cash, shares, price):
33+
holdings = shares * price
34+
assets = cash + holdings
35+
return assets, holdings
36+
37+
38+
def backtest_signals(data, signals, initial_cash, tac, trade_pct, price_col="Close"):
39+
prices = data[price_col].squeeze()
40+
cash, holdings, shares = initial_cash, 0.0, 0
41+
assets = cash + holdings
42+
portfolio = []
43+
44+
buy_signal = 2
45+
sell_signal = 1
46+
47+
for price, signal in zip(prices, signals, strict=False):
48+
trade_vol = trade_pct * assets
49+
50+
if signal == buy_signal:
51+
cash, shares = execute_buy(cash, price, shares, trade_vol, tac)
52+
53+
elif signal == sell_signal:
54+
cash, shares = execute_sell(cash, price, shares, trade_vol, tac)
55+
56+
assets, holdings = update_portfolio(cash, shares, price)
57+
portfolio.append((cash, holdings, shares, assets, signal, price))
58+
59+
portfolio = pd.DataFrame(
60+
portfolio, columns=["cash", "holdings", "shares", "total", "signal", "price"]
61+
)
62+
return portfolio
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def generate_signals(data, method, **kwargs):
6+
"""Derive trading signals using the specified method and parameters.
7+
8+
Args:
9+
data (pd.DataFrame): DataFrame containing asset price data
10+
(must include 'Close' column).
11+
method (str): The signal generation method. Currently supported:
12+
- 'bollinger_bands': Uses Bollinger Bands for signal generation.
13+
**kwargs: Additional parameters specific to the chosen method
14+
(e.g., window size, number of standard deviations).
15+
16+
Returns:
17+
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
18+
"""
19+
if method == "flip":
20+
signal = _flip_signals(prices=data["Close"].squeeze())
21+
if method == "bollinger":
22+
signal = _bollinger_signals(prices=data["Close"].squeeze(), **kwargs)
23+
return signal
24+
25+
26+
def _bollinger_signals(prices, window=20, num_std_dev=2):
27+
"""Generate anticyclical trading signals based on Bollinger Bands.
28+
29+
Bollinger Bands are computed using a moving average and standard deviations
30+
to identify overbought (sell signal) and oversold (buy signal) conditions.
31+
32+
Args:
33+
prices (pd.Series): Series of asset prices.
34+
window (int): Window size for Bollinger Bands calculation (default is 20).
35+
num_std_dev (float): Number of standard deviations for the bands (default is 2).
36+
37+
Returns:
38+
pd.Series: Trading signals (2: buy, 1: sell, 0: do nothing).
39+
"""
40+
moving_avg = prices.rolling(window=window).mean().fillna(0)
41+
std_dev = prices.rolling(window=window).std()
42+
upper_band = moving_avg + (num_std_dev * std_dev)
43+
lower_band = moving_avg - (num_std_dev * std_dev)
44+
45+
signals = pd.Series(0, index=prices.index)
46+
signals[prices < lower_band] = 2
47+
signals[prices > upper_band] = 1
48+
49+
signals = _shift_signals_to_right(signals)
50+
return pd.Series(signals, index=prices.index)
51+
52+
53+
def _flip_signals(prices):
54+
"""Generate trading signals based on price changes from the previous price.
55+
56+
Args:
57+
prices (np.Series): Series of asset prices without index.
58+
59+
Returns:
60+
np.ndarray: Trading signals (2: buy if previous price is lower,
61+
1: sell if previous price is higher,
62+
0: do nothing for the first price).
63+
"""
64+
price_diff = np.diff(prices, prepend=prices.iloc[0])
65+
signals = np.zeros(len(prices), dtype=int)
66+
signals[1:][price_diff[1:] > 0] = 1
67+
signals[1:][price_diff[1:] < 0] = 2
68+
return signals
69+
70+
71+
def _shift_signals_to_right(signals, shift=1):
72+
return np.concatenate(([0] * shift, signals[:-shift]))
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pandas as pd
2+
import pytask
3+
4+
from backtest_bay.analysis.backtest_signals import backtest_signals
5+
from backtest_bay.analysis.generate_signals import generate_signals
6+
from backtest_bay.config import BLD, INITIAL_CASH, PARAMS, SRC, TAC, TRADE_PCT
7+
8+
scripts = [
9+
SRC / "config.py",
10+
SRC / "analysis" / "generate_signals.py",
11+
SRC / "analysis" / "backtest_signals.py",
12+
]
13+
14+
for row in PARAMS.itertuples(index=False):
15+
id_backtest = (
16+
f"{row.stock}_{row.start_date}_{row.end_date}_" f"{row.interval}_{row.strategy}"
17+
)
18+
data_path = BLD / f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}.pkl"
19+
data = pd.read_pickle(data_path)
20+
produces = BLD / f"{id_backtest}.pkl"
21+
22+
signals = generate_signals(data, row.strategy)
23+
24+
@pytask.task(id=id_backtest)
25+
def task_backtest_signals(
26+
data=data,
27+
signals=signals,
28+
initial_cash=INITIAL_CASH,
29+
tac=TAC,
30+
investment_pct=TRADE_PCT,
31+
scripts=scripts,
32+
data_path=data_path,
33+
produces=produces,
34+
):
35+
portfolio = backtest_signals(data, signals, initial_cash, tac, investment_pct)
36+
portfolio.to_pickle(produces)

src/backtest_bay/config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
"""All the general configuration of the project."""
22

3+
import itertools
34
from pathlib import Path
45

6+
import pandas as pd
7+
58
SRC = Path(__file__).parent.resolve()
69
ROOT = SRC.joinpath("..", "..").resolve()
710

811
BLD = ROOT.joinpath("bld").resolve()
912

1013
# Configure input data
1114
STOCKS = ["AAPL", "MSFT"]
12-
START_DATES = ["2022-01-01", "2022-01-01"]
13-
END_DATES = ["2025-01-01", "2025-01-01"]
14-
INTERVALS = ["1d", "1d"]
15+
START_DATES = ["2022-01-01"]
16+
END_DATES = ["2025-01-01"]
17+
INTERVALS = ["1d"]
18+
STRATEGIES = ["flip", "bollinger"]
19+
20+
INITIAL_CASH = 1000
21+
TAC = 0.05
22+
TRADE_PCT = 0.05
23+
24+
# Define PARAMS using input data
25+
PARAMS = pd.DataFrame(
26+
list(itertools.product(STOCKS, START_DATES, END_DATES, INTERVALS, STRATEGIES)),
27+
columns=["stock", "start_date", "end_date", "interval", "strategy"],
28+
)

src/backtest_bay/data/task_download_data.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import pytask
22

3-
from backtest_bay.config import BLD, END_DATES, INTERVALS, SRC, START_DATES, STOCKS
3+
from backtest_bay.config import BLD, PARAMS, SRC
44
from backtest_bay.data.download_data import download_data
55

6-
dependencies = [SRC / "config.py", SRC / "data" / "download_data.py"]
6+
scripts = [SRC / "config.py", SRC / "data" / "download_data.py"]
77

8-
for index, _ in enumerate(STOCKS):
9-
_id = (
10-
STOCKS[index]
11-
+ "_"
12-
+ START_DATES[index]
13-
+ "_"
14-
+ END_DATES[index]
15-
+ "_"
16-
+ INTERVALS[index]
17-
)
8+
data_to_download = PARAMS.drop_duplicates(
9+
subset=["stock", "start_date", "end_date", "interval"]
10+
)
1811

19-
@pytask.task(id=_id)
12+
for row in data_to_download.itertuples(index=False):
13+
id_download = f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}"
14+
15+
produces = BLD / f"{id_download}.pkl"
16+
17+
if produces.exists():
18+
continue
19+
20+
@pytask.task(id=id_download)
2021
def task_download_data(
21-
symbol=STOCKS[index],
22-
start_date=START_DATES[index],
23-
end_date=END_DATES[index],
24-
interval=INTERVALS[index],
25-
depends_on=dependencies,
26-
produces=BLD / (_id + ".pkl"),
22+
symbol=row.stock,
23+
start_date=row.start_date,
24+
end_date=row.end_date,
25+
interval=row.interval,
26+
depends_on=scripts,
27+
produces=produces,
2728
):
2829
data = download_data(
2930
symbol=symbol, start_date=start_date, end_date=end_date, interval=interval
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from backtest_bay.analysis.generate_signals import _bollinger_signals
5+
6+
# Tests for _bollinger_signals
7+
8+
9+
def test_bollinger_signals_constant_prices():
10+
prices = pd.Series([100] * 30)
11+
signals = _bollinger_signals(prices)
12+
expected = pd.Series([0] * 30, index=prices.index)
13+
pd.testing.assert_series_equal(signals, expected)
14+
15+
16+
def test_bollinger_signals_buy_signal():
17+
prices = pd.Series([100] * 20 + [90])
18+
signals = _bollinger_signals(prices)
19+
buy_signal = 2
20+
assert signals.iloc[-1] == buy_signal
21+
22+
23+
def test_bollinger_signals_sell_signal():
24+
prices = pd.Series([100] * 20 + [110])
25+
signals = _bollinger_signals(prices)
26+
assert signals.iloc[-1] == 1
27+
28+
29+
def test_bollinger_signals_window_effect():
30+
prices = pd.Series(np.linspace(100, 200, 25))
31+
signals_small_window = _bollinger_signals(prices, window=2, num_std_dev=1)
32+
signals_large_window = _bollinger_signals(prices, window=20, num_std_dev=1)
33+
assert not signals_small_window.equals(signals_large_window)

tests/data/test_download_data.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# Tests for _validate_symbol
1414
def test_validate_symbol_valid_cases():
1515
_validate_symbol("AAPL")
16-
_validate_symbol("MSFT")
1716

1817

1918
def test_validate_symbol_invalid_invalid_symbol():
@@ -98,8 +97,6 @@ def test_validate_date_format_non_string():
9897

9998

10099
# tests for _validate_date_range
101-
102-
103100
def test_valid_date_range():
104101
_validate_date_range("2024-01-01", "2024-12-31")
105102

@@ -110,8 +107,6 @@ def test_invalid_date_range():
110107

111108

112109
# tests for _validate_output
113-
114-
115110
def test_validate_output_valid_input():
116111
# generate typical Yahoo Finance output format
117112
symbol = "AAPL"

0 commit comments

Comments
 (0)