Skip to content

Commit 2715b8d

Browse files
authored
Merge pull request #23 from bit-bots/feature/database_reader
feature/database_reader
2 parents 8454d06 + ba296a4 commit 2715b8d

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ env/
22
logs/
33
*.csv
44
*.sqlite
5+
*.sqlite3
56

67
# Created by .ignore support plugin (hsz.mobi)
78
### JetBrains template
@@ -205,4 +206,4 @@ ENV/
205206
.ruff_cache
206207

207208
# Torch models
208-
*.pth
209+
*.pth

ddlitlab2024/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@
3838
LOGGING_PATH: str = _logging_path
3939

4040
SESSION_ID: UUID = uuid4()
41+
42+
DB_PATH: str = os.path.join(os.path.dirname(__file__), "dataset", "db.sqlite3")

ddlitlab2024/dataset/reader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import sqlite3
2+
3+
from ddlitlab2024 import DB_PATH
4+
5+
6+
def get_connection(path: str = DB_PATH) -> sqlite3.Connection:
7+
conn = sqlite3.connect(path)
8+
return conn

ddlitlab2024/dataset/schema.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from datetime import datetime
23
from enum import Enum
34
from typing import List, Optional
@@ -6,6 +7,7 @@
67
from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship, sessionmaker
78
from sqlalchemy.types import LargeBinary
89

10+
from ddlitlab2024 import DB_PATH
911
from ddlitlab2024.dataset import logger
1012

1113
Base = declarative_base()
@@ -229,9 +231,16 @@ class GameState(Base):
229231
__table_args__ = (CheckConstraint(state.in_(RobotState.values())),)
230232

231233

234+
def parse_args():
235+
parser = argparse.ArgumentParser(description="Create the database schema")
236+
parser.add_argument("--db-path", type=str, default=DB_PATH, help="Path to the database file")
237+
return parser.parse_args()
238+
239+
232240
def main():
233241
logger.info("Creating database schema")
234-
engine = create_engine("sqlite:///data.sqlite")
242+
args = parse_args()
243+
engine = create_engine(f"sqlite:///{args.db_path}")
235244
Base.metadata.create_all(engine)
236245
sessionmaker(bind=engine)()
237246
logger.info("Database schema created")

0 commit comments

Comments
 (0)