Skip to content

Commit 22d653f

Browse files
Add doc strings for plot_portfolio.py and plot_signals.py. Bug fix for _buy_and_hold_strategy, since not invested cash was not taken into account.
1 parent 59af977 commit 22d653f

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed

src/backtest_bay/plot/plot_portfolio.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,31 @@
1111

1212

1313
def plot_portfolio(portfolio, title, tac, cash):
14-
"""Main function to plot the portfolio performance and metrics."""
14+
"""Plots the portfolio performance and associated financial metrics.
15+
16+
Args:
17+
portfolio (pd.DataFrame): A DataFrame containing the portfolio's assets, shares,
18+
and prices.
19+
title (str): Title of the plot.
20+
tac (float): Transaction costs.
21+
cash (float): Initial cash.
22+
23+
Returns:
24+
(go.Figure) Figure with portfolio performance and metrics.
25+
"""
26+
# Note that there is no need to validate the inputs 'portfolio', 'tac' and 'cash',
27+
# since they are already checked in 'download_data.py' and 'backtest_signals.py".
1528
portfolio_return = _calculate_portfolio_return(portfolio["assets"])
1629
annualized_return = _calculate_annualized_return(portfolio["assets"])
1730
annualized_volatility = _calculate_annualized_volatility(portfolio["assets"])
31+
trades = _calculate_trades(portfolio["shares"])
1832

19-
# Benchmark: Buy and Hold Strategy
2033
portfolio["buy_and_hold"] = _buy_and_hold_strategy(cash, portfolio["Close"])
2134
buy_and_hold_return = _calculate_annualized_return(portfolio["buy_and_hold"])
2235
buy_and_hold_volatility = _calculate_annualized_volatility(
2336
portfolio["buy_and_hold"]
2437
)
2538

26-
trades = _calculate_trades(portfolio["shares"])
2739
fig = make_subplots(
2840
rows=2,
2941
cols=1,
@@ -33,7 +45,10 @@ def plot_portfolio(portfolio, title, tac, cash):
3345
specs=[[{"type": "xy"}], [{"type": "domain"}]],
3446
)
3547

36-
for trace in _create_portfolio_traces(portfolio):
48+
portfolio_traces = _create_portfolio_traces(
49+
portfolio.index, portfolio["cash"], portfolio["assets"]
50+
)
51+
for trace in portfolio_traces:
3752
fig.add_trace(trace, row=1, col=1)
3853

