Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #1180: Only load enabled extension packages #1204

Open
wants to merge 2 commits into
base: 1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 46 additions & 37 deletions jupyter_server/extension/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import importlib

from tornado.gen import multi
from traitlets import Any, Bool, Dict, HasTraits, Instance, Unicode, default, observe
from traitlets import (
Any,
Bool,
Dict,
HasTraits,
Instance,
List,
Unicode,
default,
observe,
)
from traitlets import validate as validate_trait
from traitlets.config import LoggingConfigurable

Expand Down Expand Up @@ -158,52 +168,51 @@ class ExtensionPackage(HasTraits):
"""

name = Unicode(help="Name of the an importable Python package.")
enabled = Bool(False).tag(config=True)
enabled = Bool(False, help="Whether the extension package is enabled.")

_linked_points = Dict()
extension_points = Dict()
module = Any(allow_none=True, help="The module for this extension package. None if not enabled")
metadata = List(Dict(), help="Extension metadata loaded from the extension package.")
version = Unicode(
help="""
The version of this extension package, if it can be found.
Otherwise, an empty string.
""",
)

@default("version")
def _load_version(self):
if not self.enabled:
return ""
return getattr(self.module, "__version__", "")

def __init__(self, *args, **kwargs):
# Store extension points that have been linked.
self._linked_points = {}
super().__init__(*args, **kwargs)
def __init__(self, **kwargs):
"""Initialize an extension package."""
super().__init__(**kwargs)
if self.enabled:
self._load_metadata()

_linked_points: dict = {}
def _load_metadata(self):
"""Import package and load metadata

@validate_trait("name")
def _validate_name(self, proposed):
name = proposed["value"]
self._extension_points = {}
Only used if extension package is enabled
"""
name = self.name
try:
self._module, self._metadata = get_metadata(name)
self.module, self.metadata = get_metadata(name, logger=self.log)
except ImportError as e:
raise ExtensionModuleNotFound(
"The module '{name}' could not be found ({e}). Are you "
"sure the extension is installed?".format(name=name, e=e)
msg = (
f"The module '{name}' could not be found ({e}). Are you "
"sure the extension is installed?"
)
raise ExtensionModuleNotFound(msg) from None
# Create extension point interfaces for each extension path.
for m in self._metadata:
for m in self.metadata:
point = ExtensionPoint(metadata=m)
self._extension_points[point.name] = point
self.extension_points[point.name] = point
return name

@property
def module(self):
"""Extension metadata loaded from the extension package."""
return self._module

@property
def version(self):
"""Get the version of this package, if it's given. Otherwise, return an empty string"""
return getattr(self._module, "__version__", "")

@property
def metadata(self):
"""Extension metadata loaded from the extension package."""
return self._metadata

@property
def extension_points(self):
"""A dictionary of extension points."""
return self._extension_points

def validate(self):
"""Validate all extension points in this package."""
for extension in self.extension_points.values():
Expand Down
27 changes: 25 additions & 2 deletions tests/extension/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import unittest.mock as mock

import pytest
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_extension_package_api():
path1 = metadata_list[0]
app = path1["app"]

e = ExtensionPackage(name="tests.extension.mockextensions")
e = ExtensionPackage(name="tests.extension.mockextensions", enabled=True)
e.extension_points
assert hasattr(e, "extension_points")
assert len(e.extension_points) == len(metadata_list)
Expand All @@ -70,7 +71,9 @@ def test_extension_package_api():

def test_extension_package_notfound_error():
with pytest.raises(ExtensionModuleNotFound):
ExtensionPackage(name="nonexistent")
ExtensionPackage(name="nonexistent", enabled=True)
# no raise if not enabled
ExtensionPackage(name="nonexistent", enabled=False)


def _normalize_path(path_list):
Expand Down Expand Up @@ -132,3 +135,23 @@ def test_extension_manager_fail_load(jp_serverapp):
jp_serverapp.reraise_server_extension_failures = True
with pytest.raises(RuntimeError):
manager.load_extension(name)


@pytest.mark.parametrize("has_app", [True, False])
def test_disable_no_import(jp_serverapp, has_app):
# de-import modules so we can detect if they are re-imported
disabled_ext = "tests.extension.mockextensions.mock1"
enabled_ext = "tests.extension.mockextensions.mock2"
sys.modules.pop(disabled_ext, None)
sys.modules.pop(enabled_ext, None)

manager = ExtensionManager(serverapp=jp_serverapp if has_app else None)
manager.add_extension(disabled_ext, enabled=False)
manager.add_extension(enabled_ext, enabled=True)
assert disabled_ext not in sys.modules
assert enabled_ext in sys.modules

ext_pkg = manager.extensions[disabled_ext]
assert ext_pkg.extension_points == {}
assert ext_pkg.version == ""
assert ext_pkg.metadata == []