From 19cedc843382acb837c9cd23ddec522d342ed9f5 Mon Sep 17 00:00:00 2001
From: Gaisberg <93206976+Gaisberg@users.noreply.github.com>
Date: Sat, 26 Oct 2024 14:11:30 +0300
Subject: [PATCH] fix: future cancellation resulted in reset, retry endpoints
 fialing (#817)

* fix: future cancellation resulted in reset, retry endpoints fialing

* fix: update reset func to check if indexed

---------

Co-authored-by: Gaisberg <None>
Co-authored-by: Spoked <dreu.lavelle@gmail.com>
---
 src/program/db/db_functions.py | 36 ++++++++----------------
 src/program/media/item.py      | 51 +++++++++++++++++-----------------
 src/program/symlink.py         |  3 +-
 src/routers/secure/items.py    |  2 +-
 src/utils/event_manager.py     | 25 +++++++++--------
 5 files changed, 53 insertions(+), 64 deletions(-)

diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py
index 72803fc5..551a87a5 100644
--- a/src/program/db/db_functions.py
+++ b/src/program/db/db_functions.py
@@ -1,5 +1,6 @@
 import os
 import shutil
+from threading import Event
 from typing import TYPE_CHECKING
 
 import alembic
@@ -171,15 +172,9 @@ def reset_media_item(item: "MediaItem"):
         item.reset()
         session.commit()
 
-def reset_streams(item: "MediaItem", active_stream_hash: str = None):
+def reset_streams(item: "MediaItem"):
     """Reset streams associated with a MediaItem."""
     with db.Session() as session:
-        item.store_state()
-        item = session.merge(item)
-        if active_stream_hash:
-            stream = session.query(Stream).filter(Stream.infohash == active_stream_hash).first()
-            if stream:
-                blacklist_stream(item, stream, session)
 
         session.execute(
             delete(StreamRelation).where(StreamRelation.parent_id == item._id)
@@ -188,20 +183,11 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None):
         session.execute(
             delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
         )
-        item.active_stream = {}
         session.commit()
 
 def clear_streams(item: "MediaItem"):
     """Clear all streams for a media item."""
-    with db.Session() as session:
-        item = session.merge(item)
-        session.execute(
-            delete(StreamRelation).where(StreamRelation.parent_id == item._id)
-        )
-        session.execute(
-            delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
-        )
-        session.commit()
+    reset_streams(item)
 
 def clear_streams_by_id(media_item_id: int):
     """Clear all streams for a media item by the MediaItem _id."""
@@ -358,7 +344,7 @@ def store_item(item: "MediaItem"):
         finally:
             session.close()
 
-def run_thread_with_db_item(fn, service, program, input_id: int = None):
+def run_thread_with_db_item(fn, service, program, input_id, cancellation_event: Event):
     from program.media.item import MediaItem
     if input_id:
         with db.Session() as session:
@@ -378,11 +364,12 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
                         logger.log("PROGRAM", f"Service {service.__name__} emitted {item} from input item {input_item} of type {type(item).__name__}, backing off.")
                         program.em.remove_id_from_queues(input_item._id)
 
-                    input_item.store_state()
-                    session.commit()
+                    if not cancellation_event.is_set():
+                        input_item.store_state()
+                        session.commit()
 
                     session.expunge_all()
-                    yield res
+                    return res
             else:
                 # Indexing returns a copy of the item, was too lazy to create a copy attr func so this will do for now
                 indexed_item = next(fn(input_item), None)
@@ -393,9 +380,10 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
                     indexed_item.store_state()
                     session.delete(input_item)
                     indexed_item = session.merge(indexed_item)
-                    session.commit()
-                    logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
-                    yield indexed_item._id
+                    if not cancellation_event.is_set():
+                        session.commit()
+                        logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
+                    return indexed_item._id
         return
     else:
         # Content services
diff --git a/src/program/media/item.py b/src/program/media/item.py
index e00ffd3a..5f17b13d 100644
--- a/src/program/media/item.py
+++ b/src/program/media/item.py
@@ -132,8 +132,8 @@ def __init__(self, item: dict | None) -> None:
         #Post processing
         self.subtitles = item.get("subtitles", [])
 
-    def store_state(self) -> None:
-        new_state = self._determine_state()
+    def store_state(self, given_state=None) -> None:
+        new_state = given_state if given_state else self._determine_state()
         if self.last_state and self.last_state != new_state:
             sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id})
         self.last_state = new_state
