Skip to content

Commit 9b03569

Browse files
Implement test cases for merge_data_with_backtest_portfolio and move the main function to backtest_signals.py.
1 parent 655c792 commit 9b03569

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

src/backtest_bay/backtest/backtest_signals.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,16 @@ def _validate_price_col(data, price_col):
236236
if price_col not in data.columns:
237237
error_msg = f"data must contain a '{price_col}' column."
238238
raise ValueError(error_msg)
239+
240+
241+
def merge_data_with_backtest_portfolio(data, portfolio):
242+
"""Merge downloaded data with backtested portfolio using the index.
243+
244+
Args:
245+
data (pd.DataFrame): DataFrame with downloaded data.
246+
portfolio (pd.DataFrame): DataFrame to be merged with data using the index.
247+
248+
Returns:
249+
pd.DataFrame: Merged DataFrame.
250+
"""
251+
return data.merge(portfolio, how="left", left_index=True, right_index=True)

src/backtest_bay/backtest/task_backtest.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pandas as pd
22
import pytask
33

4-
from backtest_bay.backtest.backtest_signals import backtest_signals
4+
from backtest_bay.backtest.backtest_signals import (
5+
backtest_signals,
6+
merge_data_with_backtest_portfolio,
7+
)
58
from backtest_bay.backtest.generate_signals import generate_signals
69
from backtest_bay.config import BLD, INITIAL_CASH, PARAMS, SRC, TAC, TRADE_PCT
710

@@ -38,21 +41,8 @@ def task_backtest(
3841
trade_pct=TRADE_PCT,
3942
)
4043

41-
merged_portfolio = _merge_stock_data_with_portfolio(
44+
merged_portfolio = merge_data_with_backtest_portfolio(
4245
stock_data, backtested_portfolio
4346
)
4447

4548
merged_portfolio.to_pickle(produces)
46-
47-
48-
def _merge_stock_data_with_portfolio(data, portfolio):
49-
"""Merge data with portfolio using the index.
50-
51-
Args:
52-
data (pd.DataFrame): DataFrame with downloaded data.
53-
portfolio (pd.DataFrame): DataFrame to be merged with data using the index.
54-
55-
Returns:
56-
pd.DataFrame: Merged DataFrame.
57-
"""
58-
return data.merge(portfolio, how="left", left_index=True, right_index=True)

tests/backtest/test_backtest_signals.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_validate_tac,
1414
_validate_trade_pct,
1515
backtest_signals,
16+
merge_data_with_backtest_portfolio,
1617
)
1718

1819

@@ -225,3 +226,29 @@ def test_validate_trade_pct_invalid_input(trade_pct, expected_error):
225226
def test_validate_data_valid_input(data, price_col):
226227
"""Test valid data input for _validate_data."""
227228
_validate_price_col(data, price_col)
229+
230+
231+
# Tests for merge_data_with_backtest_portfolio
232+
@pytest.mark.parametrize(
233+
("data", "portfolio", "expected"),
234+
[
235+
(
236+
pd.DataFrame({"price": [10, 6]}, index=["2023-01-01", "2023-01-02"]),
237+
pd.DataFrame({"signal": [1, 0]}, index=["2023-01-01", "2023-01-02"]),
238+
pd.DataFrame(
239+
{"price": [10, 6], "signal": [1, 0]}, index=["2023-01-01", "2023-01-02"]
240+
),
241+
),
242+
(
243+
pd.DataFrame({"price": [5, 44]}, index=["2023-05-01", "2023-05-08"]),
244+
pd.DataFrame({"signal": [1, 1]}, index=["2023-05-01", "2023-05-08"]),
245+
pd.DataFrame(
246+
{"price": [5, 44], "signal": [1, 1]}, index=["2023-05-01", "2023-05-08"]
247+
),
248+
),
249+
],
250+
)
251+
def test_merge_data_with_backtest_portfolio(data, portfolio, expected):
252+
"Test correct merge for merge_data_with_backtest_portfolio."
253+
result = merge_data_with_backtest_portfolio(data, portfolio)
254+
result.equals(expected)

0 commit comments

Comments
 (0)