Skip to content

Commit 9104666

Browse files
committed
Module loading, unloading, reloading
1 parent c89b6a8 commit 9104666

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

twitchio/ext/commands/bot.py

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,24 @@
2424

2525
from __future__ import annotations
2626

27+
import asyncio
28+
import importlib.util
2729
import logging
30+
import sys
31+
import types
2832
from typing import TYPE_CHECKING, Any, TypeAlias, Unpack
2933

3034
from twitchio.client import Client
3135

36+
from ...utils import _is_submodule
3237
from .context import Context
3338
from .converters import _BaseConverter
3439
from .core import Command, CommandErrorPayload, Group, Mixin
3540
from .exceptions import *
3641

3742

3843
if TYPE_CHECKING:
39-
from collections.abc import Callable, Coroutine, Iterable
44+
from collections.abc import Callable, Coroutine, Iterable, Mapping
4045

4146
from twitchio.eventsub.subscriptions import SubscriptionPayload
4247
from twitchio.models.eventsub_ import ChatMessage
@@ -170,6 +175,7 @@ def __init__(
170175
self._get_prefix: PrefixT = prefix
171176
self._components: dict[str, Component] = {}
172177
self._base_converter: _BaseConverter = _BaseConverter(self)
178+
self.__modules: dict[str, types.ModuleType] = {}
173179

174180
@property
175181
def bot_id(self) -> str:
@@ -451,3 +457,203 @@ async def subscribe_websocket(
451457
socket_id: str | None = None,
452458
) -> SubscriptionResponse | None:
453459
return await super().subscribe_websocket(payload=payload, as_bot=as_bot, token_for=token_for, socket_id=socket_id)
460+
461+
def _get_module_name(self, name: str, package: str | None) -> str:
462+
try:
463+
return importlib.util.resolve_name(name, package)
464+
except ImportError as e:
465+
raise ModuleNotFoundError(f'The module "{name}" was not found.') from e
466+
467+
async def _remove_module_remnants(self, name: str) -> None:
468+
for component_name, component in self._components.copy().items():
469+
if component.__module__ == name or component.__module__.startswith(f"{name}."):
470+
await self.remove_component(component_name)
471+
472+
async def _module_finalizers(self, name: str, module: types.ModuleType) -> None:
473+
try:
474+
func = getattr(module, "teardown")
475+
except AttributeError:
476+
pass
477+
else:
478+
try:
479+
await func(self)
480+
except Exception:
481+
pass
482+
finally:
483+
self.__modules.pop(name, None)
484+
sys.modules.pop(name, None)
485+
486+
name = module.__name__
487+
for m in list(sys.modules.keys()):
488+
if _is_submodule(name, m):
489+
del sys.modules[m]
490+
491+
async def load_module(self, name: str, *, package: str | None = None) -> None:
492+
"""|coro|
493+
494+
Loads a module.
495+
496+
A module is a python module that contains commands, cogs, or listeners.
497+
498+
A module must have a global coroutine, ``setup`` defined as the entry point on what to do when the module is loaded.
499+
The coroutine takes a single argument, the ``bot``.
500+
501+
.. versionchanged:: 3.0
502+
This method is now a :term:`coroutine`.
503+
504+
Parameters
505+
----------
506+
name: str
507+
The module to load. It must be dot separated like regular Python imports accessing a sub-module.
508+
e.g. ``foo.bar`` if you want to import ``foo/bar.py``.
509+
package: str | None
510+
The package name to resolve relative imports with.
511+
This is required when loading an extension using a relative path.
512+
e.g. ``.foo.bar``. Defaults to ``None``.
513+
514+
Raises
515+
------
516+
ModuleAlreadyLoadedError
517+
The module is already loaded.
518+
ModuleNotFoundError
519+
The module could not be imported. Also raised if module could not be resolved using the `package` parameter.
520+
ModuleLoadFailure
521+
There was an error loading the module.
522+
NoModuleEntryPoint
523+
The module does not have a setup coroutine.
524+
TypeError
525+
The module's setup function is not a coroutine.
526+
"""
527+
528+
name = self._get_module_name(name, package)
529+
530+
if name in self.__modules:
531+
raise ModuleAlreadyLoadedError(f"The module {name} has already been loaded.")
532+
533+
spec = importlib.util.find_spec(name)
534+
if spec is None:
535+
raise ModuleNotFoundError(name)
536+
537+
module = importlib.util.module_from_spec(spec)
538+
sys.modules[name] = module
539+
540+
try:
541+
spec.loader.exec_module(module) # type: ignore
542+
except Exception as e:
543+
del sys.modules[name]
544+
raise ModuleLoadFailure(e) from e
545+
546+
try:
547+
entry = getattr(module, "setup")
548+
except AttributeError as exc:
549+
del sys.modules[name]
550+
raise NoModuleEntryPoint(f'The module "{module}" has no setup coroutine.') from exc
551+
552+
if not asyncio.iscoroutinefunction(entry):
553+
del sys.modules[name]
554+
raise TypeError(f'The module "{module}"\'s setup function is not a coroutine.')
555+
556+
try:
557+
await entry(self)
558+
except Exception as e:
559+
del sys.modules[name]
560+
await self._remove_module_remnants(module.__name__)
561+
raise ModuleLoadFailure(e) from e
562+
563+
self.__modules[name] = module
564+
565+
async def unload_module(self, name: str, *, package: str | None = None) -> None:
566+
"""|coro|
567+
568+
Unloads a module.
569+
570+
When the module is unloaded, all commands, listeners and components are removed from the bot, and the module is un-imported.
571+
572+
You can add an optional global coroutine of ``teardown`` to the module to do miscellaneous clean-up if necessary.
573+
This also takes a single paramter of the ``bot``, similar to ``setup``.
574+
575+
.. versionchanged:: 3.0
576+
This method is now a :term:`coroutine`.
577+
578+
Parameters
579+
----------
580+
name: str
581+
The module to unload. It must be dot separated like regular Python imports accessing a sub-module.
582+
e.g. ``foo.bar`` if you want to import ``foo/bar.py``.
583+
package: str | None
584+
The package name to resolve relative imports with.
585+
This is required when unloading an extension using a relative path.
586+
e.g. ``.foo.bar``. Defaults to ``None``.
587+
588+
Raises
589+
------
590+
ModuleNotLoaded
591+
The module was not loaded.
592+
"""
593+
594+
name = self._get_module_name(name, package)
595+
module = self.__modules.get(name)
596+
597+
if module is None:
598+
raise ModuleNotLoaded(name)
599+
600+
await self._remove_module_remnants(module.__name__)
601+
await self._module_finalizers(name, module)
602+
603+
async def reload_module(self, name: str, *, package: str | None = None) -> None:
604+
"""|coro|
605+
606+
Atomically reloads a module.
607+
608+
This attempts to unload and then load the module again, in an atomic way.
609+
If an operation fails mid reload then the bot will revert back to the prior working state.
610+
611+
.. versionchanged:: 3.0
612+
This method is now a :term:`coroutine`.
613+
614+
Parameters
615+
----------
616+
name: str
617+
The module to unload. It must be dot separated like regular Python imports accessing a sub-module.
618+
e.g. ``foo.bar`` if you want to import ``foo/bar.py``.
619+
package: str | None
620+
The package name to resolve relative imports with.
621+
This is required when unloading an extension using a relative path.
622+
e.g. ``.foo.bar``. Defaults to ``None``.
623+
624+
Raises
625+
------
626+
ModuleNotLoaded
627+
The module was not loaded.
628+
ModuleNotFoundError
629+
The module could not be imported. Also raised if module could not be resolved using the `package` parameter.
630+
ModuleLoadFailure
631+
There was an error loading the module.
632+
NoModuleEntryPoint
633+
The module does not have a setup coroutine.
634+
TypeError
635+
The module's setup function is not a coroutine.
636+
"""
637+
638+
name = self._get_module_name(name, package)
639+
module = self.__modules.get(name)
640+
641+
if module is None:
642+
raise ModuleNotLoaded(name)
643+
644+
modules = {name: module for name, module in sys.modules.items() if _is_submodule(module.__name__, name)}
645+
646+
try:
647+
await self._remove_module_remnants(module.__name__)
648+
await self._module_finalizers(name, module)
649+
await self.load_module(name)
650+
except Exception as e:
651+
await module.setup(self)
652+
self.__modules[name] = module
653+
sys.modules.update(modules)
654+
raise e
655+
656+
@property
657+
def modules(self) -> Mapping[str, types.ModuleType]:
658+
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
659+
return types.MappingProxyType(self.__modules)

twitchio/ext/commands/exceptions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
"ConversionError",
4343
"BadArgument",
4444
"MissingRequiredArgument",
45+
"ModuleNotFoundError",
46+
"ModuleAlreadyLoadedError",
47+
"ModuleLoadFailure",
48+
"ModuleNotLoaded",
49+
"NoModuleEntryPoint",
4550
)
4651

