Skip to content

Commit 9399a6d

Browse files
Reduce input for helper functions of plot_signals.py so that only necessary dateframe columns are considered.
1 parent f4a622c commit 9399a6d

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

src/backtest_bay/plot/plot_signals.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44

55
def plot_signals(df, title):
66
"""Create Plotly figure to plot trading signals."""
7-
df = _map_signals_for_plotting(df)
7+
df["signal_plot"] = _map_signals_for_plotting(df["signal"])
88

99
fig = make_subplots(
1010
rows=2, cols=1, shared_xaxes=True, row_heights=[0.7, 0.3], vertical_spacing=0.1
1111
)
1212

13-
fig.add_trace(_create_candlestick_trace(df), row=1, col=1)
13+
fig.add_trace(
14+
_create_candlestick_trace(df[["Open", "High", "Low", "Close"]]), row=1, col=1
15+
)
1416

15-
buy_trace, sell_trace = _create_signal_traces(df)
17+
buy_trace, sell_trace = _create_signal_traces(df["signal_plot"])
1618
fig.add_trace(buy_trace, row=2, col=1)
1719
fig.add_trace(sell_trace, row=2, col=1)
1820

@@ -35,11 +37,10 @@ def plot_signals(df, title):
3537
return fig
3638

3739

38-
def _map_signals_for_plotting(df):
40+
def _map_signals_for_plotting(signal):
3941
"""Map signals to -1, 0, 1 for plotting."""
4042
signal_mapping = {0: 0, 1: -1, 2: 1}
41-
df["signal_plot"] = df["signal"].map(signal_mapping)
42-
return df
43+
return signal.map(signal_mapping)
4344

4445

4546
def _create_candlestick_trace(df):
@@ -54,20 +55,20 @@ def _create_candlestick_trace(df):
5455
)
5556

5657

57-
def _create_signal_traces(df):
58+
def _create_signal_traces(signal_plot):
5859
"""Create signal bar traces for buy and sell signals."""
59-
buy_signals = df["signal_plot"] > 0
60+
buy_signals = signal_plot > 0
6061
buy_trace = go.Bar(
61-
x=df.index[buy_signals],
62-
y=df.loc[buy_signals, "signal_plot"],
62+
x=signal_plot.index[buy_signals],
63+
y=signal_plot[buy_signals],
6364
marker_color="green",
6465
name="Buy Signal",
6566
)
6667

67-
sell_signals = df["signal_plot"] < 0
68+
sell_signals = signal_plot < 0
6869
sell_trace = go.Bar(
69-
x=df.index[sell_signals],
70-
y=df.loc[sell_signals, "signal_plot"],
70+
x=signal_plot.index[sell_signals],
71+
y=signal_plot[sell_signals],
7172
marker_color="red",
7273
name="Sell Signal",
7374
)

0 commit comments

Comments
 (0)