Skip to content

Commit 55e68e6

Browse files
Change sqlmodel to sqlalchemy
1 parent ea78a3b commit 55e68e6

22 files changed

+142
-163
lines changed

amt/api/routes/project.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def get_root(
2121
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
2222
tasks_service: Annotated[TasksService, Depends(TasksService)],
2323
) -> HTMLResponse:
24+
logger.info(f"getting project with id {project_id}")
2425
project = projects_service.get(project_id)
2526

2627
context = {

amt/core/db.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
22

3+
from sqlalchemy import create_engine, select
34
from sqlalchemy.engine import Engine
4-
from sqlmodel import Session, SQLModel, create_engine, select
5+
from sqlalchemy.orm import Session
56

67
from amt.core.config import get_settings
8+
from amt.models.base import Base
79

810
logger = logging.getLogger(__name__)
911

@@ -22,7 +24,7 @@ def get_engine() -> Engine:
2224
def check_db() -> None:
2325
logger.info("Checking database connection")
2426
with Session(get_engine()) as session:
25-
session.exec(select(1))
27+
session.execute(select(1))
2628

2729
logger.info("Finish Checking database connection")
2830

@@ -32,6 +34,6 @@ def init_db() -> None:
3234

3335
if get_settings().AUTO_CREATE_SCHEMA: # pragma: no cover
3436
logger.info("Creating database schema")
35-
SQLModel.metadata.create_all(get_engine())
37+
Base.metadata.create_all(get_engine())
3638

3739
logger.info("Finished initializing database")

amt/migrations/env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from alembic import context
55
from amt.models import * # noqa
66
from sqlalchemy import engine_from_config, pool
7-
from sqlmodel import SQLModel
7+
from sqlalchemy.schema import MetaData
88

99
config = context.config
1010

1111
if config.config_file_name is not None:
1212
fileConfig(config.config_file_name)
1313

14-
target_metadata = SQLModel.metadata
14+
target_metadata = MetaData()
1515

1616

1717
def get_url() -> str:

amt/migrations/versions/9ce2341f2922_remove_the_status_table.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import Sequence
1010

1111
import sqlalchemy as sa
12-
import sqlmodel.sql.sqltypes
1312
from alembic import op
1413

1514
# revision identifiers, used by Alembic.
@@ -45,7 +44,7 @@ def downgrade() -> None:
4544
op.create_table(
4645
"status",
4746
sa.Column("id", sa.Integer(), nullable=False),
48-
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
47+
sa.Column("name", sa.String(255), nullable=False),
4948
sa.Column("sort_order", sa.Float(), nullable=False),
5049
sa.PrimaryKeyConstraint("id"),
5150
)

amt/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import Sequence
1010

1111
import sqlalchemy as sa
12-
import sqlmodel.sql.sqltypes
1312
from alembic import op
1413

1514
# revision identifiers, used by Alembic.
@@ -24,23 +23,23 @@ def upgrade() -> None:
2423
status = op.create_table(
2524
"status",
2625
sa.Column("id", sa.Integer(), nullable=False),
27-
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
26+
sa.Column("name", sa.String(255), nullable=False),
2827
sa.Column("sort_order", sa.Float(), nullable=False),
2928
sa.PrimaryKeyConstraint("id"),
3029
)
3130

3231
op.create_table(
3332
"user",
3433
sa.Column("id", sa.Integer(), nullable=False),
35-
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
36-
sa.Column("avatar", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
34+
sa.Column("name", sa.String(255), nullable=False),
35+
sa.Column("avatar", sa.String(255), nullable=True),
3736
sa.PrimaryKeyConstraint("id"),
3837
)
3938
op.create_table(
4039
"task",
4140
sa.Column("id", sa.Integer(), nullable=False),
42-
sa.Column("title", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
43-
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
41+
sa.Column("title", sa.String(255), nullable=False),
42+
sa.Column("description", sa.String(255), nullable=False),
4443
sa.Column("sort_order", sa.Float(), nullable=False),
4544
sa.Column("status_id", sa.Integer(), nullable=True),
4645
sa.Column("user_id", sa.Integer(), nullable=True),

amt/migrations/versions/c5254dc6083f_add_project.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from collections.abc import Sequence
1010

1111
import sqlalchemy as sa
12-
import sqlmodel.sql.sqltypes
1312
from alembic import op
1413

1514
# revision identifiers, used by Alembic.
@@ -24,8 +23,8 @@ def upgrade() -> None:
2423
op.create_table(
2524
"project",
2625
sa.Column("id", sa.Integer(), nullable=False),
27-
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False),
28-
sa.Column("model_card", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
26+
sa.Column("name",sa.String(255), nullable=False),
27+
sa.Column("model_card", sa.String(255), nullable=False),
2928
sa.PrimaryKeyConstraint("id"),
3029
)
3130
# ### end Alembic commands ###

amt/models/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from sqlalchemy.orm import DeclarativeBase
2+
3+
4+
class Base(DeclarativeBase):
5+
pass

amt/models/project.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
1-
from pathlib import Path
21
from typing import TypeVar
32

4-
from pydantic import field_validator
5-
from sqlmodel import Field, SQLModel # pyright: ignore [reportUnknownVariableType]
3+
from sqlalchemy import String
4+
from sqlalchemy.orm import Mapped, mapped_column
65

7-
T = TypeVar("T", bound="Project")
6+
from amt.models.base import Base
87

8+
T = TypeVar("T", bound="Project")
99

10-
class Project(SQLModel, table=True):
11-
id: int | None = Field(default=None, primary_key=True)
12-
name: str = Field(max_length=255, min_items=3)
13-
model_card: str | None = Field(description="Model card storage location", default=None)
1410

15-
@field_validator("model_card")
16-
@classmethod
17-
def validate_model_card(cls: type[T], model_card: str) -> str:
18-
if not Path(model_card).is_file():
19-
raise ValueError("Model card must be a file")
11+
class Project(Base):
12+
__tablename__ = "project"
2013

21-
return model_card
14+
id: Mapped[int] = mapped_column(primary_key=True)
15+
name: Mapped[str] = mapped_column(String(255)) # TODO: (Christopher) how to set min_length?
16+
model_card: Mapped[str | None] = mapped_column(default=None)

amt/models/task.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from sqlmodel import Field as SQLField # pyright: ignore [reportUnknownVariableType]
2-
from sqlmodel import SQLModel
1+
from sqlalchemy import ForeignKey
2+
from sqlalchemy.orm import Mapped, mapped_column
33

4+
from amt.models.base import Base
45

5-
class Task(SQLModel, table=True):
6-
id: int | None = SQLField(default=None, primary_key=True)
7-
title: str
8-
description: str
9-
sort_order: float
10-
status_id: int | None = SQLField(default=None)
11-
user_id: int | None = SQLField(default=None, foreign_key="user.id")
6+
7+
class Task(Base):
8+
__tablename__ = "task"
9+
10+
id: Mapped[int] = mapped_column(primary_key=True)
11+
title: Mapped[str]
12+
description: Mapped[str]
13+
sort_order: Mapped[float]
14+
status_id: Mapped[int | None] = mapped_column(default=None)
15+
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
1216
# TODO: (Christopher) SQLModel does not allow to give the below restraint an name
1317
# which is needed for alembic. This results in changing the migration file
1418
# manually to give the restrain a name.
15-
project_id: int | None = SQLField(default=None, foreign_key="project.id")
19+
project_id: Mapped[int | None] = mapped_column(ForeignKey("project.id"))
1620
# todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id

amt/models/user.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from sqlmodel import Field, SQLModel # pyright: ignore [reportUnknownVariableType]
1+
from sqlalchemy.orm import Mapped, mapped_column
22

3+
from amt.models.base import Base
34

4-
class User(SQLModel, table=True):
5-
id: int = Field(default=None, primary_key=True)
6-
name: str
7-
avatar: str | None
5+
6+
class User(Base):
7+
__tablename__ = "user"
8+
9+
id: Mapped[int] = mapped_column(primary_key=True)
10+
name: Mapped[str]
11+
avatar: Mapped[str | None] = mapped_column(default=None)

0 commit comments

Comments
 (0)