Skip to content

Commit b0b64ae

Browse files
Add download_data.py to download data from Yahoo Finance. Implement corresponding task and test functions. Add config parameters to config.py
1 parent 73938c7 commit b0b64ae

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

src/backtest_bay/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@
66
ROOT = SRC.joinpath("..", "..").resolve()
77

88
BLD = ROOT.joinpath("bld").resolve()
9+
10+
# Configure input data
11+
STOCKS = ["AAPL", "MSFT"]
12+
START_DATES = ["2022-01-01", "2022-01-01"]
13+
END_DATES = ["2025-01-01", "2025-01-01"]
14+
INTERVALS = ["1d", "1d"]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from datetime import datetime
2+
3+
import yfinance as yf
4+
5+
6+
def download_data(symbol, interval, start_date, end_date):
7+
_validate_input(symbol, interval, start_date, end_date)
8+
data = yf.download(symbol, start=start_date, end=end_date, interval=interval)
9+
_validate_output(data, symbol, start_date, end_date, interval)
10+
return data
11+
12+
13+
def _validate_input(symbol, interval, start_date, end_date):
14+
_validate_symbol(symbol)
15+
_validate_interval(interval)
16+
_validate_date_format(start_date)
17+
_validate_date_format(end_date)
18+
_validate_date_range(start_date, end_date)
19+
20+
21+
def _validate_symbol(symbol):
22+
is_symbol_string = isinstance(symbol, str)
23+
if not is_symbol_string:
24+
error_msg = "Symbol must be a non-empty string."
25+
raise ValueError(error_msg)
26+
27+
is_symbol_existent = not yf.Ticker(symbol).history(period="1d").empty
28+
if not is_symbol_existent:
29+
error_msg = f"Invalid symbol: '{symbol}'. Please provide a valid ticker symbol."
30+
raise ValueError(error_msg)
31+
32+
33+
def _validate_interval(interval):
34+
valid_intervals = {
35+
"1m",
36+
"2m",
37+
"5m",
38+
"15m",
39+
"30m",
40+
"60m",
41+
"90m",
42+
"1h",
43+
"1d",
44+
"5d",
45+
"1wk",
46+
"1mo",
47+
"3mo",
48+
}
49+
if interval not in valid_intervals:
50+
error_msg = f"Invalid interval. Choose from {', '.join(valid_intervals)}."
51+
raise ValueError(error_msg)
52+
53+
54+
def _validate_date_format(date_str):
55+
try:
56+
datetime.fromisoformat(date_str)
57+
except ValueError as e:
58+
error_msg = "Date must be in 'YYYY-MM-DD' format."
59+
raise ValueError(error_msg) from e
60+
61+
62+
def _validate_date_range(start_date, end_date):
63+
if start_date > end_date:
64+
error_msg = "Start date must be before end date."
65+
raise ValueError(error_msg)
66+
67+
68+
def _validate_output(data, symbol, start_date, end_date, interval):
69+
if data.empty:
70+
ticker = yf.Ticker(symbol)
71+
hist = ticker.history(period="max")
72+
min_date = hist.index.min().strftime("%Y-%m-%d")
73+
max_date = hist.index.max().strftime("%Y-%m-%d")
74+
75+
error_msg = (
76+
f"No data found for {symbol} between {start_date} and {end_date} "
77+
f"with interval '{interval}'.\n"
78+
f"Available data range for {symbol}: {min_date} to {max_date}."
79+
)
80+
raise ValueError(error_msg)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytask
2+
3+
from backtest_bay.config import BLD, END_DATES, INTERVALS, SRC, START_DATES, STOCKS
4+
from backtest_bay.data.download_data import download_data
5+
6+
dependencies = [SRC / "config.py", SRC / "data" / "download_data.py"]
7+
8+
for index, _ in enumerate(STOCKS):
9+
_id = (
10+
STOCKS[index]
11+
+ "_"
12+
+ START_DATES[index]
13+
+ "_"
14+
+ END_DATES[index]
15+
+ "_"
16+
+ INTERVALS[index]
17+
)
18+
19+
@pytask.task(id=_id)
20+
def task_download_data(
21+
symbol=STOCKS[index],
22+
start_date=START_DATES[index],
23+
end_date=END_DATES[index],
24+
interval=INTERVALS[index],
25+
depends_on=dependencies,
26+
produces=BLD / (_id + ".pkl"),
27+
):
28+
data = download_data(
29+
symbol=symbol, start_date=start_date, end_date=end_date, interval=interval
30+
)
31+
data.to_pickle(produces)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
3+
from backtest_bay.data.download_data import _validate_symbol
4+
5+
6+
# Tests for _validate_symbol
7+
def test_validate_symbol_valid_cases():
8+
assert _validate_symbol("AAPL") is None
9+
assert _validate_symbol("MSFT") is None
10+
11+
12+
def test_validate_symbol_invalid_symbol():
13+
with pytest.raises(ValueError, match="Invalid symbol: 'INVALID'"):
14+
_validate_symbol("INVALID")
15+
16+
17+
def test_validate_symbol_empty_string():
18+
with pytest.raises(ValueError, match="Invalid symbol: ''"):
19+
_validate_symbol("")
20+
21+
22+
def test_validate_symbol_non_string():
23+
with pytest.raises(ValueError, match="Symbol must be a non-empty string."):
24+
_validate_symbol(123)

0 commit comments

Comments
 (0)