4752

@@ -120,3 +125,28 @@ class MissingRequiredArgument(ArgumentError):
120125
def __init__(self, param: inspect.Parameter) -> None:
121126
self.param: inspect.Parameter = param
122127
super().__init__(f'"{param.name}" is a required argument which is missing.')
128+
129+
130+
class ModuleNotFoundError(TwitchioException):
131+
def __init__(self, msg: str) -> None:
132+
super().__init__(msg)
133+
134+
135+
class ModuleLoadFailure(TwitchioException):
136+
def __init__(self, exc: Exception) -> None:
137+
super().__init__(exc)
138+
139+
140+
class NoModuleEntryPoint(TwitchioException):
141+
def __init__(self, msg: str) -> None:
142+
super().__init__(msg)
143+
144+
145+
class ModuleAlreadyLoadedError(TwitchioException):
146+
def __init__(self, msg: str) -> None:
147+
super().__init__(msg)
148+
149+
150+
class ModuleNotLoaded(TwitchioException):
151+
def __init__(self, msg: str) -> None:
152+
super().__init__(msg)

twitchio/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"url_encode_datetime",
4040
"MISSING",
4141
"handle_user_ids",
42+
"_is_submodule",
4243
)
4344

4445
T_co = TypeVar("T_co", covariant=True)
@@ -837,3 +838,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
837838
return cast(F, wrapper)
838839

839840
return decorator
841+
842+
843+
def _is_submodule(parent: str, child: str) -> bool:
844+
return parent == child or child.startswith(parent + ".")

0 commit comments

Comments
 (0)