11from datetime import datetime
22
3+ import pandas as pd
34import yfinance as yf
45
56
67def download_data (symbol , interval , start_date , end_date ):
8+ """Download historical stock data and validate it.
9+
10+ This function downloads historical stock data for a given symbol, date range,
11+ and interval using the yfinance library. The input parameters and the downloaded
12+ data are validated to ensure they meet the required criteria.
13+
14+ Args:
15+ symbol (str): Stock symbol to download data for (e.g., 'AAPL' for Apple).
16+ interval (str): Data interval (e.g., '1d' for daily, '1h' for hourly).
17+ start_date (str): Start date for the data in 'YYYY-MM-DD' format.
18+ end_date (str): End date for the data in 'YYYY-MM-DD' format.
19+
20+ Returns:
21+ pd.DataFrame: DataFrame containing the downloaded stock data.
22+ """
723 _validate_input (symbol , interval , start_date , end_date )
824 data = yf .download (symbol , start = start_date , end = end_date , interval = interval )
9- _validate_output (data , symbol , start_date , end_date , interval )
25+ _validate_data (data , symbol , start_date , end_date , interval )
26+ data .columns = _remove_multiindex_from_cols (data .columns )
1027 return data
1128
1229
1330def _validate_input (symbol , interval , start_date , end_date ):
31+ """Validate symbol, interval, and date inputs."""
1432 _validate_symbol (symbol )
1533 _validate_interval (interval )
1634 _validate_date_format (start_date )
@@ -19,13 +37,15 @@ def _validate_input(symbol, interval, start_date, end_date):
1937
2038
2139def _validate_symbol (symbol ):
40+ """Check if symbol is a non-empty string."""
2241 is_symbol_string = isinstance (symbol , str )
2342 if not is_symbol_string :
2443 error_msg = "Symbol must be a non-empty string."
2544 raise TypeError (error_msg )
2645
2746
2847def _validate_interval (interval ):
48+ """Validate if interval is within the allowed set."""
2949 valid_intervals = {
3050 "1m" ,
3151 "2m" ,
@@ -47,6 +67,7 @@ def _validate_interval(interval):
4767
4868
4969def _validate_date_format (date_str ):
70+ """Check if date string is in 'YYYY-MM-DD' format."""
5071 if not isinstance (date_str , str ):
5172 error_msg = "Date must be a string in 'YYYY-MM-DD' format."
5273 raise TypeError (error_msg )
@@ -58,15 +79,111 @@ def _validate_date_format(date_str):
5879
5980
6081def _validate_date_range (start_date , end_date ):
82+ """Ensure start date is before end date."""
6183 if start_date > end_date :
6284 error_msg = "Start date must be before end date."
6385 raise ValueError (error_msg )
6486
6587
66- def _validate_output (data , symbol , start_date , end_date , interval ):
88+ def _validate_data (data , symbol , start_date , end_date , interval ):
89+ """Validate the downloaded data.
90+
91+ Args:
92+ data (pd.DataFrame): DataFrame containing stock data.
93+ symbol (str): Stock symbol for the data.
94+ start_date (str): Start date for the data.
95+ end_date (str): End date for the data.
96+ interval (str): Interval for the data.
97+
98+ Raises:
99+ TypeError: If data is not a pandas DataFrame or index is not a DatetimeIndex.
100+ ValueError: If required columns are missing or contain non-numeric values.
101+ ValueError: If the DataFrame is empty.
102+ """
103+ _validate_data_type_dataframe (data )
104+ _validate_data_empty (data , symbol , start_date , end_date , interval )
105+ _validate_data_index_datetime (data .index )
106+ _validate_data_multiindex (data .columns )
107+ _validate_data_numeric (data )
108+
109+
110+ def _validate_data_type_dataframe (data ):
111+ """Check if the input is a pandas DataFrame."""
112+ if not isinstance (data , pd .DataFrame ):
113+ error_msg = f"data must be a pandas DataFrame, got { type (data ).__name__ } ."
114+ raise TypeError (error_msg )
115+
116+
117+ def _validate_data_empty (data , symbol , start_date , end_date , interval ):
118+ """Check if the DataFrame is empty."""
67119 if data .empty :
68120 error_msg = (
69121 f"No data found for { symbol } between { start_date } and { end_date } "
70122 f"with interval '{ interval } '."
71123 )
72124 raise ValueError (error_msg )
125+
126+
127+ def _validate_data_index_datetime (index ):
128+ """Check if the index is of type DatetimeIndex."""
129+ if not isinstance (index , pd .DatetimeIndex ):
130+ error_msg = (
131+ f"data index must be a pandas DatetimeIndex, got { type (index ).__name__ } ."
132+ )
133+ raise TypeError (error_msg )
134+
135+
136+ def _validate_data_multiindex (columns ):
137+ """Check if the columns have the required MultiIndex.
138+
139+ Args:
140+ columns (pd.MultiIndex): MultiIndex of DataFrame columns.
141+
142+ Raises:
143+ ValueError: If the MultiIndex is not present, does not have exactly two levels,
144+ or if required columns are missing from level 0.
145+ """
146+ required_cols = {"Close" , "Open" , "High" , "Low" }
147+
148+ if not isinstance (columns , pd .MultiIndex ):
149+ error_msg = "DataFrame must have a MultiIndex for columns."
150+ raise TypeError (error_msg )
151+
152+ yfinance_index_levels = 2
153+ if columns .nlevels != yfinance_index_levels :
154+ error_msg = (
155+ f"MultiIndex must have exactly 2 levels, got { columns .nlevels } levels."
156+ )
157+ raise ValueError (error_msg )
158+
159+ level_0_values = set (columns .get_level_values (0 ))
160+ missing_cols = required_cols - level_0_values
161+ if missing_cols :
162+ error_msg = (
163+ "Level 0 of MultiIndex must contain the following columns: "
164+ f"{ ', ' .join (missing_cols )} ."
165+ )
166+ raise ValueError (error_msg )
167+
168+
169+ def _validate_data_numeric (data ):
170+ """Check if the required columns contain numeric values."""
171+ required_cols = ["Close" , "Open" , "High" , "Low" ]
172+ non_numeric_cols = [
173+ col
174+ for col in required_cols
175+ if not pd .api .types .is_numeric_dtype (data [col ].squeeze ())
176+ ]
177+
178+ if non_numeric_cols :
179+ error_msg = (
180+ "The following columns must contain numeric values: "
181+ f"{ ', ' .join (non_numeric_cols )} ."
182+ )
183+ raise ValueError (error_msg )
184+
185+
186+ def _remove_multiindex_from_cols (cols ):
187+ """Remove MultiIndex from columns but retain level 0 as column names."""
188+ cols = cols .get_level_values (0 )
189+ return cols
0 commit comments