diff --git a/README.md b/README.md index a13098a..393a112 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,9 @@ platforms such as GitHub discussions/issues might be added in the future. | DAILY_TASKS | False | `true` | Daily tasks on or off. | | DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | | DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | -| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | +| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | +| DATA_REPO | False | `https://github.com/LizardByte/support-bot-data` | Repository to store persistent data. This repository should be private! | +| DATA_REPO_BRANCH | False | `master` | Branch to store persistent data. | | DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | | DISCORD_CLIENT_ID | True | `None` | Discord OAuth2 client id. | | DISCORD_CLIENT_SECRET | True | `None` | Discord OAuth2 client secret. | @@ -58,11 +60,11 @@ platforms such as GitHub discussions/issues might be added in the future. | GRAVATAR_EMAIL | False | `None` | Gravatar email address for bot avatar. | | IGDB_CLIENT_ID | False | `None` | Required if daily_releases is enabled. | | IGDB_CLIENT_SECRET | False | `None` | Required if daily_releases is enabled. | -| PRAW_CLIENT_ID | True | None | `client_id` from reddit app setup page. | -| PRAW_CLIENT_SECRET | True | None | `client_secret` from reddit app setup page. | -| PRAW_SUBREDDIT | True | None | Subreddit to monitor (reddit user should be moderator of the subreddit) | -| REDDIT_USERNAME | True | None | Reddit username | -| REDDIT_PASSWORD | True | None | Reddit password | +| PRAW_CLIENT_ID | True | `None` | `client_id` from reddit app setup page. | +| PRAW_CLIENT_SECRET | True | `None` | `client_secret` from reddit app setup page. | +| PRAW_SUBREDDIT | True | `None` | Subreddit to monitor (reddit user should be moderator of the subreddit) | +| REDDIT_USERNAME | True | `None` | Reddit username | +| REDDIT_PASSWORD | True | `None` | Reddit password | | SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | | SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | diff --git a/requirements.txt b/requirements.txt index fa63dd9..698c053 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ py-cord==2.6.1 python-dotenv==1.1.0 requests==2.32.3 requests-oauthlib==2.0.0 +tinydb==4.8.2 diff --git a/src/common/database.py b/src/common/database.py index b27fdd4..4f18c51 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,22 +1,271 @@ # standard imports +import os +from pathlib import Path import shelve import threading +import traceback +from typing import Union + +# lib imports +import git +from tinydb import TinyDB +from tinydb.storages import JSONStorage +from tinydb.middlewares import CachingMiddleware + +# local imports +from src.common.common import data_dir + +# Constants +DATA_REPO_LOCK = threading.Lock() class Database: - def __init__(self, db_path): - self.db_path = db_path + def __init__(self, db_name: str, db_dir: Union[str, Path] = data_dir, use_git: bool = True): + self.db_name = db_name + self.db_dir = db_dir + + # Check for CI environment + is_ci = os.environ.get('GITHUB_PYTEST', '').lower() == 'true' + + self.use_git = use_git and not is_ci + + self.repo_url = None + self.repo_branch = None + if self.use_git: + self.repo_url = os.getenv("DATA_REPO", "https://github.com/LizardByte/support-bot-data") + self.repo_branch = os.getenv("DATA_REPO_BRANCH", "master") + self.db_dir = os.path.join(self.db_dir, "support-bot-data") + + if not os.path.exists(self.db_dir): + # Clone repo if it doesn't exist + print(f"Cloning repository {self.repo_url} to {self.db_dir}") + try: + # Try cloning with the specified branch + self.repo = git.Repo.clone_from(self.repo_url, self.db_dir, branch=self.repo_branch) + except git.exc.GitCommandError as e: + # Check if the error is due to branch not found + if "Remote branch" in str(e) and "not found in upstream origin" in str(e): + print(f"Branch '{self.repo_branch}' not found in remote. Creating a new empty branch.") + # Clone with default branch first + self.repo = git.Repo.clone_from(self.repo_url, self.db_dir) + + # Create a new orphan branch (not based on any other branch) + self.repo.git.checkout('--orphan', self.repo_branch) + + # Clear the index and working tree + try: + self.repo.git.rm('-rf', '.', '--cached') + except git.exc.GitCommandError: + # This might fail if there are no files yet, which is fine + pass + + # Remove all files in the directory except .git + for item in os.listdir(self.db_dir): + if item != '.git': + item_path = os.path.join(self.db_dir, item) + if os.path.isdir(item_path): + import shutil + shutil.rmtree(item_path) + else: + os.remove(item_path) + + # Create empty .gitkeep file to ensure the branch can be committed + gitkeep_path = os.path.join(self.db_dir, '.gitkeep') + with open(gitkeep_path, 'w'): + pass + + # Add and commit the .gitkeep file + self.repo.git.add(gitkeep_path) + self.repo.git.commit('-m', f"Initialize empty branch '{self.repo_branch}'") + + # Push the new branch to remote + try: + self.repo.git.push('--set-upstream', 'origin', self.repo_branch) + print(f"Created and pushed new empty branch '{self.repo_branch}'") + except git.exc.GitCommandError as e: + print(f"Failed to push new branch: {str(e)}") + # Continue anyway - we might not have push permissions + else: + # Re-raise if it's a different error + raise + else: + # Use existing repo + self.repo = git.Repo(self.db_dir) + + # Make sure the correct branch is checked out + if self.repo_branch not in [ref.name.split('/')[-1] for ref in self.repo.refs]: + # Branch doesn't exist locally, check if it exists remotely + try: + self.repo.git.fetch('origin') + remote_branches = [ref.name.split('/')[-1] for ref in self.repo.remote().refs] + + if self.repo_branch in remote_branches: + # Checkout existing remote branch + self.repo.git.checkout(self.repo_branch) + else: + # Create new orphan branch + self.repo.git.checkout('--orphan', self.repo_branch) + self.repo.git.rm('-rf', '.', '--cached') + + # Create empty .gitkeep file + gitkeep_path = os.path.join(self.db_dir, '.gitkeep') + with open(gitkeep_path, 'w'): + pass + + self.repo.git.add(gitkeep_path) + self.repo.git.commit('-m', f"Initialize empty branch '{self.repo_branch}'") + self.repo.git.push('--set-upstream', 'origin', self.repo_branch) + print(f"Created and pushed new empty branch '{self.repo_branch}'") + except git.exc.GitCommandError: + print(f"Failed to work with branch '{self.repo_branch}'. Using current branch instead.") + else: + # Branch exists locally, make sure it's checked out + self.repo.git.checkout(self.repo_branch) + + self.json_path = os.path.join(self.db_dir, f"{self.db_name}.json") + self.shelve_path = os.path.join(db_dir, self.db_name) # Shelve adds its own extensions self.lock = threading.Lock() + # Check if migration is needed before creating TinyDB instance + self._check_for_migration() + + # Initialize the TinyDB instance with CachingMiddleware + self.tinydb = TinyDB( + self.json_path, + storage=CachingMiddleware(JSONStorage), + indent=4, + ) + + def _check_for_migration(self): + # Check if migration is needed (shelve exists but json doesn't) + # No extension is used on Linux + shelve_exists = os.path.exists(f"{self.shelve_path}.dat") or os.path.exists(self.shelve_path) + json_exists = os.path.exists(self.json_path) + + if shelve_exists and not json_exists: + print(f"Migrating database from shelve to TinyDB: {self.shelve_path}") + self._migrate_from_shelve() + + def _migrate_from_shelve(self): + try: + # Create a temporary database just for migration + migration_db = TinyDB( + self.json_path, + storage=CachingMiddleware(JSONStorage), + indent=4, + ) + + # Determine if this is the Reddit database + is_reddit_db = "reddit_bot" in self.db_name + + # Open the shelve database + with shelve.open(self.shelve_path) as shelve_db: + # Process each key in the shelve database + for key in shelve_db.keys(): + value = shelve_db[key] + + # If value is a dict and looks like a collection of records + if isinstance(value, dict) and all(isinstance(k, str) for k in value.keys()): + table = migration_db.table(key) + + # Insert each record into TinyDB with proper fields + for record_id, record_data in value.items(): + if isinstance(record_data, dict): + if is_reddit_db: + # Check if it's a comment or submission + is_comment = 'body' in record_data + + if is_comment: + # For comments + simplified_record = { + 'reddit_id': record_data.get('id', record_id), + 'author': record_data.get('author'), + 'body': record_data.get('body'), + 'created_utc': record_data.get('created_utc', 0), + 'processed': record_data.get('processed', False), + 'slash_command': record_data.get('slash_command', { + 'project': None, + 'command': None, + }), + } + else: + # For submissions + simplified_record = { + 'reddit_id': record_data.get('id', record_id), + 'title': record_data.get('title'), + 'selftext': record_data.get('selftext'), + 'author': str(record_data.get('author')), + 'created_utc': record_data.get('created_utc', 0), + 'permalink': record_data.get('permalink'), + 'url': record_data.get('url'), + 'link_flair_text': record_data.get('link_flair_text'), + 'link_flair_background_color': record_data.get( + 'link_flair_background_color'), + 'bot_discord': record_data.get('bot_discord', { + 'sent': False, + 'sent_utc': None, + }), + } + + table.insert(simplified_record) + else: + # Non-Reddit databases keep original structure + record_data['id'] = record_id + table.insert(record_data) + + # Flush changes to disk + migration_db.storage.flush() + migration_db.close() + + print(f"Migration completed successfully: {self.json_path}") + except Exception as e: + print(f"Migration failed: {str(e)}") + traceback.print_exc() + def __enter__(self): self.lock.acquire() - self.db = shelve.open(self.db_path, writeback=True) - return self.db + return self.tinydb def __exit__(self, exc_type, exc_val, exc_tb): self.sync() - self.db.close() self.lock.release() def sync(self): - self.db.sync() + # Only call flush if using CachingMiddleware + if hasattr(self.tinydb.storage, 'flush'): + self.tinydb.storage.flush() + + # Git operations - commit and push changes if using git + with DATA_REPO_LOCK: + if self.use_git and self.repo is not None: + try: + # Check for untracked database files and tracked files with changes + status = self.repo.git.status('--porcelain') + + # If there are any changes or untracked files + if status: + # Add ALL json files in the directory to ensure we track all databases + json_files = [f for f in os.listdir(self.db_dir) if f.endswith('.json')] + if json_files: + for json_file in json_files: + file_path = os.path.join(self.db_dir, json_file) + self.repo.git.add(file_path) + + # Check if we have anything to commit after adding + if self.repo.git.status('--porcelain'): + # Commit all changes at once with a general message + commit_message = "Update database files" + self.repo.git.commit('-m', commit_message) + print("Committed changes to git data repository") + + # Push to remote + try: + origin = self.repo.remote('origin') + origin.push() + print("Pushed changes to remote git data repository") + except git.exc.GitCommandError as e: + print(f"Failed to push changes: {str(e)}") + + except Exception as e: + print(f"Git operation failed: {str(e)}") + traceback.print_exc() diff --git a/src/common/webapp.py b/src/common/webapp.py index 65a1935..3542fdd 100644 --- a/src/common/webapp.py +++ b/src/common/webapp.py @@ -9,6 +9,7 @@ import discord from flask import Flask, jsonify, redirect, request, Response, send_from_directory from requests_oauthlib import OAuth2Session +from tinydb import Query from werkzeug.middleware.proxy_fix import ProxyFix # local imports @@ -109,8 +110,7 @@ def discord_callback(): return Response(html.escape(request.args['error_description']), status=400) # get all active states from the global state manager - with globals.DISCORD_BOT.db as db: - active_states = db['oauth_states'] + active_states = globals.DISCORD_BOT.oauth_states discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) token = discord_oauth.fetch_token( @@ -144,19 +144,32 @@ def discord_callback(): connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") connections = connections_response.json() + # Default user data + user_data = { + 'user_id': int(discord_user['id']), + 'discord_username': discord_user['username'], + 'discord_global_name': discord_user['global_name'], + 'github_id': None, + 'github_username': None, + } + + # Check for GitHub connections + for connection in connections: + if connection['type'] == 'github': + user_data['github_id'] = int(connection['id']) + user_data['github_username'] = connection['name'] + with globals.DISCORD_BOT.db as db: - db['discord_users'] = db.get('discord_users', {}) - db['discord_users'][discord_user['id']] = { - 'discord_username': discord_user['username'], - 'discord_global_name': discord_user['global_name'], - 'github_id': None, - 'github_username': None, - } + query = Query() + + # Get the discord_users table + discord_users_table = db.table('discord_users') - for connection in connections: - if connection['type'] == 'github': - db['discord_users'][discord_user['id']]['github_id'] = connection['id'] - db['discord_users'][discord_user['id']]['github_username'] = connection['name'] + # Upsert the user data + discord_users_table.upsert( + user_data, + query.user_id == int(discord_user['id']) + ) globals.DISCORD_BOT.update_cached_message( author_id=discord_user['id'], @@ -177,8 +190,7 @@ def github_callback(): state = request.args.get('state') # get all active states from the global state manager - with globals.DISCORD_BOT.db as db: - active_states = db['oauth_states'] + active_states = globals.DISCORD_BOT.oauth_states github_oauth = OAuth2Session(GITHUB_CLIENT_ID, redirect_uri=GITHUB_REDIRECT_URI) token = github_oauth.fetch_token( @@ -215,14 +227,26 @@ def github_callback(): discord_user = discord_user_future.result() with globals.DISCORD_BOT.db as db: - db['discord_users'] = db.get('discord_users', {}) - db['discord_users'][discord_user_id] = { + query = Query() + + # Get the discord_users table + discord_users_table = db.table('discord_users') + + # Create user data object + user_data = { + 'user_id': int(discord_user_id), 'discord_username': discord_user.name, 'discord_global_name': discord_user.global_name, - 'github_id': github_user['id'], + 'github_id': int(github_user['id']), 'github_username': github_user['login'], } + # Upsert the user data (insert or update) + discord_users_table.upsert( + user_data, + query.user_id == int(discord_user_id) + ) + globals.DISCORD_BOT.update_cached_message( author_id=discord_user_id, reason='success', diff --git a/src/discord_bot/bot.py b/src/discord_bot/bot.py index 3382c94..1948b95 100644 --- a/src/discord_bot/bot.py +++ b/src/discord_bot/bot.py @@ -8,7 +8,7 @@ import discord # local imports -from src.common.common import bot_name, data_dir, get_avatar_bytes, org_name +from src.common.common import bot_name, get_avatar_bytes, org_name from src.common.database import Database from src.discord_bot.views import DonateCommandView @@ -36,8 +36,9 @@ def __init__(self, *args, **kwargs): self.bot_thread = threading.Thread(target=lambda: None) self.token = os.environ['DISCORD_BOT_TOKEN'] - self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database')) + self.db = Database(db_name='discord_bot_database') self.ephemeral_db = {} + self.oauth_states = {} self.clean_ephemeral_cache = tasks.clean_ephemeral_cache self.daily_task = tasks.daily_task self.role_update_task = tasks.role_update_task @@ -48,9 +49,6 @@ def __init__(self, *args, **kwargs): store=False, ) - with self.db as db: - db['oauth_states'] = {} # clear any oauth states from previous sessions - async def on_ready(self): """ Bot on ready event. diff --git a/src/discord_bot/cogs/github_commands.py b/src/discord_bot/cogs/github_commands.py index 0e85dad..400b7fc 100644 --- a/src/discord_bot/cogs/github_commands.py +++ b/src/discord_bot/cogs/github_commands.py @@ -113,10 +113,8 @@ async def link_github( authorization_url, state = auth.authorization_url(platform_map[platform]['auth_url']) # Store the state in the user's session or database - with self.bot.db as db: - db['oauth_states'] = db.get('oauth_states', {}) - db['oauth_states'][str(ctx.author.id)] = state - db.sync() + oauth_states = self.bot.oauth_states + oauth_states[str(ctx.author.id)] = state response = await ctx.respond( f"Please authorize the application by clicking [here]({authorization_url}).", diff --git a/src/discord_bot/tasks.py b/src/discord_bot/tasks.py index 592c533..53c832b 100644 --- a/src/discord_bot/tasks.py +++ b/src/discord_bot/tasks.py @@ -204,25 +204,33 @@ async def role_update_task(bot: Bot, test_mode: bool = False) -> bool: if datetime.now(UTC).minute not in list(range(0, 60, 10)): return False - # check each user in the database for their GitHub sponsor status + # Check each user in the database for their GitHub sponsor status with bot.db as db: - discord_users = db.get('discord_users', {}) + users_table = db.table('discord_users') + discord_users = users_table.all() + # Return early if there are no users to process if not discord_users: return False + # Get the GitHub sponsors github_sponsors = sponsors.get_github_sponsors() - for user_id, user_data in discord_users.items(): - # get the currently revocable roles, to ensure we don't remove roles that were added by another integration + # Process each user + for user_data in discord_users: + user_id = user_data.get('discord_id') + if not user_id: + continue + + # Get the currently revocable roles, to ensure we don't remove roles that were added by another integration # i.e.; any role that was added by our bot is safe to remove revocable_roles = user_data.get('roles', []).copy() - # check if the user is a GitHub sponsor + # Check if the user is a GitHub sponsor for edge in github_sponsors['data']['organization']['sponsorshipsAsMaintainer']['edges']: sponsor = edge['node']['sponsorEntity'] - if sponsor['login'] == user_data['github_username']: - # user is a sponsor + if sponsor['login'] == user_data.get('github_username'): + # User is a sponsor user_data['github_sponsor'] = True monthly_amount = edge['node'].get('tier', {}).get('monthlyPriceInDollars', 0) @@ -236,14 +244,15 @@ async def role_update_task(bot: Bot, test_mode: bool = False) -> bool: break else: - # user is not a sponsor + # User is not a sponsor user_data['github_sponsor'] = False user_data['roles'] = [] + # Add GitHub user role if applicable if user_data.get('github_username'): user_data['roles'].append('github-user') - # update the discord user roles + # Update the discord user roles for g in bot.guilds: roles = g.roles @@ -281,8 +290,9 @@ async def role_update_task(bot: Bot, test_mode: bool = False) -> bool: remove_future = asyncio.run_coroutine_threadsafe(member.remove_roles(role), bot.loop) remove_future.result() - with bot.db as db: - db['discord_users'] = discord_users - db.sync() + # Update the user in the database + with bot.db as db: + users_table = db.table('discord_users') + users_table.update(user_data, doc_ids=[user_data.get('doc_id')]) return True diff --git a/src/reddit_bot/bot.py b/src/reddit_bot/bot.py index b11a9f5..1b116ee 100644 --- a/src/reddit_bot/bot.py +++ b/src/reddit_bot/bot.py @@ -1,7 +1,6 @@ # standard imports from datetime import datetime import os -import shelve import sys import threading import time @@ -11,11 +10,13 @@ import praw from praw import models import prawcore +from tinydb import Query # local imports from src.common import common from src.common import globals from src.common import inspector +from src.common.database import Database class Bot: @@ -42,11 +43,14 @@ def __init__(self, **kwargs): self.data_dir = common.data_dir self.commands_dir = os.path.join(self.data_dir, "support-bot-commands", "docs") - # files - self.db = os.path.join(self.data_dir, 'reddit_bot_database') + # database + self.db = Database(db_name='reddit_bot_database') - # locks - self.lock = threading.Lock() + # initialize database tables if they don't exist + with self.db as db: + if not db.tables(): + db.table('comments') + db.table('submissions') self.reddit = praw.Reddit( client_id=os.environ['PRAW_CLIENT_ID'], @@ -58,9 +62,6 @@ def __init__(self, **kwargs): ) self.subreddit = self.reddit.subreddit(self.subreddit_name) # "AskReddit" for faster testing of submission loop - self.migrate_shelve() - self.migrate_last_online() - def validate_env(self) -> bool: required_env = [ 'DISCORD_REDDIT_CHANNEL_ID', @@ -78,43 +79,31 @@ def validate_env(self) -> bool: return False return True - def migrate_last_online(self): - if os.path.isfile(os.path.join(self.data_dir, 'last_online')): - os.remove(os.path.join(self.data_dir, 'last_online')) - - def migrate_shelve(self): - with self.lock, shelve.open(self.db) as db: - if 'submissions' not in db and 'comments' not in db: - db['comments'] = {} - db['submissions'] = {} - submissions = db['submissions'] - for k, v in db.items(): - if k not in ['comments', 'submissions']: - submissions[k] = v - assert submissions[k] == v - db['submissions'] = submissions - keys_to_delete = [k for k in db if k not in ['comments', 'submissions']] - for k in keys_to_delete: - del db[k] - assert k not in db - def process_comment(self, comment: models.Comment): - with self.lock, shelve.open(self.db) as db: - comments = db.get('comments', {}) - if comment.id in comments and comments[comment.id].get('processed', False): + with self.db as db: + comments_table = db.table('comments') + c = Query() + existing_comment = comments_table.get(c.reddit_id == comment.id) + + if existing_comment and existing_comment.get('processed', False): return - comments[comment.id] = { + comment_data = { + 'reddit_id': comment.id, # Store Reddit ID as a regular field 'author': str(comment.author), 'body': comment.body, 'created_utc': comment.created_utc, - 'processed': True, + 'processed': False, 'slash_command': {'project': None, 'command': None}, } - # the shelve doesn't update unless we recreate the main key - db['comments'] = comments - self.slash_commands(comment=comment) + comment_data = self.slash_commands(comment=comment, comment_data=comment_data) + comment_data['processed'] = True + + if existing_comment: + comments_table.update(comment_data, c.reddit_id == comment.id) + else: + comments_table.insert(comment_data) def process_submission(self, submission: models.Submission): """ @@ -125,42 +114,51 @@ def process_submission(self, submission: models.Submission): submission : praw.models.Submission The submission to process. """ - with self.lock, shelve.open(self.db) as db: - submissions = db.get('submissions', {}) - if submission.id not in submissions: - submissions[submission.id] = {} - submission_exists = False + with self.db as db: + submissions_table = db.table('submissions') + s = Query() + existing_submission = submissions_table.get(s.reddit_id == submission.id) + + # Extract submission data to store + submission_data = { + 'reddit_id': submission.id, # Store Reddit ID as a regular field + 'title': submission.title, + 'selftext': submission.selftext, + 'author': str(submission.author), + 'created_utc': submission.created_utc, + 'permalink': submission.permalink, + 'url': submission.url, + 'link_flair_text': submission.link_flair_text if hasattr(submission, 'link_flair_text') else None, + 'link_flair_background_color': submission.link_flair_background_color if hasattr( + submission, 'link_flair_background_color') else None, + 'bot_discord': {'sent': False, 'sent_utc': None}, + } + + if existing_submission: + submission_data['bot_discord'] = existing_submission.get( + 'bot_discord', {'sent': False, 'sent_utc': None}) + submissions_table.update(submission_data, s.reddit_id == submission.id) else: - submission_exists = True - - # the shelve doesn't update unless we recreate the main key - submissions[submission.id].update(vars(submission)) - db['submissions'] = submissions - - if not submission_exists: - print(f'submission id: {submission.id}') - print(f'submission title: {submission.title}') - print('---------') - if os.getenv('DISCORD_REDDIT_CHANNEL_ID'): - self.discord(submission=submission) - self.flair(submission=submission) - self.karma(submission=submission) - - def discord(self, submission: models.Submission): + print(f'submission id: {submission.id}') + print(f'submission title: {submission.title}') + print('---------') + if os.getenv('DISCORD_REDDIT_CHANNEL_ID'): + submission_data = self.discord(submission=submission, submission_data=submission_data) + submission_data = self.flair(submission=submission, submission_data=submission_data) + submission_data = self.karma(submission=submission, submission_data=submission_data) + + submissions_table.insert(submission_data) + + def discord(self, submission: models.Submission, submission_data: dict) -> dict: """ Send a discord message. Parameters ---------- - db : shelve.Shelf - The database. submission : praw.models.Submission The submission to process. - - Returns - ------- - shelve.Shelf - The updated database. + submission_data : dict + The submission data to process. """ # get the flair color try: @@ -174,7 +172,7 @@ def discord(self, submission: models.Submission): self.DEGRADED = True reason = inspector.current_name() self.DEGRADED_REASONS.append(reason) if reason not in self.DEGRADED_REASONS else None - return + return submission_data # create the discord embed embed = discord.Embed( @@ -202,45 +200,50 @@ def discord(self, submission: models.Submission): ) if message: - with self.lock, shelve.open(self.db) as db: - # the shelve doesn't update unless we recreate the main key - submissions = db['submissions'] - submissions[submission.id]['bot_discord'] = {'sent': True, 'sent_utc': int(time.time())} - db['submissions'] = submissions + submission_data['bot_discord'] = { + 'sent': True, + 'sent_utc': int(time.time()), + } - def flair(self, submission: models.Submission): + return submission_data + + def flair(self, submission: models.Submission, submission_data: dict) -> dict: # todo - pass + return submission_data - def karma(self, submission: models.Submission): + def karma(self, submission: models.Submission, submission_data: dict) -> dict: # todo - pass - - def slash_commands(self, comment: models.Comment): - if comment.body.startswith("/"): - print(f"Processing slash command: {comment.body}") - # Split the comment into project and command - parts = comment.body[1:].split() - project = parts[0] - command = parts[1] if len(parts) > 1 else None - - # Check if the command file exists in self.commands_dir - command_file = os.path.join(self.commands_dir, project, f"{command}.md") if command else None - if command_file and os.path.isfile(command_file): - # Open the markdown file and read its contents - with open(command_file, 'r', encoding='utf-8') as file: - file_contents = file.read() - - # Reply to the comment with the contents of the file - comment.reply(file_contents) - else: - # Log error message - print(f"Unknown command: {command} in project: {project}") - with self.lock, shelve.open(self.db) as db: - # the shelve doesn't update unless we recreate the main key - comments = db['comments'] - comments[comment.id]['slash_command'] = {'project': project, 'command': command} - db['comments'] = comments + return submission_data + + def slash_commands(self, comment: models.Comment, comment_data: dict) -> dict: + if not comment.body.startswith("/"): + return comment_data + + print(f"Processing slash command: {comment.body}") + # Split the comment into project and command + parts = comment.body[1:].split() + project = parts[0] + command = parts[1] if len(parts) > 1 else None + + # Check if the command file exists in self.commands_dir + command_file = os.path.join(self.commands_dir, project, f"{command}.md") if command else None + + if not command_file or not os.path.isfile(command_file): + return comment_data + + # Open the markdown file and read its contents + with open(command_file, 'r', encoding='utf-8') as file: + file_contents = file.read() + + # Reply to the comment with the contents of the file + comment.reply(file_contents) + + comment_data['slash_command'] = { + 'project': project, + 'command': command, + } + + return comment_data def _comment_loop(self, test: bool = False): # process comments and then keep monitoring diff --git a/tests/conftest.py b/tests/conftest.py index 9905c86..b93e5ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,26 +38,24 @@ def discord_bot(): @pytest.fixture(scope='function') def discord_db_users(discord_bot): with discord_bot.db as db: - db['discord_users'] = { - '939171917578002502': { - 'discord_username': 'test_user', - 'discord_global_name': 'Test User', - 'github_id': 'test_user', - 'github_username': 'test_user', - 'roles': [ - 'supporters', - ] - } - } - db['oauth_states'] = {'939171917578002502': 'valid_state'} - db.sync() # Ensure the data is written to the shelve + users_table = db.table('discord_users') + users_table.insert({ + 'user_id': 939171917578002502, + 'discord_username': 'test_user', + 'discord_global_name': 'Test User', + 'github_id': 123, + 'github_username': 'test_user', + 'roles': ['supporters'] + }) + + discord_bot.oauth_states[939171917578002502] = 'valid_state' yield + # Clean up test data with discord_bot.db as db: - db['discord_users'] = {} - db['oauth_states'] = {} - db.sync() # Ensure the data is written to the shelve + db.table('discord_users').truncate() + discord_bot.oauth_states.clear() @pytest.fixture(scope='function') diff --git a/tests/unit/common/test_database.py b/tests/unit/common/test_database.py new file mode 100644 index 0000000..e8c796e --- /dev/null +++ b/tests/unit/common/test_database.py @@ -0,0 +1,341 @@ +# standard imports +import inspect +import os +import shelve +import threading +from unittest.mock import patch + +# lib imports +import pytest + +# local imports +from src.common.database import Database + + +class TestDatabase: + @pytest.fixture + def test_dir(self, tmp_path): + """Create a temporary directory for database files.""" + return tmp_path + + @pytest.fixture + def db_init(self, test_dir): + """Create a database path for testing.""" + return { + "db_dir": test_dir, + "use_git": False + } + + @pytest.fixture + def cleanup_files(self, db_init): + """Clean up database files after tests.""" + yield + + db_name = f'db_{inspect.currentframe().f_code.co_name}' + + # Clean up known database files + for file_path in [f"{db_name}.json", f"{db_name}.db", f"{db_name}.dat", + f"{db_name}.bak", f"{db_name}.dir", db_name]: + if os.path.exists(os.path.join(db_init['db_dir'], file_path)): + try: + os.remove(file_path) + print(f"Removed {file_path}") + except Exception as e: + print(f"Could not remove {file_path}: {e}") + + def test_init_new_database(self, db_init, cleanup_files): + """Test creating a new database.""" + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + assert os.path.exists(os.path.join(db_init['db_dir'], f'db_{inspect.currentframe().f_code.co_name}.json')) + assert hasattr(db, "tinydb") + assert hasattr(db, "lock") + + def test_context_manager(self, db_init, cleanup_files): + """Test database context manager functionality.""" + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + + with db as tinydb: + tinydb.insert({"test": "data"}) + + # Reopen to verify data was saved + db2 = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + with db2 as tinydb: + data = tinydb.all() + assert len(data) == 1 + assert data[0]["test"] == "data" + + def test_sync_method(self, db_init, cleanup_files): + """Test sync method flushes data to disk.""" + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + + with patch.object(db.tinydb.storage, 'flush') as mock_flush: + db.sync() + mock_flush.assert_called_once() + + @staticmethod + def create_test_shelve(shelve_path): + """Helper to create a test shelve database.""" + # Close the shelve db properly to ensure it's written to disk + shelve_db = shelve.open(shelve_path) + + # Add a comments table with records + shelve_db["comments"] = { + "abc123": {"author": "user1", "body": "test comment", "processed": True}, + "def456": {"author": "user2", "body": "another comment", "processed": False} + } + # Add submissions table + shelve_db["submissions"] = { + "xyz789": {"author": "user3", "title": "Test post", "processed": True} + } + + # Explicitly sync and close + shelve_db.sync() + shelve_db.close() + + def test_migrate_from_shelve(self, db_init, cleanup_files): + """Test migration from shelve to TinyDB.""" + # Create a shelve database first + self.create_test_shelve(os.path.join(db_init['db_dir'], f'db_{inspect.currentframe().f_code.co_name}')) + print("Files in db_dir before migration:", os.listdir(db_init['db_dir'])) + + # Try to create the database (it should migrate if possible) + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + + # If migration didn't happen on this platform, create the tables manually for test + with db as tinydb: + comments = tinydb.table("comments").all() + submissions = tinydb.table("submissions").all() + + # Now verify the data exists one way or another + assert len(comments) == 2 + assert len(submissions) == 1 + + # Find records by attributes since order may vary + user1_comment = next((c for c in comments if c["author"] == "user1"), None) + user2_comment = next((c for c in comments if c["author"] == "user2"), None) + + assert user1_comment is not None, "Couldn't find user1's comment" + assert user2_comment is not None, "Couldn't find user2's comment" + assert user1_comment["body"] == "test comment" + assert user1_comment["processed"] is True + assert user2_comment["body"] == "another comment" + assert user2_comment["processed"] is False + + assert submissions[0]["author"] == "user3" + assert submissions[0]["title"] == "Test post" + + def test_migrate_from_shelve_reddit_db(self, db_init, cleanup_files): + """Test migration from shelve to TinyDB for Reddit database specifically.""" + # Create a shelve path with reddit in the name to trigger special handling + reddit_db_name = f'reddit_bot_{inspect.currentframe().f_code.co_name}' + shelve_path = os.path.join(db_init['db_dir'], reddit_db_name) + + # Create a test shelve with Reddit-specific data structures + shelve_db = shelve.open(shelve_path) + + # Add comments with different structures + shelve_db["comments"] = { + "abc123": { + "author": "user1", # comment authors are strings + "body": "test comment", + "created_utc": 1625097600, + "processed": True, + "slash_command": "/help", + }, + "def456": { + "author": "user2", + "body": "another comment", + "created_utc": 1625184000, + "processed": False, + }, + "ghi789": { + "author": None, # Edge case with None author + "body": "deleted comment", + "created_utc": 1625270400, + } + } + + # Add submissions with different structures + shelve_db["submissions"] = { + "xyz789": { + "id": "xyz789", + "title": "Test post", + "selftext": "Post content", + "author": "user3", + "created_utc": 1625356800, + "permalink": "/r/test/comments/xyz789/", + "url": "https://reddit.com/r/test/comments/xyz789/", + "link_flair_text": "Help", + "link_flair_background_color": "#ff0000", + "bot_discord": {"message_id": "123456789"} + }, + "uvw456": { + "id": "uvw456", + "title": "Another post", + "selftext": "", + "author": "user4", + "created_utc": 1625443200, + "permalink": "/r/test/comments/uvw456/", + "url": "https://reddit.com/r/test/comments/uvw456/" + # Missing some fields intentionally + } + } + + # Explicitly sync and close + shelve_db.sync() + shelve_db.close() + + # Create the database (it should trigger migration) + db = Database( + db_name=reddit_db_name, + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + + # Verify the migrated data + with db as tinydb: + # Check comments + comments_table = tinydb.table("comments") + comments = comments_table.all() + assert len(comments) == 3 + + # Find specific comments by ID to verify consistent order + comment1 = next((c for c in comments if c["reddit_id"] == "abc123"), None) + comment2 = next((c for c in comments if c["reddit_id"] == "def456"), None) + comment3 = next((c for c in comments if c["reddit_id"] == "ghi789"), None) + + # Verify comment 1 + assert comment1 is not None + assert comment1["author"] == "user1" + assert comment1["body"] == "test comment" + assert comment1["created_utc"] == 1625097600 + assert comment1["processed"] is True + assert comment1["slash_command"] == "/help" + + # Verify comment 2 + assert comment2 is not None + assert comment2["author"] == "user2" + assert comment2["body"] == "another comment" + assert comment2["created_utc"] == 1625184000 + assert comment2["processed"] is False + assert comment2["slash_command"]["command"] is None + assert comment2["slash_command"]["project"] is None + + # Verify comment 3 - None author + assert comment3 is not None + assert comment3["author"] is None + assert comment3["body"] == "deleted comment" + assert comment3["created_utc"] == 1625270400 + + # Check submissions + submissions_table = tinydb.table("submissions") + submissions = submissions_table.all() + assert len(submissions) == 2 + + # Find specific submissions + submission1 = next((s for s in submissions if s["reddit_id"] == "xyz789"), None) + submission2 = next((s for s in submissions if s["reddit_id"] == "uvw456"), None) + + # Verify submission 1 - full fields + assert submission1 is not None + assert submission1["title"] == "Test post" + assert submission1["selftext"] == "Post content" + assert submission1["author"] == "user3" + assert submission1["created_utc"] == 1625356800 + assert submission1["permalink"] == "/r/test/comments/xyz789/" + assert submission1["url"] == "https://reddit.com/r/test/comments/xyz789/" + assert submission1["link_flair_text"] == "Help" + assert submission1["link_flair_background_color"] == "#ff0000" + assert submission1["bot_discord"]["message_id"] == "123456789" + + # Verify submission 2 - missing some fields + assert submission2 is not None + assert submission2["title"] == "Another post" + assert submission2["selftext"] == "" + assert submission2["author"] == "user4" + assert submission2["created_utc"] == 1625443200 + assert submission2["permalink"] == "/r/test/comments/uvw456/" + assert submission2["url"] == "https://reddit.com/r/test/comments/uvw456/" + assert submission2["link_flair_text"] is None + assert submission2["link_flair_background_color"] is None + assert submission2["bot_discord"]["sent"] is False + assert submission2["bot_discord"]["sent_utc"] is None + + def test_migration_error_handling(self, db_init, cleanup_files): + """Test error handling during migration.""" + # Create a shelve database first + self.create_test_shelve(os.path.join(db_init['db_dir'], f'db_{inspect.currentframe().f_code.co_name}')) + + # Instead of testing print output, check that the database still initializes + # even when shelve migration fails + with patch('shelve.open', side_effect=Exception("Test error")): + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + + # Check database was initialized with empty tables + with db as tinydb: + comments = tinydb.table("comments").all() + submissions = tinydb.table("submissions").all() + assert len(comments) == 0 + assert len(submissions) == 0 + + def test_thread_safety(self, db_init, cleanup_files): + """Test thread safety with multiple threads accessing database.""" + db = Database( + db_name=f'db_{inspect.currentframe().f_code.co_name}', + db_dir=db_init['db_dir'], + use_git=db_init['use_git'], + ) + results = [] + + def worker(worker_id): + with db as tinydb: + # Simulate some work + table = tinydb.table(f"worker_{worker_id}") + table.insert({"id": worker_id}) + results.append(worker_id) + + # Create and start multiple threads + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Verify all workers processed + assert len(results) == 5 + assert set(results) == set(range(5)) + + # Verify database has all tables + with db as tinydb: + for i in range(5): + table = tinydb.table(f"worker_{i}") + assert len(table.all()) == 1 + assert table.all()[0]["id"] == i diff --git a/tests/unit/common/test_webapp.py b/tests/unit/common/test_webapp.py index 3e781a1..7c8bf5a 100644 --- a/tests/unit/common/test_webapp.py +++ b/tests/unit/common/test_webapp.py @@ -75,14 +75,14 @@ def test_discord_callback_success(test_client, mocker, discord_db_users): mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) mocker.patch('src.common.webapp.OAuth2Session.get', side_effect=[ Mock(json=lambda: { - 'id': '939171917578002502', + 'id': 939171917578002502, 'username': 'discord_user', 'global_name': 'discord_global_name', }), Mock(json=lambda: [ { 'type': 'github', - 'id': 'github_user_id', + 'id': '123', 'name': 'github_user_login', } ]) @@ -143,11 +143,11 @@ def test_github_callback_success(test_client, mocker, discord_db_users): mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) mocker.patch('src.common.webapp.OAuth2Session.get', side_effect=[ Mock(json=lambda: { - 'id': 'github_user_id', + 'id': '123', 'login': 'github_user_login', }), Mock(json=lambda: { - 'id': 'github_user_id', + 'id': '123', 'login': 'github_user_login', }) ]) @@ -171,7 +171,7 @@ def test_github_callback_invalid_state(test_client, mocker, discord_db_users): mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) mocker.patch('src.common.webapp.OAuth2Session.get', return_value=Mock(json=lambda: { - 'id': 'github_user_id', + 'id': '123', 'login': 'github_user_login', })) diff --git a/tests/unit/reddit/test_reddit_bot.py b/tests/unit/reddit/test_reddit_bot.py index 45d4e01..11af542 100644 --- a/tests/unit/reddit/test_reddit_bot.py +++ b/tests/unit/reddit/test_reddit_bot.py @@ -6,7 +6,6 @@ import inspect import json import os -import shelve from unittest.mock import patch from urllib.parse import quote_plus @@ -15,6 +14,7 @@ from betamax_serializers.pretty_json import PrettyJSONSerializer from praw.config import _NotSet import pytest +from tinydb import Query # local imports from src.reddit_bot.bot import Bot @@ -119,7 +119,8 @@ def bot(self): @pytest.fixture(scope='session', autouse=True) def betamax_config(self, bot): - record_mode = 'none' if os.environ.get('GITHUB_PYTEST') else 'once' + record_mode = 'none' if os.environ.get('GITHUB_PYTEST', '').lower() == 'true' else 'once' + record_mode = 'all' if os.environ.get('FORCE_BETAMAX_UPDATE', '').lower() == 'true' else record_mode with Betamax.configure() as config: config.cassette_library_dir = 'tests/fixtures/cassettes' @@ -144,16 +145,16 @@ def recorder(self, session): @pytest.fixture(scope='session') def slash_command_comment(self, bot, recorder): - comment = bot.reddit.comment(id='l20s21b') with recorder.use_cassette(f'fixture_{inspect.currentframe().f_code.co_name}'): + comment = bot.reddit.comment(id='l20s21b') assert comment.body == '/sunshine vban' return comment @pytest.fixture(scope='session') def _submission(self, bot, recorder): - s = bot.reddit.submission(id='w03cku') with recorder.use_cassette(f'fixture_{inspect.currentframe().f_code.co_name}'): + s = bot.reddit.submission(id='w03cku') # TODO: replace with a submission by LizardByte-bot assert s.author return s @@ -169,44 +170,36 @@ def test_validate_env(self, bot): }): assert bot.validate_env() - def test_migrate_shelve(self, bot): - with patch.object(shelve, 'open') as mock_open: - bot.migrate_shelve() - mock_open.assert_called_once_with(bot.db) - with shelve.open(bot.db) as db: - assert db.get('comments') is not None - assert db.get('submissions') is not None - - def test_migrate_last_online(self, bot): - f = os.path.join(bot.data_dir, 'last_online') - if not os.path.isfile(f): - with open(f, 'w') as file: - file.write('1234') - assert os.path.isfile(f) - - bot.migrate_last_online() - assert not os.path.isfile(f) - def test_process_comment(self, bot, recorder, request, slash_command_comment): with recorder.use_cassette(request.node.name): bot.process_comment(comment=slash_command_comment) - with bot.lock, shelve.open(bot.db) as db: - assert slash_command_comment.id in db['comments'] - assert db['comments'][slash_command_comment.id]['author'] == str(slash_command_comment.author) - assert db['comments'][slash_command_comment.id]['body'] == slash_command_comment.body - assert db['comments'][slash_command_comment.id]['processed'] - assert db['comments'][slash_command_comment.id]['slash_command']['project'] == 'sunshine' - assert db['comments'][slash_command_comment.id]['slash_command']['command'] == 'vban' + + with bot.db as db: + comments_table = db.table('comments') + c = Query() + comment_data = comments_table.get(c.reddit_id == slash_command_comment.id) + + assert comment_data is not None + assert comment_data['author'] == str(slash_command_comment.author) + assert comment_data['body'] == slash_command_comment.body + assert comment_data['processed'] + assert comment_data['slash_command']['project'] == 'sunshine' + assert comment_data['slash_command']['command'] == 'vban' def test_process_submission(self, bot, discord_bot, recorder, request, _submission): with recorder.use_cassette(request.node.name): bot.process_submission(submission=_submission) - with bot.lock, shelve.open(bot.db) as db: - assert _submission.id in db['submissions'] - assert db['submissions'][_submission.id]['author'] == str(_submission.author) - assert db['submissions'][_submission.id]['title'] == _submission.title - assert db['submissions'][_submission.id]['bot_discord']['sent'] is True - assert db['submissions'][_submission.id]['bot_discord']['sent_utc'] + + with bot.db as db: + submissions_table = db.table('submissions') + s = Query() + submission_data = submissions_table.get(s.reddit_id == _submission.id) + + assert submission_data is not None + assert submission_data['author'] == str(_submission.author) + assert submission_data['title'] == _submission.title + assert submission_data['bot_discord']['sent'] is True + assert 'sent_utc' in submission_data['bot_discord'] def test_comment_loop(self, bot, recorder, request): with recorder.use_cassette(request.node.name):