Skip to content

Commit

Permalink
refactor: center DataFrame return in EpidataCall
Browse files Browse the repository at this point in the history
* remove json, csv, iter formats
* remove format_type option, always request classic
* consolidate DataFrame code
* parse types only if classic, otherwise let Pandas do it
  • Loading branch information
dshemetov committed Jul 12, 2024
1 parent 2a6c3f7 commit d2b41f0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 119 deletions.
60 changes: 2 additions & 58 deletions epidatpy/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from datetime import date
from enum import Enum
from typing import (
Any,
Dict,
Final,
List,
Mapping,
Expand All @@ -18,10 +16,8 @@
from urllib.parse import urlencode

from epiweeks import Week
from pandas import CategoricalDtype, DataFrame, Series

from ._parse import (
fields_to_predicate,
parse_api_date,
parse_api_date_or_week,
parse_api_week,
Expand Down Expand Up @@ -90,17 +86,6 @@ def __str__(self) -> str:
return f"{format_date(self.start)}-{format_date(self.end)}"


class EpiDataFormatType(str, Enum):
"""
possible formatting options for API calls
"""

json = "json"
classic = "classic"
csv = "csv"
jsonl = "jsonl"


class InvalidArgumentException(Exception):
"""
exception for an invalid argument
Expand Down Expand Up @@ -180,41 +165,36 @@ def _verify_parameters(self) -> None:

def _formatted_parameters(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> Mapping[str, str]:
"""
format this call into a [URL, Params] tuple
"""
all_params = dict(self._params)
if format_type and format_type != EpiDataFormatType.classic:
all_params["format"] = format_type
if fields:
all_params["fields"] = fields
return {k: format_list(v) for k, v in all_params.items() if v is not None}

def request_arguments(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> Tuple[str, Mapping[str, str]]:
"""
format this call into a [URL, Params] tuple
"""
formatted_params = self._formatted_parameters(format_type, fields)
formatted_params = self._formatted_parameters(fields)
full_url = add_endpoint_to_url(self._base_url, self._endpoint)
return full_url, formatted_params

def request_url(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
) -> str:
"""
format this call into a full HTTP request url with encoded parameters
"""
self._verify_parameters()
u, p = self.request_arguments(format_type, fields)
u, p = self.request_arguments(fields)
query = urlencode(p)
if query:
return f"{u}?{query}"
Expand Down Expand Up @@ -253,39 +233,3 @@ def _parse_row(
if not self.meta:
return row
return {k: self._parse_value(k, v, disable_date_parsing) for k, v in row.items()}

def _as_df(
self,
rows: Sequence[Mapping[str, Union[str, float, int, date, None]]],
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> DataFrame:
pred = fields_to_predicate(fields)
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
df = DataFrame(rows, columns=columns or None)

data_types: Dict[str, Any] = {}
for info in self.meta:
if not pred(info.name) or df[info.name].isnull().all():
continue
if info.type == EpidataFieldType.bool:
data_types[info.name] = bool
elif info.type == EpidataFieldType.categorical:
data_types[info.name] = CategoricalDtype(
categories=Series(info.categories) if info.categories else None, ordered=True
)
elif info.type == EpidataFieldType.int:
data_types[info.name] = int
elif info.type in (
EpidataFieldType.date,
EpidataFieldType.epiweek,
EpidataFieldType.date_or_epiweek,
):
data_types[info.name] = int if disable_date_parsing else "datetime64[ns]"
elif info.type == EpidataFieldType.float:
data_types[info.name] = float
else:
data_types[info.name] = str
if data_types:
df = df.astype(data_types)
return df
100 changes: 45 additions & 55 deletions epidatpy/request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import date
from typing import (
Any,
Dict,
Final,
List,
Mapping,
Expand All @@ -9,7 +10,7 @@
cast,
)

from pandas import DataFrame
from pandas import CategoricalDtype, DataFrame, Series
from requests import Response, Session
from requests.auth import HTTPBasicAuth
from tenacity import retry, stop_after_attempt
Expand All @@ -21,13 +22,14 @@
from ._model import (
AEpiDataCall,
EpidataFieldInfo,
EpiDataFormatType,
EpidataFieldType,
EpiDataResponse,
EpiRange,
EpiRangeParam,
OnlySupportsClassicFormatException,
add_endpoint_to_url,
)
from ._parse import fields_to_predicate

# Make the linter happy about the unused variables
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
Expand Down Expand Up @@ -83,23 +85,25 @@ def with_session(self, session: Session) -> "EpiDataCall":

def _call(
self,
format_type: Optional[EpiDataFormatType] = None,
fields: Optional[Sequence[str]] = None,
stream: bool = False,
) -> Response:
url, params = self.request_arguments(format_type, fields)
url, params = self.request_arguments(fields)
return _request_with_retry(url, params, self._session, stream)

def classic(
self,
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
disable_type_parsing: Optional[bool] = False,
) -> EpiDataResponse:
"""Request and parse epidata in CLASSIC message format."""
self._verify_parameters()
try:
response = self._call(None, fields)
response = self._call(fields)
r = cast(EpiDataResponse, response.json())
if disable_type_parsing:
return r
epidata = r.get("epidata")
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
r["epidata"] = [self._parse_row(row, disable_date_parsing=disable_date_parsing) for row in epidata]
Expand All @@ -111,25 +115,11 @@ def __call__(
self,
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> EpiDataResponse:
"""Request and parse epidata in CLASSIC message format."""
return self.classic(fields, disable_date_parsing=disable_date_parsing)

def json(
self,
fields: Optional[Sequence[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> List[Mapping[str, Union[str, int, float, date, None]]]:
"""Request and parse epidata in JSON format"""
) -> Union[EpiDataResponse, DataFrame]:
"""Request and parse epidata in df message format."""
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
response = self._call(EpiDataFormatType.json, fields)
response.raise_for_status()
return [
self._parse_row(row, disable_date_parsing=disable_date_parsing)
for row in cast(List[Mapping[str, Union[str, int, float, None]]], response.json())
]
return self.classic(fields, disable_date_parsing=disable_date_parsing, disable_type_parsing = False)
return self.df(fields, disable_date_parsing=disable_date_parsing)

def df(
self,
Expand All @@ -140,37 +130,37 @@ def df(
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
r = self.json(fields, disable_date_parsing=disable_date_parsing)
return self._as_df(r, fields, disable_date_parsing=disable_date_parsing)

def csv(self, fields: Optional[Iterable[str]] = None) -> str:
"""Request and parse epidata in CSV format"""
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
response = self._call(EpiDataFormatType.csv, fields)
response.raise_for_status()
return response.text

def iter(
self,
fields: Optional[Iterable[str]] = None,
disable_date_parsing: Optional[bool] = False,
) -> Generator[Mapping[str, Union[str, int, float, date, None]], None, Response]:
"""Request and streams epidata rows"""
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()
response = self._call(EpiDataFormatType.jsonl, fields, stream=True)
response.raise_for_status()
for line in response.iter_lines():
yield self._parse_row(loads(line), disable_date_parsing=disable_date_parsing)
return response

def __iter__(
self,
) -> Generator[Mapping[str, Union[str, int, float, date, None]], None, Response]:
return self.iter()
json = self.classic(fields, disable_type_parsing=True)
rows = json.get("epidata", [])
pred = fields_to_predicate(fields)
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
df = DataFrame(rows, columns=columns or None)

data_types: Dict[str, Any] = {}
for info in self.meta:
if not pred(info.name) or df[info.name].isnull().all():
continue
if info.type == EpidataFieldType.bool:
data_types[info.name] = bool
elif info.type == EpidataFieldType.categorical:
data_types[info.name] = CategoricalDtype(
categories=Series(info.categories) if info.categories else None, ordered=True
)
elif info.type == EpidataFieldType.int:
data_types[info.name] = "Int64"
elif info.type in (
EpidataFieldType.date,
EpidataFieldType.epiweek,
EpidataFieldType.date_or_epiweek,
):
data_types[info.name] = "Int64" if disable_date_parsing else "datetime64[ns]"
elif info.type == EpidataFieldType.float:
data_types[info.name] = "Float64"
else:
data_types[info.name] = "string"
if data_types:
df = df.astype(data_types)
return df


class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
Expand Down
6 changes: 0 additions & 6 deletions smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
classic = apicall.classic()
print(classic)

data = apicall.json()
print(data[0])

df = apicall.df()
print(df.columns)
print(df.dtypes)
Expand Down Expand Up @@ -53,9 +50,6 @@
classic = apicall.classic()
print(classic)

data = apicall.json()
print(data[0])

df = apicall.df()
print(df.columns)
print(df.dtypes)
Expand Down

0 comments on commit d2b41f0

Please sign in to comment.