Skip to content

Commit 77bb052

Browse files
Add database support with migrations
1 parent c7492d9 commit 77bb052

File tree

21 files changed

+361
-3
lines changed

21 files changed

+361
-3
lines changed

cookiecutter.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"machine_learn_model_path": "./ml/model/",
66
"machine_learn_model_name": "model.pkl",
77
"input_example_path": "./ml/model/examples/example.json",
8+
"database_url": "sqlite:///./app.db",
89
"full_name": "Your name",
910
"email": "[email protected]",
1011
"release_date": "{% now 'local' %}",
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
SECRET_KEY=secret
22
DEBUG=True
33
MODEL_PATH=/Users/arthur.dasilva/repos/arthurhenrique/n
4-
MODEL_NAME=pregnancy_model_local.joblib
4+
MODEL_NAME=pregnancy_model_local.joblib
5+
MEMOIZATION_FLAG=False
6+
DATABASE_URL=sqlite:///./app.db

sample/pregnancy-model/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@ MODEL_PATH=./ml/model/
1414
MODEL_NAME=model.pkl
1515
```
1616

17+
### Database Configuration
18+
19+
Set your database url in `.env` using `DATABASE_URL`. The default uses SQLite:
20+
21+
```sh
22+
DATABASE_URL=sqlite:///./app.db
23+
```
24+
25+
### Migrations
26+
27+
Create a new migration with:
28+
29+
```sh
30+
alembic revision --autogenerate -m "message"
31+
```
32+
33+
Apply migrations with:
34+
35+
```sh
36+
alembic upgrade head
37+
```
38+
1739
### Update `/predict`
1840

1941
To update your machine learning model, add your `load` and `method` [change here](app/api/routes/predictor.py#L19) at `predictor.py`

sample/pregnancy-model/alembic.ini

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[alembic]
2+
script_location = migrations
3+
sqlalchemy.url = sqlite:///./app.db
4+
5+
[loggers]
6+
keys = root,sqlalchemy,alembic
7+
8+
[handlers]
9+
keys = console
10+
11+
[formatters]
12+
keys = generic
13+
14+
[logger_root]
15+
level = WARN
16+
handlers = console
17+
qualname =
18+
19+
[logger_sqlalchemy]
20+
level = WARN
21+
handlers =
22+
qualname = sqlalchemy.engine
23+
24+
[logger_alembic]
25+
level = INFO
26+
handlers =
27+
qualname = alembic
28+
29+
[handler_console]
30+
class = StreamHandler
31+
args = (sys.stderr,)
32+
level = NOTSET
33+
formatter = generic
34+
35+
[formatter_generic]
36+
format = %(levelname)-5.5s [%(name)s] %(message)s

sample/pregnancy-model/app/core/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from starlette.config import Config
66
from starlette.datastructures import Secret
77

8-
from core.logging import InterceptHandler
8+
from .logging import InterceptHandler
99

1010
config = Config(".env")
1111

@@ -29,3 +29,4 @@
2929
MODEL_PATH = config("MODEL_PATH", default="/Users/arthur.dasilva/repos/arthurhenrique/cookiecutter-fastapi/sample/pregnancy-model")
3030
MODEL_NAME = config("MODEL_NAME", default="pregnancy_model_local.joblib")
3131
INPUT_EXAMPLE = config("INPUT_EXAMPLE", default="./ml/model/examples/example.json")
32+
DATABASE_URL = config("DATABASE_URL", default="sqlite:///./app.db")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from sqlalchemy import create_engine
2+
from sqlalchemy.orm import sessionmaker, declarative_base
3+
4+
from .config import DATABASE_URL
5+
6+
connect_args = {}
7+
if DATABASE_URL.startswith("sqlite"):
8+
connect_args["check_same_thread"] = False
9+
10+
engine = create_engine(DATABASE_URL, connect_args=connect_args)
11+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
12+
13+
Base = declarative_base()
14+
15+
def get_db():
16+
db = SessionLocal()
17+
try:
18+
yield db
19+
finally:
20+
db.close()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from sqlalchemy import Column, Integer, String
2+
3+
from app.core.database import Base
4+
5+
6+
class Item(Base):
7+
__tablename__ = "items"
8+
9+
id = Column(Integer, primary_key=True, index=True)
10+
name = Column(String, index=True)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from logging.config import fileConfig
2+
3+
from sqlalchemy import engine_from_config
4+
from sqlalchemy import pool
5+
6+
from alembic import context
7+
8+
from core.database import Base
9+
from models import item
10+
from core.config import DATABASE_URL
11+
12+
config = context.config
13+
14+
if config.get_main_option("sqlalchemy.url") is None:
15+
config.set_main_option("sqlalchemy.url", DATABASE_URL)
16+
17+
fileConfig(config.config_file_name)
18+
19+
target_metadata = Base.metadata
20+
21+
22+
def run_migrations_offline() -> None:
23+
context.configure(
24+
url=config.get_main_option("sqlalchemy.url"),
25+
target_metadata=target_metadata,
26+
literal_binds=True,
27+
dialect_opts={"paramstyle": "named"},
28+
)
29+
30+
with context.begin_transaction():
31+
context.run_migrations()
32+
33+
34+
def run_migrations_online() -> None:
35+
connectable = engine_from_config(
36+
config.get_section(config.config_ini_section),
37+
prefix="sqlalchemy.",
38+
poolclass=pool.NullPool,
39+
)
40+
41+
with connectable.connect() as connection:
42+
context.configure(connection=connection, target_metadata=target_metadata)
43+
44+
with context.begin_transaction():
45+
context.run_migrations()
46+
47+
48+
if context.is_offline_mode():
49+
run_migrations_offline()
50+
else:
51+
run_migrations_online()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""create items table
2+
3+
Revision ID: 0001
4+
Revises:
5+
Create Date: 2020-01-01 00:00:00.000000
6+
"""
7+
from alembic import op
8+
import sqlalchemy as sa
9+
10+
revision = "0001"
11+
down_revision = None
12+
branch_labels = None
13+
depends_on = None
14+
15+
16+
def upgrade() -> None:
17+
op.create_table(
18+
"items",
19+
sa.Column("id", sa.Integer(), primary_key=True),
20+
sa.Column("name", sa.String(), nullable=True),
21+
)
22+
23+
24+
def downgrade() -> None:
25+
op.drop_table("items")

sample/pregnancy-model/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ dependencies = [
1414
"loguru>=0.7.0",
1515
"joblib>=1.2.0",
1616
"scikit-learn>=1.1.3",
17+
"SQLAlchemy>=2.0.0",
18+
"alembic>=1.12.0",
1719
"pandas>=2.2.3",
1820
]
1921

0 commit comments

Comments
 (0)