Skip to content

improve endpoint argument validation: use Pydantic #32

Closed
@dshemetov

Description

@dshemetov

We might be able to remove a lot of boilerplate argument validation code by using Pydantic's validate_call and type hints. Here is a proof of concept script:

# Testing Pydantic for validating work.
#
# My profiling results below indicate that Pydantic is good for validating input
# arguments to functions and the returns of small JSON messages. Simply using
# Pandas may be better for larger JSON messages, especially if we expect to
# eventually output a DataFrame anyway.
#
# So my recommendation is to use `validate_call` for endpoint function arguments.
#

import cProfile
import datetime
from typing import List, Literal, Optional, Union

import requests
from epiweeks import Week
from pydantic import (
    BaseModel,
    ConfigDict,
    PositiveFloat,
    condate,
    field_validator,
    validate_call,
)

GeoType = Literal["nation", "msa", "hrr", "hhs", "state", "county"]
TimeType = Literal["day", "week"]
EpiDateLike = Union[int, str, condate(gt=datetime.date(1990, 1, 1)), Week]


# The default error takes some getting used to: the validation error message
# follows a positional index or a keyword argument name.
# https://docs.pydantic.dev/2.8/errors/errors/
@validate_call(config=dict(arbitrary_types_allowed=True))
def test_function(a: int, b: PositiveFloat, c: GeoType, d: EpiDateLike) -> float:
    return f"{a + b} {c} {d}"

# Casts the first argument to int and errors on the next 3.
test_function(5.0, -5, c="hey", d=datetime.date(1989, 4, 5))
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=19890405)
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=Week(1989, 14))


# Mutually exclusive arguments require some extra work.
# https://stackoverflow.com/a/72087084
class MyModel(BaseModel):
    a: Optional[str]
    b: Optional[str]

    @field_validator("b", always=True)
    def mutually_exclusive(cls, v, values):
        if values["a"] is not None and v:
            raise ValueError("'a' and 'b' are mutually exclusive.")

        return v


# You can create a model for a JSON row and validate it.
class Covidcast(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    source: str
    signal: str
    geo_type: GeoType
    geo_value: str
    time_type: TimeType
    time_value: EpiDateLike
    issue: EpiDateLike
    lag: int
    value: float
    stderr: float
    sample_size: int
    direction: Union[float, None]
    missing_value: int
    missing_stderr: int
    missing_sample_size: int

row = """{"geo_value":"us","signal":"smoothed_cli","source":"fb-survey","geo_type":"nation","time_type":"day","time_value":20210405,"direction":null,"issue":20210410,"lag":5,"missing_value":0,"missing_stderr":0,"missing_sample_size":0,"value":0.6758323,"stderr":0.0148258,"sample_size":244046.0}"""
vrow = Covidcast.model_validate_json(row)


# You can create a model for the whole JSON response, consisting of rows above and validate them all.
class Response(BaseModel):
    result: Union[str, int]
    message: str
    epidata: List[Covidcast]

data = requests.get(
    "https://api.delphi.cmu.edu/epidata/covidcast/?data_source=fb-survey&signals=smoothed_cli&time_type=day&time_values=20210405-20210410&geo_type=nation&geo_values=us"
)
vdata = Response.model_validate_json(data.text)


# Profiling constructing Pandas DataFrames from the validated JSON data.
vdata2 = vdata
vdata2.epidata = vdata2.epidata * 10**5

cProfile.run("pd.DataFrame([s.model_dump() for s in vdata2.epidata])")  # around 3.4s
cProfile.run("pd.DataFrame.from_records(vdata2.model_dump()['epidata'])")  # 0.62s

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions