Skip to content

Commit

Permalink
chore: refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nuffin committed Oct 20, 2024
1 parent 006e4cb commit c1cfb76
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 44 deletions.
4 changes: 1 addition & 3 deletions alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from alembic import context

from db import db, init_schemas

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
Expand All @@ -26,7 +24,7 @@


## to init db.metadata
init_schemas()
from db import db

# add your model's MetaData object here
# for 'autogenerate' support
Expand Down
2 changes: 1 addition & 1 deletion llmpa/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def insert_path():
insert_path()
del insert_path

from schemas.file import FileInfo
from schemas import FileInfo


JsonType = Dict[str, Any]
Expand Down
17 changes: 17 additions & 0 deletions llmpa/backends/local/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from . import LocalBackend


def main():
import os
import sys

## if len(sys.argv) < 2:
## print(f"Usage: {os.path.basename(__file__)} <filepath>")
## sys.exit(1)

backend = LocalBackend()
print(backend.embedding_text("gpt2", "Hello, world!"))


if __name__ == "__main__":
main()
5 changes: 0 additions & 5 deletions llmpa/backends/local/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@

from transformers import AutoModel, AutoFeatureExtractor

# from tensorflow.keras.applications import (
# EfficientNetV2B0,
# ResNet50,
# )


# EmbeddingExtractor Class with model name as a parameter
class EmbeddingExtractor:
Expand Down
3 changes: 2 additions & 1 deletion llmpa/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .postgres import db, async_engine, async_session, init_db, init_schemas
from .postgres import db, async_engine, async_session, init_db
from . import schemas
6 changes: 0 additions & 6 deletions llmpa/db/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
async_session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)


def init_schemas():
from schemas import File, Task, User

__models = [File, Task, User]


def init_db(app: Flask):
global async_engine, async_session

Expand Down
3 changes: 3 additions & 0 deletions llmpa/db/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .file import File
from .task import Task
from .user import User
2 changes: 1 addition & 1 deletion llmpa/schemas/base.py → llmpa/db/schemas/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from sqlalchemy.ext.declarative import declared_attr

from db import db
from ..postgres import db


## for backup
Expand Down
22 changes: 1 addition & 21 deletions llmpa/schemas/file.py → llmpa/db/schemas/file.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from sqlalchemy.dialects.postgresql import UUID

from db import db
from ..postgres import db
from .base import timestamped, soft_deletable


Expand All @@ -27,23 +27,3 @@ def __repr__(self):
return f"<File {self.name}>"


class FileInfo:
def __init__(
self,
fileId: str,
name: str,
path: str,
originName: str,
size: int,
type: str,
mimetype: str,
):
self.name = name
self.path = path
self.originName = originName
self.size = size
self.type = type
self.mimetype = mimetype

def __repr__(self):
return f"<FileInfo {self.name} ({self.fileType})>"
2 changes: 1 addition & 1 deletion llmpa/schemas/task.py → llmpa/db/schemas/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.dialects.postgresql import JSON, UUID

from agents.task import TaskStatus
from db import db
from ..postgres import db
from .base import timestamped, soft_deletable


Expand Down
2 changes: 1 addition & 1 deletion llmpa/schemas/user.py → llmpa/db/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import event
from sqlalchemy.dialects.postgresql import UUID

from db import db
from ..postgres import db
from .base import timestamped, soft_deletable


Expand Down
5 changes: 1 addition & 4 deletions llmpa/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
from .file import File
from .task import Task
from .user import User
from .fileinfo import FileInfo

__all__ = ["File", "Task", "User"]
22 changes: 22 additions & 0 deletions llmpa/schemas/fileinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class FileInfo:
def __init__(
self,
fileId: str,
name: str,
path: str,
originName: str,
size: int,
type: str,
mimetype: str,
):
self.name = name
self.path = path
self.originName = originName
self.size = size
self.type = type
self.mimetype = mimetype

def __repr__(self):
return f"<FileInfo {self.name} ({self.type})>"


42 changes: 42 additions & 0 deletions llmpa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

from .fileparser.mimetype import detect
from .schemas.fileinfo import FileInfo

def load_file_info(filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError(f"File '{filepath}' does not exist")

# Get file information
name = os.path.basename(filepath)
path = os.path.abspath(filepath)
originName = name
size = os.path.getsize(filepath)
file_type = os.path.splitext(filepath)[1][1:] # File extension as type
mimetype = detect(filepath)

return FileInfo(
fileId=str(os.path.getctime(filepath)), # Using creation time as fileId (or customize)
name=name,
path=path,
originName=originName,
size=size,
type=file_type,
mimetype=mimetype or "unknown"
)

def main():
import os
import sys

if len(sys.argv) < 2:
print(f"Usage: {os.path.basename(__file__)} <filepath>")
sys.exit(1)

filepath = sys.argv[1]
file_info = load_file_info(filepath)
print(file_info)


if __name__ == "__main__":
main()

0 comments on commit c1cfb76

Please sign in to comment.