@@ -145,6 +145,10 @@ def is_stream_blacklisted(self, stream: Stream):
             session.refresh(self, attribute_names=['blacklisted_streams'])
         return stream in self.blacklisted_streams
 
+    def blacklist_active_stream(self):
+        stream = next(stream for stream in self.streams if stream.infohash == self.active_stream["infohash"])
+        self.blacklist_stream(stream)
+
     def blacklist_stream(self, stream: Stream):
         value = blacklist_stream(self, stream)
         if value:
@@ -321,20 +325,23 @@ def get_aliases(self) -> dict:
     def __hash__(self):
         return hash(self._id)
 
-    def reset(self, soft_reset: bool = False):
+    def reset(self):
         """Reset item attributes."""
         if self.type == "show":
             for season in self.seasons:
                 for episode in season.episodes:
-                    episode._reset(soft_reset)
-                season._reset(soft_reset)
+                    episode._reset()
+                season._reset()
         elif self.type == "season":
             for episode in self.episodes:
-                episode._reset(soft_reset)
-        self._reset(soft_reset)
-        self.store_state()
+                episode._reset()
+        self._reset()
+        if self.title:
+            self.store_state(States.Indexed)
+        else:
+            self.store_state(States.Requested)
 
-    def _reset(self, soft_reset):
+    def _reset(self):
         """Reset item attributes for rescraping."""
         if self.symlink_path:
             if Path(self.symlink_path).exists():
@@ -351,16 +358,8 @@ def _reset(self, soft_reset):
         self.set("folder", None)
         self.set("alternative_folder", None)
 
-        if not self.active_stream:
-            self.active_stream = {}
-        if not soft_reset:
-            if self.active_stream.get("infohash", False):
-                reset_streams(self, self.active_stream["infohash"])
-        else:
-            if self.active_stream.get("infohash", False):
-                stream = next((stream for stream in self.streams if stream.infohash == self.active_stream["infohash"]), None)
-                if stream:
-                    self.blacklist_stream(stream)
+        reset_streams(self)
+        self.active_stream = {}
 
         self.set("active_stream", {})
         self.set("symlinked", False)
@@ -371,7 +370,7 @@ def _reset(self, soft_reset):
         self.set("symlinked_times", 0)
         self.set("scraped_times", 0)
 
-        logger.debug(f"Item {self.log_string} reset for rescraping")
+        logger.debug(f"Item {self.log_string} has been reset")
 
     @property
     def log_string(self):
@@ -456,10 +455,10 @@ def _determine_state(self):
             return States.Requested
         return States.Unknown
 
-    def store_state(self) -> None:
+    def store_state(self, given_state: States =None) -> None:
         for season in self.seasons:
-            season.store_state()
-        super().store_state()
+            season.store_state(given_state)
+        super().store_state(given_state)
 
     def __repr__(self):
         return f"Show:{self.log_string}:{self.state.name}"
@@ -527,10 +526,10 @@ class Season(MediaItem):
         "polymorphic_load": "inline",
     }
 
-    def store_state(self) -> None:
+    def store_state(self, given_state: States = None) -> None:
         for episode in self.episodes:
-            episode.store_state()
-        super().store_state()
+            episode.store_state(given_state)
+        super().store_state(given_state)
 
     def __init__(self, item):
         self.type = "season"
diff --git a/src/program/symlink.py b/src/program/symlink.py
index 0123ab0f..8f3cf7bb 100644
--- a/src/program/symlink.py
+++ b/src/program/symlink.py
@@ -94,7 +94,8 @@ def run(self, item: Union[Movie, Show, Season, Episode]):
         if not self._should_submit(items):
             if item.symlinked_times == 5:
                 logger.debug(f"Soft resetting {item.log_string} because required files were not found")
-                item.reset(True)
+                item.blacklist_active_stream()
+                item.reset()
                 yield item
             next_attempt = self._calculate_next_attempt(item)
             logger.debug(f"Waiting for {item.log_string} to become available, next attempt in {round((next_attempt - datetime.now()).total_seconds())} seconds")
