Skip to content

Commit

Permalink
Change sqlmodel to sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Aug 9, 2024
1 parent ea78a3b commit 55e68e6
Show file tree
Hide file tree
Showing 22 changed files with 142 additions and 163 deletions.
1 change: 1 addition & 0 deletions amt/api/routes/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def get_root(
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
tasks_service: Annotated[TasksService, Depends(TasksService)],
) -> HTMLResponse:
logger.info(f"getting project with id {project_id}")
project = projects_service.get(project_id)

context = {
Expand Down
8 changes: 5 additions & 3 deletions amt/core/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging

from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from sqlmodel import Session, SQLModel, create_engine, select
from sqlalchemy.orm import Session

from amt.core.config import get_settings
from amt.models.base import Base

logger = logging.getLogger(__name__)

Expand All @@ -22,7 +24,7 @@ def get_engine() -> Engine:
def check_db() -> None:
logger.info("Checking database connection")
with Session(get_engine()) as session:
session.exec(select(1))
session.execute(select(1))

logger.info("Finish Checking database connection")

Expand All @@ -32,6 +34,6 @@ def init_db() -> None:

if get_settings().AUTO_CREATE_SCHEMA: # pragma: no cover
logger.info("Creating database schema")
SQLModel.metadata.create_all(get_engine())
Base.metadata.create_all(get_engine())

logger.info("Finished initializing database")
4 changes: 2 additions & 2 deletions amt/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from alembic import context
from amt.models import * # noqa
from sqlalchemy import engine_from_config, pool
from sqlmodel import SQLModel
from sqlalchemy.schema import MetaData

config = context.config

if config.config_file_name is not None:
fileConfig(config.config_file_name)

target_metadata = SQLModel.metadata
target_metadata = MetaData()


def get_url() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from collections.abc import Sequence

import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from alembic import op

# revision identifiers, used by Alembic.
Expand Down Expand Up @@ -45,7 +44,7 @@ def downgrade() -> None:
op.create_table(
"status",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("sort_order", sa.Float(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from collections.abc import Sequence

import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from alembic import op

# revision identifiers, used by Alembic.
Expand All @@ -24,23 +23,23 @@ def upgrade() -> None:
status = op.create_table(
"status",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("sort_order", sa.Float(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)

op.create_table(
"user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("avatar", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("avatar", sa.String(255), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"task",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("title", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("title", sa.String(255), nullable=False),
sa.Column("description", sa.String(255), nullable=False),
sa.Column("sort_order", sa.Float(), nullable=False),
sa.Column("status_id", sa.Integer(), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=True),
Expand Down
5 changes: 2 additions & 3 deletions amt/migrations/versions/c5254dc6083f_add_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from collections.abc import Sequence

import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from alembic import op

# revision identifiers, used by Alembic.
Expand All @@ -24,8 +23,8 @@ def upgrade() -> None:
op.create_table(
"project",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
sa.Column("model_card", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("name",sa.String(255), nullable=False),
sa.Column("model_card", sa.String(255), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
Expand Down
5 changes: 5 additions & 0 deletions amt/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from sqlalchemy.orm import DeclarativeBase


class Base(DeclarativeBase):
pass
23 changes: 9 additions & 14 deletions amt/models/project.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from pathlib import Path
from typing import TypeVar

from pydantic import field_validator
from sqlmodel import Field, SQLModel # pyright: ignore [reportUnknownVariableType]
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column

T = TypeVar("T", bound="Project")
from amt.models.base import Base

T = TypeVar("T", bound="Project")

class Project(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str = Field(max_length=255, min_items=3)
model_card: str | None = Field(description="Model card storage location", default=None)

@field_validator("model_card")
@classmethod
def validate_model_card(cls: type[T], model_card: str) -> str:
if not Path(model_card).is_file():
raise ValueError("Model card must be a file")
class Project(Base):
__tablename__ = "project"

return model_card
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255)) # TODO: (Christopher) how to set min_length?
model_card: Mapped[str | None] = mapped_column(default=None)
24 changes: 14 additions & 10 deletions amt/models/task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from sqlmodel import Field as SQLField # pyright: ignore [reportUnknownVariableType]
from sqlmodel import SQLModel
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column

from amt.models.base import Base

class Task(SQLModel, table=True):
id: int | None = SQLField(default=None, primary_key=True)
title: str
description: str
sort_order: float
status_id: int | None = SQLField(default=None)
user_id: int | None = SQLField(default=None, foreign_key="user.id")

class Task(Base):
__tablename__ = "task"

id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str]
description: Mapped[str]
sort_order: Mapped[float]
status_id: Mapped[int | None] = mapped_column(default=None)
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
# TODO: (Christopher) SQLModel does not allow to give the below restraint an name
# which is needed for alembic. This results in changing the migration file
# manually to give the restrain a name.
project_id: int | None = SQLField(default=None, foreign_key="project.id")
project_id: Mapped[int | None] = mapped_column(ForeignKey("project.id"))
# todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id
14 changes: 9 additions & 5 deletions amt/models/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from sqlmodel import Field, SQLModel # pyright: ignore [reportUnknownVariableType]
from sqlalchemy.orm import Mapped, mapped_column

from amt.models.base import Base

class User(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
name: str
avatar: str | None

class User(Base):
__tablename__ = "user"

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
avatar: Mapped[str | None] = mapped_column(default=None)
2 changes: 1 addition & 1 deletion amt/repositories/deps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Generator

from sqlmodel import Session
from sqlalchemy.orm import Session

from amt.core.db import get_engine

Expand Down
11 changes: 5 additions & 6 deletions amt/repositories/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Annotated

from fastapi import Depends
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlmodel import Session, select
from sqlalchemy.orm import Session

from amt.core.exceptions import RepositoryError, RepositoryNoResultFound
from amt.models import Project
Expand All @@ -19,7 +19,7 @@ def __init__(self, session: Annotated[Session, Depends(get_session)]) -> None:
self.session = session

def find_all(self) -> Sequence[Project]:
return self.session.exec(select(Project)).all()
return self.session.execute(select(Project)).scalars().all()

def delete(self, project: Project) -> None:
"""
Expand All @@ -41,21 +41,20 @@ def save(self, project: Project) -> Project:
self.session.commit()
self.session.refresh(project)
except SQLAlchemyError as e:
logger.debug(f"Error saving project: {project}")
self.session.rollback()
raise RepositoryError from e
return project

def find_by_id(self, project_id: int) -> Project:
try:
statement = select(Project).where(Project.id == project_id)
return self.session.exec(statement).one()
return self.session.execute(statement).scalars().one()
except NoResultFound as e:
raise RepositoryError from e

def paginate(self, skip: int, limit: int) -> list[Project]:
try:
statement = select(Project).order_by(func.lower(Project.name)).offset(skip).limit(limit)
return list(self.session.exec(statement).all())
return list(self.session.execute(statement).scalars())
except Exception as e:
raise RepositoryNoResultFound from e
17 changes: 8 additions & 9 deletions amt/repositories/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Annotated

from fastapi import Depends
from sqlalchemy import and_, select
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, and_, select
from sqlalchemy.orm import Session

from amt.core.exceptions import RepositoryError
from amt.models import Task
Expand All @@ -26,31 +27,29 @@ def find_all(self) -> Sequence[Task]:
Returns all tasks in the repository.
:return: all tasks in the repository
"""
return self.session.exec(select(Task)).all()
return self.session.execute(select(Task)).scalars().all()

def find_by_status_id(self, status_id: int) -> Sequence[Task]:
"""
Returns all tasks in the repository for the given status_id.
:param status_id: the status_id to filter on
:return: a list of tasks in the repository for the given status_id
"""
# todo (Robbert): we 'type ignore' Task.sort_order because it works correctly, but pyright does not agree
statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order) # pyright: ignore [reportUnknownMemberType, reportCallIssue, reportUnknownVariableType, reportArgumentType]
return self.session.exec(statement).all()
statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order)
return self.session.execute(statement).scalars().all()

def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]:
"""
Returns all tasks in the repository for the given project_id.
:param project_id: the project_id to filter on
:return: a list of tasks in the repository for the given project_id
"""
# todo (Robbert): we 'type ignore' Task.sort_order because it works correctly, but pyright does not agree
statement = (
select(Task)
.where(and_(Task.status_id == status_id, Task.project_id == project_id))
.order_by(Task.sort_order) # pyright: ignore [reportUnknownMemberType, reportCallIssue, reportUnknownVariableType, reportArgumentType]
.order_by(Task.sort_order)
)
return self.session.exec(statement).all()
return self.session.execute(statement).scalars().all()

def save(self, task: Task) -> Task:
"""
Expand Down Expand Up @@ -102,6 +101,6 @@ def find_by_id(self, task_id: int) -> Task:
"""
statement = select(Task).where(Task.id == task_id)
try:
return self.session.exec(statement).one()
return self.session.execute(statement).scalars().one()
except NoResultFound as e:
raise RepositoryError from e
1 change: 1 addition & 0 deletions amt/services/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def fetch_github_content(self, url: str) -> Instrument:
return Instrument(**data)

def fetch_instruments(self, urns: Sequence[str] | None = None) -> list[Instrument]:
# todo (Robbert): we 'type ignore' Task.sort_order because it works correctly, but pyright does not agree
content_list = self.fetch_github_content_list()

instruments: list[Instrument] = []
Expand Down
Loading

0 comments on commit 55e68e6

Please sign in to comment.