3954
metrics_table = _create_metrics_table(
@@ -52,13 +67,13 @@ def plot_portfolio(portfolio, title, tac, cash):
5267
return fig
5368

5469

55-
def _create_portfolio_traces(portfolio):
70+
def _create_portfolio_traces(index, cash, assets):
5671
"""Create Plotly traces for cash and assets over time."""
5772
traces = [
58-
go.Scatter(x=portfolio.index, y=portfolio["cash"], mode="lines", name="Cash"),
73+
go.Scatter(x=index, y=cash, mode="lines", name="Cash"),
5974
go.Scatter(
60-
x=portfolio.index,
61-
y=portfolio["assets"],
75+
x=index,
76+
y=assets,
6277
mode="lines",
6378
name="Assets (Cash + Holdings)",
6479
),
@@ -160,6 +175,7 @@ def _calculate_trades(shares):
160175

161176

162177
def _calculate_annualized_volatility(stock):
178+
"""Calculates the annualized volatility for a stock."""
163179
if len(stock) <= 1:
164180
return 0
165181

@@ -176,10 +192,13 @@ def _calculate_annualized_volatility(stock):
176192

177193

178194
def _buy_and_hold_strategy(initial_cash, prices):
195+
"""Calculates the portfolio development using a buy and hold strategy."""
179196
first_price = prices.iloc[0]
180197

181198
if first_price == 0:
182199
return 0
183200

184201
shares = np.floor(initial_cash / first_price)
185-
return shares * prices
202+
not_invested_cash = initial_cash - shares * first_price
203+
portfolio_buy_and_hold = shares * prices + not_invested_cash
204+
return portfolio_buy_and_hold

src/backtest_bay/plot/plot_signals.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,30 @@
99
pd.options.plotting.backend = "plotly"
1010

1111

12-
def plot_signals(df, title):
13-
"""Create Plotly figure to plot trading signals."""
14-
df["signal_plot"] = _map_signals_for_plotting(df["signal"])
12+
def plot_signals(data, title):
13+
"""Plots trading signals alongside candlestick charts.
14+
15+
Args:
16+
data (pd.DataFrame): A DataFrame containing stock price data and trading
17+
signals. Expected columns: "Open", "High","Low","Close","signal".
18+
title (str): The title of the plot.
19+
20+
Returns:
21+
(go.Figure): Figure with trading signals alongside candlestick charts.
22+
"""
23+
# Note that there is no need to validate the input 'data', since 'data' is already
24+
# checked in 'download_data.py'.
25+
data["signal_plot"] = _map_signals_for_plotting(data["signal"])
1526

1627
fig = make_subplots(
1728
rows=2, cols=1, shared_xaxes=True, row_heights=[0.7, 0.3], vertical_spacing=0.1
1829
)
1930

2031
fig.add_trace(
21-
_create_candlestick_trace(df[["Open", "High", "Low", "Close"]]), row=1, col=1
32+
_create_candlestick_trace(data[["Open", "High", "Low", "Close"]]), row=1, col=1
2233
)
2334

24-
buy_trace, sell_trace = _create_signal_traces(df["signal_plot"])
35+
buy_trace, sell_trace = _create_signal_traces(data["signal_plot"])
2536
fig.add_trace(buy_trace, row=2, col=1)
2637
fig.add_trace(sell_trace, row=2, col=1)
2738

@@ -38,7 +49,7 @@ def plot_signals(df, title):
3849
row=2,
3950
col=1,
4051
tickvals=[-1, 0, 1],
41-
ticktext=["Sell", "Hold", "Buy"],
52+
ticktext=["Sell", "", "Buy"],
4253
)
4354

4455
return fig

tests/plot/test_plot_portfolio.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def test_calculate_trades(shares, expected):
102102
],
103103
)
104104
def test_calculate_annualized_volatility(stock_prices, expected_volatility):
105+
"""Test if _calculate_annualized_volatility correctly calculates annualized
106+
volatility."""
105107
stock = pd.Series(
106108
stock_prices, index=pd.date_range(start="2022-01-01", periods=len(stock_prices))
107109
)
@@ -113,16 +115,15 @@ def test_calculate_annualized_volatility(stock_prices, expected_volatility):
113115
@pytest.mark.parametrize(
114116
("initial_cash", "prices", "expected"),
115117
[
116-
# Standard case with positive prices
117-
(100, pd.Series([10, 12, 15, 18]), pd.Series([100.0, 120.0, 150.0, 180.0])),
118+
# Enough cash
119+
(105, pd.Series([10, 12, 15, 18]), pd.Series([105.0, 125.0, 155.0, 185.0])),
118120
# Not enough cash to buy even one share
119-
(5, pd.Series([10, 12, 15, 18]), pd.Series([0.0, 0.0, 0.0, 0.0])),
121+
(5, pd.Series([10, 12, 15, 18]), pd.Series([5.0, 5.0, 5.0, 5.0])),
120122
# Cash exactly enough to buy one share
121123
(10, pd.Series([10, 12, 15, 18]), pd.Series([10.0, 12.0, 15.0, 18.0])),
122124
],
123125
)
124-
def test_buy_and_hold_strategy(initial_cash, prices, expected):
125-
"""Test if _calculate_annualized_volatility correctly calculates annualized
126-
volatility."""
126+
def test_buy_and_hold_strategy_correct_calculation(initial_cash, prices, expected):
127+
"""Test if _buy_and_hold_strategy correctly calculates buy and hold strategy."""
127128
result = _buy_and_hold_strategy(initial_cash, prices)
128129
pd.testing.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)