Skip to content

Commit 6be8cd8

Browse files
Add test cases for plot_portfolio.py. Bug fix for _calculate_trades, since the previous implementation falsy added one share.
1 parent c7eb088 commit 6be8cd8

File tree

2 files changed

+125
-34
lines changed

2 files changed

+125
-34
lines changed

src/backtest_bay/plot/plot_portfolio.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,40 +36,6 @@ def plot_portfolio(portfolio, title, tac):
3636
return fig
3737

3838

39-
def _calculate_portfolio_return(stock):
40-
"""Calculate the total return of the stock."""
41-
initial_value = stock.iloc[0]
42-
final_value = stock.iloc[-1]
43-
total_return = (final_value / initial_value) - 1
44-
portfolio_return = total_return * 100
45-
return portfolio_return
46-
47-
48-
def _calculate_annualized_return(stock):
49-
"""Calculate the annualized return for a stock."""
50-
initial_value = stock.iloc[0]
51-
final_value = stock.iloc[-1]
52-
total_return = (final_value / initial_value) - 1
53-
54-
years = _calculate_years(stock)
55-
annualized_return = ((1 + total_return) ** (1 / years) - 1) * 100
56-
57-
return annualized_return
58-
59-
60-
def _calculate_years(stock):
61-
"""Calculate the number of years between the first and last date of a stock."""
62-
total_seconds = (stock.index[-1] - stock.index[0]).total_seconds()
63-
years = total_seconds / pd.Timedelta(days=365).total_seconds()
64-
return years
65-
66-
67-
def _calculate_trades(shares):
68-
"""Calculate the number of trades by counting changes in the shares held."""
69-
trades = shares.diff().ne(0).sum()
70-
return trades
71-
72-
7339
def _create_portfolio_traces(portfolio):
7440
"""Create Plotly traces for cash and assets over time."""
7541
traces = [
@@ -127,3 +93,41 @@ def _create_plot_layout(title):
12793
"template": "plotly",
12894
}
12995
return layout
96+
97+
98+
def _calculate_portfolio_return(stock):
99+
"""Calculate the total return of the stock."""
100+
initial_value = stock.iloc[0]
101+
final_value = stock.iloc[-1]
102+
103+
if initial_value == 0:
104+
return float("nan")
105+
106+
total_return = (final_value / initial_value) - 1
107+
portfolio_return = total_return * 100
108+
return portfolio_return
109+
110+
111+
def _calculate_annualized_return(stock):
112+
"""Calculate the annualized return for a stock."""
113+
total_return = _calculate_portfolio_return(stock) / 100
114+
years = _calculate_years(stock.index)
115+
116+
if years == 0:
117+
return 0
118+
119+
annualized_return = ((1 + total_return) ** (1 / years) - 1) * 100
120+
return annualized_return
121+
122+
123+
def _calculate_years(stock_index):
124+
"""Calculate the number of years between the first and last date of a stock."""
125+
total_seconds = (stock_index[-1] - stock_index[0]).total_seconds()
126+
years = total_seconds / pd.Timedelta(days=365).total_seconds()
127+
return years
128+
129+
130+
def _calculate_trades(shares):
131+
"""Calculate the number of trades by counting changes in the shares held."""
132+
trades = shares.diff().fillna(0).ne(0).sum()
133+
return trades

tests/plot/test_plot_portfolio.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from backtest_bay.plot.plot_portfolio import (
6+
_calculate_annualized_return,
7+
_calculate_portfolio_return,
8+
_calculate_trades,
9+
_calculate_years,
10+
)
11+
12+
13+
# Tests for _calculate_portfolio_return
14+
@pytest.mark.parametrize(
15+
("stock", "expected_return"),
16+
[
17+
(pd.Series([100]), 0.0),
18+
(pd.Series([100, 150]), 50.0),
19+
(pd.Series([200, 100]), -50.0),
20+
(pd.Series([100, 100]), 0.0),
21+
(pd.Series([50, 75, 100]), 100.0),
22+
],
23+
)
24+
def test_calculate_portfolio_return_valid_calculation(stock, expected_return):
25+
"""Check if calculated return equals expected value."""
26+
result = _calculate_portfolio_return(stock)
27+
assert result == expected_return
28+
29+
30+
# Tests for _calculate_years
31+
@pytest.mark.parametrize(
32+
("index", "expected"),
33+
[
34+
(pd.to_datetime(["2020-01-01", "2020-12-31"]), 1),
35+
(pd.to_datetime(["2020-01-01", "2020-07-01"]), 0.5),
36+
(pd.to_datetime(["2010-01-01", "2020-01-01"]), 10),
37+
(pd.to_datetime(["2020-01-01", "2020-01-01"]), 0),
38+
(pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 12:00:00"]), 0.5 / 365),
39+
],
40+
)
41+
def test_calculate_years(index, expected):
42+
result = _calculate_years(index)
43+
assert np.isclose(result, expected, atol=1e-2)
44+
45+
46+
# Tests for _calculate_annualized_return
47+
@pytest.mark.parametrize(
48+
("stock", "expected"),
49+
[
50+
(
51+
pd.Series(
52+
[100, 100, 100],
53+
index=pd.to_datetime(["2020-01-01", "2020-07-01", "2021-01-01"]),
54+
),
55+
0,
56+
),
57+
(
58+
pd.Series([100, 110], index=pd.to_datetime(["2020-01-01", "2021-01-01"])),
59+
10.00,
60+
),
61+
(
62+
pd.Series([200, 100], index=pd.to_datetime(["2020-01-01", "2021-01-01"])),
63+
-50.00,
64+
),
65+
(
66+
pd.Series([100, 200], index=pd.to_datetime(["2018-01-01", "2021-01-01"])),
67+
round(((200 / 100) ** (1 / 3) - 1) * 100, 2),
68+
),
69+
],
70+
)
71+
def test_calculate_annualized_return(stock, expected):
72+
result = _calculate_annualized_return(stock)
73+
assert np.isclose(result, expected, atol=1e-1)
74+
75+
76+
# Tests for _calculate_trades
77+
@pytest.mark.parametrize(
78+
("shares", "expected"),
79+
[
80+
(pd.Series([10, 10, 10, 10]), 0),
81+
(pd.Series([0, 10, 0, 10, 0]), 4),
82+
(pd.Series([0, -10, 0, 0, 0]), 2),
83+
],
84+
)
85+
def test_calculate_trades(shares, expected):
86+
result = _calculate_trades(shares)
87+
assert result == expected

0 commit comments

Comments
 (0)