Skip to content

Commit

Permalink
feat: added websocket controller
Browse files Browse the repository at this point in the history
refactor: some of state transition and event looping
fix!: refactor items api
  • Loading branch information
Gaisberg authored and Gaisberg committed Aug 4, 2024
1 parent 3aea8a4 commit b91710a
Show file tree
Hide file tree
Showing 19 changed files with 635 additions and 346 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker-build-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Docker Build and Push Dev
on:
push:
branches:
- main
- dev

jobs:
build-and-push-dev:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
data/
logs/
settings.json
ignore.txt
.vscode
.git
makefile
Expand Down
28 changes: 1 addition & 27 deletions src/controllers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,4 @@ async def get_stats(_: Request):
payload["incomplete_retries"] = incomplete_retries
payload["states"] = states

return {"success": True, "data": payload}

@router.get("/scrape/{item_id:path}")
async def scrape_item(item_id: str, request: Request):
with db.Session() as session:
item = DB._get_item_from_db(session, MediaItem({"imdb_id":str(item_id)}))
if item is None:
raise HTTPException(status_code=404, detail="Item not found")

scraper = request.app.program.services.get(Scraping)
if scraper is None:
raise HTTPException(status_code=404, detail="Scraping service not found")

time_now = time.time()
scraped_results = scraper.scrape(item, log=False)
time_end = time.time()
duration = time_end - time_now

results = {}
for hash, torrent in scraped_results.items():
results[hash] = {
"title": torrent.data.parsed_title,
"raw_title": torrent.raw_title,
"rank": torrent.rank,
}

return {"success": True, "total": len(results), "duration": round(duration, 3), "results": results}
return {"success": True, "data": payload}
188 changes: 119 additions & 69 deletions src/controllers/items.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List, Optional
from copy import copy
from typing import Optional

import Levenshtein
import program.db.db_functions as DB
from fastapi import APIRouter, HTTPException, Request
from program.db.db import db
from program.media.item import Episode, MediaItem, Season
from program.downloaders import Downloader
from program.media.item import MediaItem
from program.media.state import States
from program.symlink import Symlinker
from pydantic import BaseModel
from sqlalchemy import func, select
from program.media.stream import Stream
from program.scrapers import Scraping
from utils.logger import logger

router = APIRouter(
Expand All @@ -17,19 +19,13 @@
responses={404: {"description": "Not found"}},
)


class IMDbIDs(BaseModel):
imdb_ids: Optional[List[str]] = None


@router.get("/states")
async def get_states():
return {
"success": True,
"states": [state for state in States],
}


@router.get(
"",
summary="Retrieve Media Items",
Expand All @@ -43,6 +39,7 @@ async def get_items(
state: Optional[str] = None,
sort: Optional[str] = "desc",
search: Optional[str] = None,
extended: Optional[bool] = False,
):
if page < 1:
raise HTTPException(status_code=400, detail="Page number must be 1 or greater.")
Expand Down Expand Up @@ -85,9 +82,10 @@ async def get_items(
if type not in ["movie", "show", "season", "episode"]:
raise HTTPException(
status_code=400,
detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']",
)
query = query.where(MediaItem.type.in_(types))
detail=f"Invalid type: {type}. Valid types are: ['movie', 'show', 'season', 'episode']")
else:
types=[type]
query = query.where(MediaItem.type.in_(types))

if sort and not search:
if sort.lower() == "asc":
Expand All @@ -108,24 +106,19 @@ async def get_items(

return {
"success": True,
"items": [item.to_dict() for item in items],
"items": [item.to_extended_dict() if extended else item.to_dict() for item in items],
"page": page,
"limit": limit,
"total_items": total_items,
"total_pages": total_pages,
}


@router.get("/extended/{item_id}")
async def get_extended_item_info(_: Request, item_id: str):
with db.Session() as session:
item = session.execute(select(MediaItem).where(MediaItem.imdb_id == item_id)).unique().scalar_one_or_none()
if item is None:
raise HTTPException(status_code=404, detail="Item not found")
return {"success": True, "item": item.to_extended_dict()}


@router.post("/add")
@router.post(
"/add",
summary="Add Media Items",
description="Add media items with bases on imdb IDs",
)
async def add_items(
request: Request, imdb_ids: str = None
):
Expand All @@ -151,51 +144,108 @@ async def add_items(

return {"success": True, "message": f"Added {len(valid_ids)} item(s) to the queue"}


@router.delete("/remove")
async def remove_item(
_: Request, imdb_id: str
@router.post(
"/reset",
summary="Reset Media Items",
description="Reset media items with bases on item IDs",
)
async def reset_items(
request: Request, ids: str
):
if not imdb_id:
raise HTTPException(status_code=400, detail="No IMDb ID provided")
if DB._remove_item_from_db(imdb_id):
return {"success": True, "message": f"Removed item with imdb_id {imdb_id}"}
return {"success": False, "message": f"No item with imdb_id ({imdb_id}) found"}


@router.get("/imdb/{imdb_id}")
async def get_imdb_info(
_: Request,
imdb_id: str,
season: Optional[int] = None,
episode: Optional[int] = None,
ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)]
if not ids:
raise HTTPException(status_code=400, detail="No item ID provided")
with db.Session() as session:
items = []
for id in ids:
item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
item.streams = session.execute(select(Stream).where(Stream.parent_id == item._id)).scalars().all()
items.append(item)
for item in items:
request.app.program._remove_from_running_events(item)
if item.type == "show":
for season in item.seasons:
for episode in season.episodes:
episode.reset()
season.reset()
elif item.type == "season":
for episode in item.episodes:
episode.reset()
item.reset()

session.commit()
return {"success": True, "message": f"Reset items with id {ids}"}

@router.post(
"/retry",
summary="Retry Media Items",
description="Retry media items with bases on item IDs",
)
async def retry_items(
request: Request, ids: str
):
"""
Get the item with the given IMDb ID.
If the season and episode are provided, get the item with the given season and episode.
"""
ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)]
if not ids:
raise HTTPException(status_code=400, detail="No item ID provided")
with db.Session() as session:
if season is not None and episode is not None:
item = session.execute(
select(Episode).where(
(Episode.imdb_id == imdb_id) &
(Episode.season_number == season) &
(Episode.episode_number == episode)
)
).scalar_one_or_none()
elif season is not None:
item = session.execute(
select(Season).where(
(Season.imdb_id == imdb_id) &
(Season.season_number == season)
)
).scalar_one_or_none()
else:
item = session.execute(
select(MediaItem).where(MediaItem.imdb_id == imdb_id)
).scalar_one_or_none()

if item is None:
raise HTTPException(status_code=404, detail="Item not found")

return {"success": True, "item": item.to_extended_dict()}
items = []
for id in ids:
items.append(session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one())
for item in items:
request.app.program._remove_from_running_events(item)
request.app.program.add_to_queue(item)

return {"success": True, "message": f"Retried items with id {ids}"}

@router.delete(
"",
summary="Remove Media Items",
description="Remove media items with bases on item IDs",)
async def remove_item(
_: Request, ids: str
):
ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)]
if not ids:
raise HTTPException(status_code=400, detail="No item ID provided")
for id in ids:
DB._remove_item_from_db(id)
return {"success": True, "message": f"Removed item with id {id}"}

