Skip to content

Commit

Permalink
Add caching support for multiple backends
Browse files Browse the repository at this point in the history
remove __all__ assignment from graph.py
  • Loading branch information
Jnelen committed Mar 4, 2024
1 parent 4697eb1 commit 5b44a54
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
11 changes: 0 additions & 11 deletions spyrmsd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,6 @@ def set_backend(backend):
_current_backend = backend


__all__ = [
"graph_from_adjacency_matrix",
"match_graphs",
"vertex_property",
"num_vertices",
"num_edges",
"lattice",
"cycle",
"adjacency_matrix_from_atomic_coordinates",
]

if len(_available_backends) == 0:
raise ImportError(
"No valid backends found. Please ensure that either graph-tool or NetworkX are installed."
Expand Down
18 changes: 10 additions & 8 deletions spyrmsd/molecule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
self.adjacency_matrix: np.ndarray = np.asarray(adjacency_matrix, dtype=int)

# Molecular graph
self.G = None
self.G: Dict[str, object] = {}

self.masses: Optional[List[float]] = None

Expand Down Expand Up @@ -182,7 +182,7 @@ def strip(self) -> None:
self.adjacency_matrix = self.adjacency_matrix[np.ix_(idx, idx)]

# Reset molecular graph when stripping
self.G = None
self.G = {}

self.stripped = True

Expand All @@ -200,11 +200,13 @@ def to_graph(self):
If the molecule does not have an associated adjacency matrix, a simple
bond perception is used.
The molecular graph is cached.
The molecular graph is cached per backend.
"""
if self.G is None:
_current_backend = graph._current_backend

if _current_backend not in self.G.keys():
try:
self.G = graph.graph_from_adjacency_matrix(
self.G[_current_backend] = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)
except AttributeError:
Expand All @@ -218,11 +220,11 @@ def to_graph(self):
self.atomicnums, self.coordinates
)

self.G = graph.graph_from_adjacency_matrix(
self.G[_current_backend] = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)

return self.G
return self.G[_current_backend]

def __len__(self) -> int:
"""
Expand Down

0 comments on commit 5b44a54

Please sign in to comment.