Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added extension management commands #157

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from datetime import datetime
from typing import Optional

Expand All @@ -11,6 +10,7 @@
from bot.postgres import create_tables

from . import constants
from .utils.extensions import EXTENSIONS, walk_extensions


class Bot(commands.Bot):
Expand Down Expand Up @@ -85,12 +85,10 @@ async def _db_setup(self) -> None:
def load_extensions(self) -> None:
"""Load all the extensions in the exts/ folder."""
logger.info("Start loading extensions from ./exts/")
for extension in constants.EXTENSIONS.glob("*/*.py"):
if extension.name.startswith("_"):
continue # ignore files starting with _
dot_path = str(extension).replace(os.sep, ".")[:-3] # remove the .py
self.load_extension(dot_path)
logger.info(f"Successfully loaded extension: {dot_path}")
for ext in walk_extensions():
self.load_extension(ext)
EXTENSIONS.append(ext)
logger.info(f"Successfully loaded extension: {ext}")

def run(self) -> None:
"""Run the bot with the token in constants.py/.env ."""
Expand Down
295 changes: 295 additions & 0 deletions bot/exts/backend/extension_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import functools
import typing as t
from collections import defaultdict
from enum import Enum

from disnake import AllowedMentions, Colour, Embed
from disnake.ext import commands
from loguru import logger

from bot import constants
from bot.bot import Bot
from bot.utils.extensions import EXTENSIONS, unqualify
from bot.utils.pagination import LinePaginator

BLACKLIST = [__name__]


class Action(Enum):
"""Represents an action to perform on an extension."""

# Need to be partial otherwise they are considered to be function definitions.
LOAD = functools.partial(Bot.load_extension)
UNLOAD = functools.partial(Bot.unload_extension)
RELOAD = functools.partial(Bot.reload_extension)


class ExtensionConverter(commands.Converter):
"""
Fully qualify the name of an extension and ensure it exists.

The * value bypasses this when used with an extension manager command.
"""

source_list = EXTENSIONS
type = "extension"

async def convert(self, _: commands.Context, argument: str) -> str:
"""Fully qualify the name of an extension and ensure it exists."""
# Special values to reload all extensions
if argument == "*":
return argument

argument = argument.lower()

if argument in self.source_list:
return argument

qualified_arg = f"modmail.{self.type}s.{argument}"
if qualified_arg in self.source_list:
return qualified_arg

matches = []
for ext in self.source_list:
if argument == unqualify(ext):
matches.append(ext)

if not matches:
raise commands.BadArgument(
f":x: Could not find the {self.type} `{argument}`."
)

if len(matches) > 1:
names = "\n".join(sorted(matches))
raise commands.BadArgument(
f":x: `{argument}` is an ambiguous {self.type} name. "
f"Please use one of the following fully-qualified names.```\n{names}```"
)

return matches[0]


class ExtensionManager(commands.Cog, name="Extension Manager"):
"""
Extension management.

Commands to load, reload, unload, and list extensions.
"""

type = "extension"
module_name = constants.EXTENSIONS.name

def __init__(self, bot: Bot):
self.bot = bot
self.all_extensions = EXTENSIONS

@commands.group("ext", aliases=("extensions", "exts"), invoke_without_command=True)
async def extensions_group(self, ctx: commands.Context) -> None:
"""Load, unload, reload, and list loaded extensions."""
await ctx.send_help(ctx.command)

@extensions_group.command(name="load", aliases=("l",))
async def load_extensions(
self, ctx: commands.Context, *extensions: ExtensionConverter
) -> None:
r"""
Load extensions given their fully qualified or unqualified names.

If '*' is given as the name, all unloaded extensions will be loaded.
"""
if not extensions:
await ctx.send_help(ctx.command)
return

if "*" in extensions:
extensions = sorted(
ext
for ext in self.all_extensions
if ext not in self.bot.extensions.keys()
)

msg = self.batch_manage(Action.LOAD, *extensions)
await ctx.send(msg)

@extensions_group.command(name="unload", aliases=("ul",))
async def unload_extensions(
self, ctx: commands.Context, *extensions: ExtensionConverter
) -> None:
r"""
Unload currently loaded extensions given their fully qualified or unqualified names.

If '*' is given as the name, all loaded extensions will be unloaded.
"""
if not extensions:
await ctx.send_help(ctx.command)
return

blacklisted = [ext for ext in BLACKLIST if ext in extensions]

if blacklisted:
bl_msg = "\n".join(blacklisted)
await ctx.send(
f":x: The following {self.type}(s) may not be unloaded:```\n{bl_msg}```"
)
return

if "*" in extensions:
extensions = sorted(
ext
for ext in self.bot.extensions.keys() & self.all_extensions
if ext not in BLACKLIST
)

if "*" in extensions:
extensions = sorted(
ext for ext in self.bot.extensions.keys() & self.all_extensions
)

await ctx.send(self.batch_manage(Action.UNLOAD, *extensions))

@extensions_group.command(name="reload", aliases=("r", "rl"))
async def reload_extensions(
self, ctx: commands.Context, *extensions: ExtensionConverter
) -> None:
r"""
Reload extensions given their fully qualified or unqualified names.

If an extension fails to be reloaded, it will be rolled-back to the prior working state.

If '*' is given as the name, all currently loaded extensions will be reloaded.
"""
if not extensions:
await ctx.send_help(ctx.command)
return

