Skip to content

Commit 33796c1

Browse files
Add task_plot.py for plotting the backtesting results. Plots are included for visualizing signals and portfolio over time. Adjust task_backtesting.py to implement plotting.
1 parent 5f3c280 commit 33796c1

File tree

5 files changed

+270
-9
lines changed

5 files changed

+270
-9
lines changed

src/backtest_bay/backtest/task_backtest.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,45 @@
1515
id_backtest = (
1616
f"{row.stock}_{row.start_date}_{row.end_date}_" f"{row.interval}_{row.strategy}"
1717
)
18-
data_path = (
18+
stock_data_path = (
1919
BLD / "data" / f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}.pkl"
2020
)
2121
produces = BLD / "backtest" / f"{id_backtest}.pkl"
2222
strategy = row.strategy
2323

2424
@pytask.task(id=id_backtest)
2525
def task_backtest(
26-
scripts=scripts, data_path=data_path, produces=produces, strategy=strategy
26+
scripts=scripts,
27+
stock_data_path=stock_data_path,
28+
produces=produces,
29+
strategy=strategy,
2730
):
28-
data = pd.read_pickle(data_path)
29-
signals = generate_signals(data=data, method=strategy)
30-
portfolio = backtest_signals(
31-
data=data,
31+
stock_data = pd.read_pickle(stock_data_path)
32+
signals = generate_signals(data=stock_data, method=strategy)
33+
backtested_portfolio = backtest_signals(
34+
data=stock_data,
3235
signals=signals,
3336
initial_cash=INITIAL_CASH,
3437
tac=TAC,
3538
trade_pct=TRADE_PCT,
3639
)
37-
portfolio.to_pickle(produces)
40+
41+
merged_portfolio = _merge_stock_data_with_portfolio(
42+
stock_data, backtested_portfolio
43+
)
44+
45+
merged_portfolio.to_pickle(produces)
46+
47+
48+
def _merge_stock_data_with_portfolio(data, portfolio):
49+
"""Merge data with portfolio using the index.
50+
51+
Args:
52+
data (pd.DataFrame): DataFrame with downloaded data.
53+
portfolio (pd.DataFrame): DataFrame to be merged with data using the index.
54+
55+
Returns:
56+
pd.DataFrame: Merged DataFrame.
57+
"""
58+
data.columns = data.columns.droplevel(1)
59+
return data.merge(portfolio, how="left", left_index=True, right_index=True)

src/backtest_bay/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
INTERVALS = ["1d"]
1818
STRATEGIES = ["bollinger", "macd", "roc", "rsi"]
1919

20-
INITIAL_CASH = 1000
21-
TAC = 0.05
20+
INITIAL_CASH = 1000000
21+
TAC = 0.005
2222
TRADE_PCT = 0.05
2323

2424
# Define PARAMS using input data
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import pandas as pd
2+
import plotly.graph_objects as go
3+
from plotly.subplots import make_subplots
4+
5+
6+
def plot_portfolio(portfolio, title, tac):
7+
"""Main function to plot the portfolio performance and metrics."""
8+
portfolio_return = _calculate_portfolio_return(portfolio["assets"])
9+
annualized_return = _calculate_annualized_return(portfolio["assets"])
10+
buy_and_hold_return = _calculate_annualized_return(portfolio["Close"])
11+
trades = _calculate_trades(portfolio["shares"])
12+
13+
fig = make_subplots(
14+
rows=2,
15+
cols=1,
16+
shared_xaxes=True,
17+
row_heights=[0.70, 0.30],
18+
vertical_spacing=0.2,
19+
specs=[[{"type": "xy"}], [{"type": "domain"}]],
20+
)
21+
22+
for trace in _create_portfolio_traces(portfolio):
23+
fig.add_trace(trace, row=1, col=1)
24+
25+
metrics_table = _create_metrics_table(
26+
portfolio_return, annualized_return, trades, buy_and_hold_return, tac
27+
)
28+
fig.add_trace(metrics_table, row=2, col=1)
29+
30+
fig.update_layout(_create_plot_layout(title))
31+
32+
return fig
33+
34+
35+
def _calculate_portfolio_return(stock):
36+
"""Calculate the total return of the stock."""
37+
initial_value = stock.iloc[0]
38+
final_value = stock.iloc[-1]
39+
total_return = (final_value / initial_value) - 1
40+
portfolio_return = total_return * 100
41+
return portfolio_return
42+
43+
44+
def _calculate_annualized_return(stock):
45+
"""Calculate the annualized return for a stock."""
46+
initial_value = stock.iloc[0]
47+
final_value = stock.iloc[-1]
48+
total_return = (final_value / initial_value) - 1
49+
50+
years = _calculate_years(stock)
51+
annualized_return = ((1 + total_return) ** (1 / years) - 1) * 100
52+
53+
return annualized_return
54+
55+
56+
def _calculate_years(stock):
57+
"""Calculate the number of years between the first and last date of a stock."""
58+
total_seconds = (stock.index[-1] - stock.index[0]).total_seconds()
59+
years = total_seconds / pd.Timedelta(days=365).total_seconds()
60+
return years
61+
62+
63+
def _calculate_trades(stock):
64+
"""Calculate the number of trades by counting changes in the shares held."""
65+
trades = stock.diff().ne(0).sum()
66+
return trades
67+
68+
69+
def _create_portfolio_traces(portfolio):
70+
"""Create Plotly traces for cash and assets over time."""
71+
traces = [
72+
go.Scatter(x=portfolio.index, y=portfolio["cash"], mode="lines", name="Cash"),
73+
go.Scatter(
74+
x=portfolio.index,
75+
y=portfolio["assets"],
76+
mode="lines",
77+
name="Assets (Cash + Holdings)",
78+
),
79+
]
80+
return traces
81+
82+
83+
def _create_metrics_table(
84+
portfolio_return, annualized_return, trades, buy_and_hold_return, tac
85+
):
86+
"""Create a Plotly table for the portfolio metrics."""
87+
table = go.Table(
88+
header={
89+
"values": ["Metric", "Value"],
90+
"fill_color": "lightgrey",
91+
"align": "center",
92+
},
93+
cells={
94+
"values": [
95+
[
96+
"Total Return",
97+
"Annualized Return",
98+
"Trades",
99+
"Assumed TAC",
100+
"Benchmark: Annualized Buy and Hold Return",
101+
],
102+
[
103+
f"{portfolio_return:.2f}%",
104+
f"{annualized_return:.2f}%",
105+
trades,
106+
f"{tac * 100:.2f}%",
107+
f"{buy_and_hold_return:.2f}%",
108+
],
109+
],
110+
"align": "left",
111+
},
112+
)
113+
return table
114+
115+
116+
def _create_plot_layout(title):
117+
"""Create the layout configuration for the Plotly figure."""
118+
layout = {
119+
"title": title,
120+
"xaxis_title": "Date",
121+
"yaxis_title": "Value",
122+
"legend_title": "Portfolio Components",
123+
"template": "plotly",
124+
}
125+
return layout
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import plotly.graph_objects as go
2+
from plotly.subplots import make_subplots
3+
4+
5+
def plot_signals(df, title):
6+
"""Create Plotly figure to plot trading signals."""
7+
df = _map_signals_for_plotting(df)
8+
9+
fig = make_subplots(
10+
rows=2, cols=1, shared_xaxes=True, row_heights=[0.7, 0.3], vertical_spacing=0.1
11+
)
12+
13+
fig.add_trace(_create_candlestick_trace(df), row=1, col=1)
14+
15+
buy_trace, sell_trace = _create_signal_traces(df)
16+
fig.add_trace(buy_trace, row=2, col=1)
17+
fig.add_trace(sell_trace, row=2, col=1)
18+
19+
fig.update_layout(
20+
title=title,
21+
xaxis_title="Date",
22+
yaxis_title="Stock",
23+
xaxis_rangeslider_visible=False,
24+
legend_title="Legend",
25+
)
26+
27+
fig.update_yaxes(
28+
title_text="Signal",
29+
row=2,
30+
col=1,
31+
tickvals=[-1, 0, 1],
32+
ticktext=["Sell", "Hold", "Buy"],
33+
)
34+
35+
return fig
36+
37+
38+
def _map_signals_for_plotting(df):
39+
"""Map signals to -1, 0, 1 for plotting."""
40+
signal_mapping = {0: 0, 1: -1, 2: 1}
41+
df["signal_plot"] = df["signal"].map(signal_mapping)
42+
return df
43+
44+
45+
def _create_candlestick_trace(df):
46+
"""Create candlestick trace for stock data."""
47+
return go.Candlestick(
48+
x=df.index,
49+
open=df["Open"],
50+
high=df["High"],
51+
low=df["Low"],
52+
close=df["Close"],
53+
name="Stock",
54+
)
55+
56+
57+
def _create_signal_traces(df):
58+
"""Create signal bar traces for buy and sell signals."""
59+
buy_signals = df["signal_plot"] > 0
60+
buy_trace = go.Bar(
61+
x=df.index[buy_signals],
62+
y=df.loc[buy_signals, "signal_plot"],
63+
marker_color="green",
64+
name="Buy Signal",
65+
)
66+
67+
sell_signals = df["signal_plot"] < 0
68+
sell_trace = go.Bar(
69+
x=df.index[sell_signals],
70+
y=df.loc[sell_signals, "signal_plot"],
71+
marker_color="red",
72+
name="Sell Signal",
73+
)
74+
75+
return buy_trace, sell_trace

src/backtest_bay/plot/task_plot.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pandas as pd
2+
import pytask
3+
4+
from backtest_bay.config import BLD, PARAMS, SRC, TAC
5+
from backtest_bay.plot.plot_portfolio import plot_portfolio
6+
from backtest_bay.plot.plot_signals import plot_signals
7+
8+
scripts = [
9+
SRC / "config.py",
10+
SRC / "plot" / "plot_signals.py",
11+
SRC / "plot" / "plot_portfolio.py",
12+
]
13+
14+
for row in PARAMS.itertuples(index=False):
15+
id_backtest = (
16+
f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}_{row.strategy}"
17+
)
18+
backtest_path = BLD / "backtest" / f"{id_backtest}.pkl"
19+
plot_path = f"{row.stock}_{row.start_date}_{row.end_date}_{row.interval}"
20+
produces = {
21+
"plot_signals": BLD / "plot" / plot_path / f"plot_signals_{row.strategy}.html",
22+
"plot_portfolio": BLD
23+
/ "plot"
24+
/ plot_path
25+
/ f"plot_portfolio_{row.strategy}.html",
26+
}
27+
28+
@pytask.task(id=id_backtest)
29+
def task_plot(
30+
scripts=scripts,
31+
backtest_path=backtest_path,
32+
produces=produces,
33+
id_backtest=id_backtest,
34+
):
35+
portfolio = pd.read_pickle(backtest_path)
36+
fig = plot_signals(portfolio, id_backtest)
37+
fig.write_html(produces.get("plot_signals"))
38+
fig = plot_portfolio(portfolio, id_backtest, TAC)
39+
fig.write_html(produces.get("plot_portfolio"))

0 commit comments

Comments
 (0)