diff --git a/src/routers/secure/items.py b/src/routers/secure/items.py
index 91347400..4b8ee20e 100644
--- a/src/routers/secure/items.py
+++ b/src/routers/secure/items.py
@@ -533,7 +533,7 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDRe
 #     downloader = request.app.program.services.get(Downloader).service
 #     with db.Session() as session:
 #         item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
-#         item.reset(True)
+#         item.reset()
 #         downloader.download_cached(item, hash)
 #         request.app.program.add_to_queue(item)
 #         return {"success": True, "message": f"Downloading {item.title} with hash {hash}"}
diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py
index 32c1c442..ab3e19b4 100644
--- a/src/utils/event_manager.py
+++ b/src/utils/event_manager.py
@@ -1,4 +1,5 @@
 import os
+import threading
 import traceback
 
 from datetime import datetime
@@ -8,8 +9,7 @@
 
 from loguru import logger
 from pydantic import BaseModel
-from sqlalchemy.orm.exc import StaleDataError
-from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
+from concurrent.futures import Future, ThreadPoolExecutor
 
 from utils.sse_manager import sse_manager
 from program.db.db import db
@@ -37,6 +37,7 @@ def __init__(self):
         self._futures: list[Future] = []
         self._queued_events: list[Event] = []
         self._running_events: list[Event] = []
+        self._canceled_futures: list[Future] = []
         self.mutex = Lock()
 
     def _find_or_create_executor(self, service_cls) -> ThreadPoolExecutor:
@@ -71,7 +72,7 @@ def _process_future(self, future, service):
             service (type): The service class associated with the future.
         """
         try:
-            result = next(future.result(), None)
+            result = future.result()
             if future in self._futures:
                 self._futures.remove(future)
             sse_manager.publish_event("event_update", self.get_event_updates())
@@ -81,10 +82,10 @@ def _process_future(self, future, service):
                 item_id, timestamp = result, datetime.now()
             if item_id:
                 self.remove_event_from_running(item_id)
+                if future.cancellation_event.is_set():
+                    logger.debug(f"Future with Item ID: {item_id} was cancelled discarding results...")
+                    return
                 self.add_event(Event(emitted_by=service, item_id=item_id, run_at=timestamp))
-        except (StaleDataError, CancelledError):
-            # Expected behavior when cancelling tasks or when the item was removed
-            return
         except Exception as e:
             logger.error(f"Error in future for {future}: {e}")
             logger.exception(traceback.format_exc())
@@ -166,8 +167,10 @@ def submit_job(self, service, program, event=None):
             log_message += f" with Item ID: {item_id}"
         logger.debug(log_message)
 
+        cancellation_event = threading.Event()
         executor = self._find_or_create_executor(service)
-        future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id)
+        future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id, cancellation_event)
+        future.cancellation_event = cancellation_event
         if event:
             future.event = event
         self._futures.append(future)
@@ -186,27 +189,25 @@ def cancel_job(self, item_id: int, suppress_logs=False):
             item_id, related_ids = get_item_ids(session, item_id)
             ids_to_cancel = set([item_id] + related_ids)
 
-            futures_to_remove = []
             for future in self._futures:
                 future_item_id = None
                 future_related_ids = []
 
-                if hasattr(future, 'event') and hasattr(future.event, 'item'):
+                if hasattr(future, 'event') and hasattr(future.event, 'item_id'):
                     future_item = future.event.item_id
                     future_item_id, future_related_ids = get_item_ids(session, future_item)
 
                 if future_item_id in ids_to_cancel or any(rid in ids_to_cancel for rid in future_related_ids):
                     self.remove_id_from_queues(future_item)
-                    futures_to_remove.append(future)
                     if not future.done() and not future.cancelled():
                         try:
+                            future.cancellation_event.set()
                             future.cancel()
+                            self._canceled_futures.append(future)
                         except Exception as e:
                             if not suppress_logs:
                                 logger.error(f"Error cancelling future for {future_item.log_string}: {str(e)}")
 
-            for future in futures_to_remove:
-                self._futures.remove(future)
 
         logger.debug(f"Canceled jobs for Item ID {item_id} and its children.")