# These require downloaders to be refactored

# @router.get("/cached")
# async def manual_scrape(request: Request, ids: str):
# scraper = request.app.program.services.get(Scraping)
# downloader = request.app.program.services.get(Downloader).service
# if downloader.__class__.__name__ not in ["RealDebridDownloader", "TorBoxDownloader"]:
# raise HTTPException(status_code=400, detail="Only Real-Debrid is supported for manual scraping currently")
# ids = [int(id) for id in ids.split(",")] if "," in ids else [int(ids)]
# if not ids:
# raise HTTPException(status_code=400, detail="No item ID provided")
# with db.Session() as session:
# items = []
# return_dict = {}
# for id in ids:
# items.append(session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one())
# if any(item for item in items if item.type in ["Season", "Episode"]):
# raise HTTPException(status_code=400, detail="Only shows and movies can be manually scraped currently")
# for item in items:
# new_item = item.__class__({})
# # new_item.parent = item.parent
# new_item.copy(item)
# new_item.copy_other_media_attr(item)
# scraped_results = scraper.scrape(new_item, log=False)
# cached_hashes = downloader.get_cached_hashes(new_item, scraped_results)
# for hash, stream in scraped_results.items():
# return_dict[hash] = {"cached": hash in cached_hashes, "name": stream.raw_title}
# return {"success": True, "data": return_dict}

# @router.post("/download")
# async def download(request: Request, id: str, hash: str):
# downloader = request.app.program.services.get(Downloader).service
# with db.Session() as session:
# item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
# item.reset(True)
# downloader.download_cached(item, hash)
# request.app.program.add_to_queue(item)
# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"}
27 changes: 26 additions & 1 deletion src/controllers/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy
from typing import Any, List
from typing import Any, Dict, List

from fastapi import APIRouter, HTTPException
from program.settings.manager import settings_manager
Expand Down Expand Up @@ -65,6 +65,31 @@ async def get_settings(paths: str):
}


@router.post("/set/all")
async def set_all_settings(new_settings: Dict[str, Any]):
current_settings = settings_manager.settings.model_dump()

def update_settings(current_obj, new_obj):
for key, value in new_obj.items():
if isinstance(value, dict) and key in current_obj:
update_settings(current_obj[key], value)
else:
current_obj[key] = value

update_settings(current_settings, new_settings)

# Validate and save the updated settings
try:
settings_manager.settings = settings_manager.settings.model_validate(current_settings)
await save_settings()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

return {
"success": True,
"message": "All settings updated successfully!",
}

@router.post("/set")
async def set_settings(settings: List[SetSettings]):
current_settings = settings_manager.settings.model_dump()
Expand Down
53 changes: 53 additions & 0 deletions src/controllers/ws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
from loguru import logger
from fastapi import APIRouter, WebSocket, WebSocketDisconnect

router = APIRouter(
prefix="/ws",
tags=["websocket"],
responses={404: {"description": "Not found"}})

class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
logger.debug("Frontend connected!")
self.active_connections.append(websocket)
await websocket.send_json({"type": "health", "status": "running"})

def disconnect(self, websocket: WebSocket):
logger.debug("Frontend disconnected!")
self.active_connections.remove(websocket)

async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)

async def send_log_message(self, message: str):
await self.broadcast({"type": "log", "message": message})

async def send_item_update(self, item: json):
await self.broadcast({"type": "item_update", "item": item})

async def broadcast(self, message: json):
for connection in self.active_connections:
try:
await connection.send_json(message)
except RuntimeError:
self.active_connections.remove(connection)


manager = ConnectionManager()


@router.websocket("")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(websocket)
except RuntimeError:
manager.disconnect(websocket)
Loading

0 comments on commit b91710a

Please sign in to comment.