if "*" in extensions:
extensions = self.bot.extensions.keys() & self.all_extensions.keys()

await ctx.send(self.batch_manage(Action.RELOAD, *extensions))

@extensions_group.command(name="list", aliases=("all", "ls"))
async def list_extensions(self, ctx: commands.Context) -> None:
"""
Get a list of all extensions, including their loaded status.

Red indicates that the extension is unloaded.
Green indicates that the extension is currently loaded.
"""
embed = Embed(colour=Colour.blurple())
embed.set_author(
name=f"{self.type.capitalize()} List",
)

lines = []
categories = self.group_extension_statuses()
for category, extensions in sorted(categories.items()):
# Treat each category as a single line by concatenating everything.
# This ensures the paginator will not cut off a page in the middle of a category.
logger.trace(f"Extensions in category {category}: {extensions}")
category = category.replace("_", " ").title()
extensions = "\n".join(sorted(extensions))
lines.append(f"**{category}**\n{extensions}\n")

logger.debug(
f"{ctx.author} requested a list of all {self.type}s. "
"Returning a paginated list."
)

await LinePaginator.paginate(
lines or f"There are no {self.type}s installed.", ctx, embed=embed
)

def group_extension_statuses(self) -> t.Mapping[str, str]:
"""Return a mapping of extension names and statuses to their categories."""
categories = defaultdict(list)

for ext in self.all_extensions:
if ext in self.bot.extensions:
status = ":green_circle:"
else:
status = ":red_circle:"

root, name = ext.rsplit(".", 1)
if root.split(".", 1)[1] == self.module_name:
category = f"General {self.type}s"
else:
category = " - ".join(root.split(".")[2:])
categories[category].append(f"{status} {name}")

return dict(categories)

def batch_manage(self, action: Action, *extensions: str) -> str:
"""
Apply an action to multiple extensions and return a message with the results.

If only one extension is given, it is deferred to `manage()`.
"""
if len(extensions) == 1:
msg, _ = self.manage(action, extensions[0])
return msg

verb = action.name.lower()
failures = {}

for extension in sorted(extensions):
_, error = self.manage(action, extension)
if error:
failures[extension] = error

emoji = ":x:" if failures else ":thumbsup:"
msg = f"{emoji} {len(extensions) - len(failures)} / {len(extensions)} {self.type}s {verb}ed."

if failures:
failures = "\n".join(f"{ext}\n {err}" for ext, err in failures.items())
msg += f"\nFailures:```\n{failures}```"

logger.debug(f"Batch {verb}ed {self.type}s.")

return msg

def manage(self, action: Action, ext: str) -> t.Tuple[str, t.Optional[str]]:
"""Apply an action to an extension and return the status message and any error message."""
verb = action.name.lower()
error_msg = None

try:
action.value(self.bot, ext)
except (commands.ExtensionAlreadyLoaded, commands.ExtensionNotLoaded):
if action is Action.RELOAD:
# When reloading, have a special error.
msg = f":x: {self.type.capitalize()} `{ext}` is not loaded, so it was not {verb}ed."
else:
msg = f":x: {self.type.capitalize()} `{ext}` is already {verb}ed."
except Exception as e:
if hasattr(e, "original"):
# If original exception is present, then utilize it
e = e.original

logger.exception(f"{self.type.capitalize()} '{ext}' failed to {verb}.")

error_msg = f"{e.__class__.__name__}: {e}"
msg = f":x: Failed to {verb} {self.type} `{ext}`:\n```\n{error_msg}```"
else:
msg = f":thumbsup: {self.type.capitalize()} successfully {verb}ed: `{ext}`."

logger.debug(error_msg or msg)
return msg, error_msg

async def cog_check(self, ctx: commands.Context) -> bool:
"""Only allow lords to invoke the commands in this cog."""
# ctx.guild should always be full cache at this point
await self.bot.wait_until_ready()
role = ctx.guild.get_role(constants.Roles.steering_council)
return role in ctx.author.roles

# This cannot be static (must have a __func__ attribute).
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
"""Handle BadArgument errors locally to prevent the help command from showing."""
if isinstance(error, commands.BadArgument):
await ctx.send(str(error), allowed_mentions=AllowedMentions.none())
error.handled = True


def setup(bot: Bot) -> None:
"""Load the Extension Manager cog."""
bot.add_cog(ExtensionManager(bot))
20 changes: 20 additions & 0 deletions bot/utils/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from typing import Generator

from bot import constants

EXTENSIONS: list[str] = [] # All extensions.


def unqualify(name: str) -> str:
"""Return an unqualified name given a qualified module/package `name`."""
return name.rsplit(".", maxsplit=1)[-1]


def walk_extensions() -> Generator[str, None, None]:
"""Yield extensions from the configured constants.EXTENSIONS subpackage."""
for extension in constants.EXTENSIONS.glob("*/*.py"):
if extension.name.startswith("_"):
continue # ignore files starting with _
dot_path = str(extension).replace(os.sep, ".")[:-3] # remove the .py
yield dot_path