Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions src/anemoi/utils/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import sys
import sysconfig
from functools import cache
from importlib.metadata import PackageNotFoundError
from importlib.metadata import distribution
from typing import Any

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,8 +119,52 @@ def _package_version(name: str) -> str | None:
return None


def _module_versions() -> tuple[dict[str, Any], set]:
def _get_package_source_url(package_name: str) -> dict[str, Any] | None:
"""Extract the source URL from package metadata if installed from git or other VCS. This reads PEP 610 direct_url.json files created by pip when installing from git URLs.

Parameters
----------
package_name : str
The name of the package to check.

Returns
-------
dict or None
Dictionary with 'url' and optionally 'vcs_info' (commit hash, branch, etc.) if available.
"""
try:
dist = distribution(package_name)
except PackageNotFoundError as e:
LOG.debug(f"Could not get source URL for {package_name}: {e}")
return None

try:
direct_url_text = dist.read_text("direct_url.json")
except FileNotFoundError as e:
LOG.debug(f"No direct_url.json found for {package_name}: {e}")
return None

try:
direct_url = json.loads(str(direct_url_text))
result = {"url": direct_url.get("url")}
except json.JSONDecodeError as e:
LOG.debug(f"Invalid direct_url.json for {package_name}: {e}")
return None

# Add VCS info if available (commit hash, requested revision, etc.)
if "vcs_info" in direct_url:
result["vcs_info"] = direct_url["vcs_info"]

# Add subdirectory info if present (e.g., for monorepos)
if "subdirectory" in direct_url:
result["subdirectory"] = direct_url["subdirectory"]

return result


def _module_versions() -> tuple[dict[str, Any], list[tuple[str, str]]]:
"""Collect version information for all loaded modules.
Include source URL information from PEP 610 direct_url.json files.

Returns
-------
Expand All @@ -140,11 +186,16 @@ def _module_versions() -> tuple[dict[str, Any], set]:
if version is None:
continue

versions[name] = version
# Store dict with source info
source_url = _get_package_source_url(name)
versions[name] = {"version": version}
if source_url: # Package contains source info
versions[name]["source"] = source_url

if hasattr(module, "__file__") and module.__file__ is not None:
paths.add((name, os.path.realpath(module.__file__)))

return versions, paths
return versions, list(paths)


@cache
Expand Down Expand Up @@ -248,7 +299,7 @@ def _paths(path_or_object: None | str | list[str] | tuple[str] | Any) -> list[tu
The list of paths.
"""
if path_or_object is None:
_, paths = _module_versions(full=False)
_, paths = _module_versions()
return paths

if isinstance(path_or_object, (list, tuple, set)):
Expand Down Expand Up @@ -463,7 +514,7 @@ def gather_provenance_info(assets: list[str] = [], full: bool = False) -> dict[s
time=datetime.datetime.utcnow().isoformat(),
python=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
module_versions=versions,
distribution_names=import_name_to_distribution_name(versions.keys()),
distribution_names=import_name_to_distribution_name(list(versions.keys())),
git_versions=git_versions,
)
else:
Expand All @@ -474,7 +525,7 @@ def gather_provenance_info(assets: list[str] = [], full: bool = False) -> dict[s
python_path=sys.path,
config_paths=sysconfig.get_paths(),
module_versions=versions,
distribution_names=import_name_to_distribution_name(versions.keys()),
distribution_names=import_name_to_distribution_name(list(versions.keys())),
git_versions=git_versions,
platform=platform_info(),
gpus=gpu_info(),
Expand Down
68 changes: 68 additions & 0 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,77 @@
# nor does it submit to any jurisdiction.


import json
import sys
import tempfile
from pathlib import Path

from anemoi.utils import provenance


def test_gather() -> None:
"""Test success of gather_provenance_info."""
provenance.gather_provenance_info()


def test_get_package_source_url_with_git_integration():
"""Test _get_package_source_url on a synthetic .dist-info structure.

Create a temporary package metadata directory that Python's importlib.metadata
can read
"""
with tempfile.TemporaryDirectory() as tmpdir:
# This is what pip creates when installing a package
dist_info = Path(tmpdir) / "test_git_package-1.0.0.dist-info"
dist_info.mkdir()

# Create METADATA file (required for importlib.metadata to recognize it)
metadata_content = """Metadata-Version: 2.1
Name: test-git-package
Version: 1.0.0
"""
(dist_info / "METADATA").write_text(metadata_content)

# Create direct_url.json - this is what pip creates for git installs
direct_url_content = {
"url": "git+https://github.com/ways/[email protected]",
"vcs_info": {
"commit_id": "abc123def456",
"requested_revision": "models-0.12.0",
"vcs": "git",
},
"subdirectory": "models",
}
(dist_info / "direct_url.json").write_text(json.dumps(direct_url_content))

# Add tmpdir to sys.path so importlib.metadata can discover it
sys.path.insert(0, tmpdir)

try:
result = provenance._get_package_source_url("test-git-package")

assert result is not None, "Should return source info for git package"
assert result["url"] == "git+https://github.com/ways/[email protected]"
assert result["vcs_info"]["commit_id"] == "abc123def456"
assert result["vcs_info"]["requested_revision"] == "models-0.12.0"
assert result["subdirectory"] == "models"

finally:
# Clean up
sys.path.remove(tmpdir)


def test_get_package_source_url_regular_package_integration():
"""Test with a real regular package installed from PyPI."""
# pip is always installed and almost always from PyPI, not git
result = provenance._get_package_source_url("pip")

# For a regular PyPI package, should return None (no direct_url.json)
# In rare dev environments pip might be from git, but that's also valid
assert result is None or isinstance(result, dict)


def test_get_package_source_url_nonexistent():
"""Test with a package that doesn't exist."""
result = provenance._get_package_source_url("nonexistent-package-xyz-12345")
assert result is None