Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Pydantic v2 #656

Merged
merged 15 commits into from
Aug 23, 2023
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ dependencies:
- conda-content-trust
- pyinstrument
- pytest-asyncio
- pydantic <2
- pydantic >=2
- pip:
- git+https://github.com/jupyter-server/jupyter_releaser.git@v2
2 changes: 1 addition & 1 deletion quetz/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class User(Base):
'Profile', uselist=False, back_populates='user', cascade="all,delete-orphan"
)

role = Column(String)
role = Column(String, nullable=True)

@classmethod
def find(cls, db, name):
Expand Down
4 changes: 2 additions & 2 deletions quetz/jobs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from quetz.rest_models import PaginatedResponse

from .models import JobStatus, TaskStatus
from .rest_models import Job, JobBase, JobUpdateModel, Task
from .rest_models import Job, JobCreate, JobUpdateModel, Task

api_router = APIRouter()

Expand All @@ -44,7 +44,7 @@ def get_jobs(

@api_router.post("/api/jobs", tags=["Jobs"], status_code=201, response_model=Job)
def create_job(
job: JobBase,
job: JobCreate,
dao: Dao = Depends(get_dao),
auth: authorization.Rules = Depends(get_rules),
):
Expand Down
24 changes: 14 additions & 10 deletions quetz/jobs/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

from importlib_metadata import entry_points as get_entry_points
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from . import handlers
from .models import JobStatus, TaskStatus
Expand Down Expand Up @@ -83,7 +83,6 @@ def parse_job_name(v):
class JobBase(BaseModel):
"""New job spec"""

items_spec: str = Field(..., title='Item selector spec')
manifest: str = Field(None, title='Name of the function')

start_at: Optional[datetime] = Field(
Expand All @@ -97,7 +96,8 @@ class JobBase(BaseModel):
),
)

@validator("manifest", pre=True)
@field_validator("manifest", mode="before")
@classmethod
def validate_job_name(cls, function_name):
if isinstance(function_name, bytes):
return parse_job_name(function_name)
Expand All @@ -107,6 +107,12 @@ def validate_job_name(cls, function_name):
return function_name.encode('ascii')


class JobCreate(JobBase):
"""Create job spec"""

items_spec: str = Field(..., title='Item selector spec')


class JobUpdateModel(BaseModel):
"""Modify job spec items (status and items_spec)"""

Expand All @@ -123,10 +129,8 @@ class Job(JobBase):

status: JobStatus = Field(None, title='Status of the job (running, paused, ...)')

items_spec: str = Field(None, title='Item selector spec')

class Config:
orm_mode = True
items_spec: Optional[str] = Field(None, title='Item selector spec')
model_config = ConfigDict(from_attributes=True)


class Task(BaseModel):
Expand All @@ -136,12 +140,12 @@ class Task(BaseModel):
created: datetime = Field(None, title='Created at')
status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)')

@validator("package_version", pre=True)
@field_validator("package_version", mode="before")
@classmethod
def convert_package_version(cls, v):
if v:
return {'filename': v.filename, 'id': uuid.UUID(bytes=v.id).hex}
else:
return {}

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)
13 changes: 6 additions & 7 deletions quetz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Awaitable, Callable, List, Optional, Tuple, Type

import pydantic
import pydantic.error_wrappers
import requests
from fastapi import (
APIRouter,
Expand Down Expand Up @@ -477,7 +476,7 @@ def delete_user(

@api_router.get(
"/users/{username}/role",
response_model=rest_models.UserRole,
response_model=rest_models.UserOptionalRole,
tags=["users"],
)
def get_user_role(
Expand Down Expand Up @@ -732,7 +731,7 @@ def post_channel(
detail="Cannot use both `includelist` and `excludelist` together.",
)

user_attrs = new_channel.dict(exclude_unset=True)
user_attrs = new_channel.model_dump(exclude_unset=True)

if "size_limit" in user_attrs:
auth.assert_set_channel_size_limit()
Expand Down Expand Up @@ -789,7 +788,7 @@ def patch_channel(
):
auth.assert_update_channel_info(channel.name)

user_attrs = channel_data.dict(exclude_unset=True)
user_attrs = channel_data.model_dump(exclude_unset=True)

if "size_limit" in user_attrs:
auth.assert_set_channel_size_limit()
Expand Down Expand Up @@ -1064,7 +1063,7 @@ def get_package_versions(
version_list = []

for version, profile, api_key_profile in version_profile_list:
version_data = rest_models.PackageVersion.from_orm(version)
version_data = rest_models.PackageVersion.model_validate(version)
version_list.append(version_data)

return version_list
Expand All @@ -1089,7 +1088,7 @@ def get_paginated_package_versions(
version_list = []

for version, profile, api_key_profile in version_profile_list['result']:
version_data = rest_models.PackageVersion.from_orm(version)
version_data = rest_models.PackageVersion.model_validate(version)
version_list.append(version_data)

return {
Expand Down Expand Up @@ -1643,7 +1642,7 @@ def _delete_file(condainfo, filename):
summary=str(condainfo.about.get("summary", "n/a")),
description=str(condainfo.about.get("description", "n/a")),
)
except pydantic.error_wrappers.ValidationError as err:
except pydantic.ValidationError as err:
_delete_file(condainfo, file.filename)
raise errors.ValidationError(
"Validation Error for package: "
Expand Down
6 changes: 2 additions & 4 deletions quetz/metrics/rest_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from datetime import datetime
from typing import Dict, List

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from quetz.metrics.db_models import IntervalType


class PackageVersionMetricItem(BaseModel):
timestamp: datetime
count: int

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


class PackageVersionMetricSeries(BaseModel):
Expand Down
Loading
Loading