Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve endpoint argument validation: use Pydantic #32

Open
dshemetov opened this issue Jul 9, 2024 · 0 comments
Open

improve endpoint argument validation: use Pydantic #32

dshemetov opened this issue Jul 9, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@dshemetov
Copy link
Contributor

We can remove a lot of boilerplate argument validation code by using Pydantic's validate_call and type hints. It's fast and has good default messages. See the proof of concept script below.

# Testing Pydantic for validating work.
#
# My profiling results below indicate that Pydantic is good for validating input
# arguments to functions and the returns of small JSON messages. Simply using
# Pandas may be better for larger JSON messages, especially if we expect to
# eventually output a DataFrame anyway.
#
# So my recommendation is to use `validate_call` for endpoint function arguments.
#

import cProfile
import datetime
from typing import List, Literal, Optional, Union

import requests
from epiweeks import Week
from pydantic import (
    BaseModel,
    ConfigDict,
    PositiveFloat,
    condate,
    field_validator,
    validate_call,
)

GeoType = Literal["nation", "msa", "hrr", "hhs", "state", "county"]
TimeType = Literal["day", "week"]
EpiDateLike = Union[int, str, condate(gt=datetime.date(1990, 1, 1)), Week]


# The default error takes some getting used to: the validation error message
# follows a positional index or a keyword argument name.
# https://docs.pydantic.dev/2.8/errors/errors/
@validate_call(config=dict(arbitrary_types_allowed=True))
def test_function(a: int, b: PositiveFloat, c: GeoType, d: EpiDateLike) -> float:
    return f"{a + b} {c} {d}"

# Casts the first argument to int and errors on the next 3.
test_function(5.0, -5, c="hey", d=datetime.date(1989, 4, 5))
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=19890405)
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=Week(1989, 14))


# Mutually exclusive arguments require some extra work.
# https://stackoverflow.com/a/72087084
class MyModel(BaseModel):
    a: Optional[str]
    b: Optional[str]

    @field_validator("b", always=True)
    def mutually_exclusive(cls, v, values):
        if values["a"] is not None and v:
            raise ValueError("'a' and 'b' are mutually exclusive.")

        return v


# You can create a model for a JSON row and validate it.
class Covidcast(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    source: str
    signal: str
    geo_type: GeoType
    geo_value: str
    time_type: TimeType
    time_value: EpiDateLike
    issue: EpiDateLike
    lag: int
    value: float
    stderr: float
    sample_size: int
    direction: Union[float, None]
    missing_value: int
    missing_stderr: int
    missing_sample_size: int

row = """{"geo_value":"us","signal":"smoothed_cli","source":"fb-survey","geo_type":"nation","time_type":"day","time_value":20210405,"direction":null,"issue":20210410,"lag":5,"missing_value":0,"missing_stderr":0,"missing_sample_size":0,"value":0.6758323,"stderr":0.0148258,"sample_size":244046.0}"""
vrow = Covidcast.model_validate_json(row)


# You can create a model for the whole JSON response, consisting of rows above and validate them all.
class Response(BaseModel):
    result: Union[str, int]
    message: str
    epidata: List[Covidcast]

data = requests.get(
    "https://api.delphi.cmu.edu/epidata/covidcast/?data_source=fb-survey&signals=smoothed_cli&time_type=day&time_values=20210405-20210410&geo_type=nation&geo_values=us"
)
vdata = Response.model_validate_json(data.text)


# Profiling constructing Pandas DataFrames from the validated JSON data.
vdata2 = vdata
vdata2.epidata = vdata2.epidata * 10**5

cProfile.run("pd.DataFrame([s.model_dump() for s in vdata2.epidata])")  # around 3.4s
cProfile.run("pd.DataFrame.from_records(vdata2.model_dump()['epidata'])")  # 0.62s
@dshemetov dshemetov added the enhancement New feature or request label Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant