Skip to content

Commit

Permalink
Check loaded module names before importing a new one from a file (#119)
Browse files Browse the repository at this point in the history
Fixes a bug that prevents in-file benchmark execution due to a type
mismatch when calling runner.run() with __file__ as argument (or just
the name).

The reason was that on encountering a file name, the Python module that
the file pointed to was imported under its absolute path name regardless
of whether it was already present in sys.modules (in this case, it was,
because it is also __main__).

Adds a check in `import_file_as_module` whether a module already exists
in sys.modules for the given file name, in which case that module is
returned.
  • Loading branch information
nicholasjng authored Mar 19, 2024
1 parent bfaec65 commit 8954ba1
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/nnbench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,27 @@ def ismodule(name: str | os.PathLike[str]) -> bool:

def modulename(file: str | os.PathLike[str]) -> str:
"""Convert a file name to its corresponding Python module name."""
fpath = Path(file)
fpath = Path(file).with_suffix("")
if len(fpath.parts) == 1:
return str(fpath)

filename = fpath.with_suffix("").as_posix()
filename = fpath.as_posix()
return filename.replace("/", ".")


def import_file_as_module(file: str | os.PathLike[str]) -> ModuleType:
fpath = Path(file)
fpath = Path(file).resolve() # Python module __file__ paths are absolute.
if not fpath.is_file() or fpath.suffix != ".py":
raise ValueError(f"path {str(file)!r} is not a Python file")

# TODO: For absolute paths, the resulting module name will be horrifying
# -> find a sensible cutoff point (project root)
# TODO: Recomputing this map in a loop can be expensive if many modules are loaded.
modmap = {m.__file__: m for m in sys.modules.values() if getattr(m, "__file__", None)}
spath = str(fpath)
if spath in modmap:
# if the module under "file" has already been loaded, return it,
# otherwise we get nasty type errors in collection.
return modmap[spath]

modname = modulename(fpath)
if modname in sys.modules:
# return already loaded module
Expand Down

0 comments on commit 8954ba1

Please sign in to comment.