From a5e506a6dba5423b5c7c74de04b270de8b66abc0 Mon Sep 17 00:00:00 2001 From: AntObi Date: Fri, 10 Nov 2023 21:29:24 +0000 Subject: [PATCH] Remove unused lines in composition module --- src/elementembeddings/composition.py | 14 -------------- src/elementembeddings/tests/test_composition.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/elementembeddings/composition.py b/src/elementembeddings/composition.py index 3c05c2b..436ce24 100644 --- a/src/elementembeddings/composition.py +++ b/src/elementembeddings/composition.py @@ -145,11 +145,6 @@ def num_atoms(self) -> float: """Total number of atoms in Composition.""" return self._natoms - @property - def embedding_dim(self) -> int: - """Dimension of the embedding.""" - return self.embedding.dim - def as_dict(self) -> dict: # TO-DO: Need to create a dict representation for the embedding class """Return the CompositionalEmbedding class as a dict.""" @@ -158,9 +153,6 @@ def as_dict(self) -> dict: "composition": self.composition, "fractional_composition": self.fractional_composition, } - # Se - - # Set an attribute def _mean_feature_vector(self) -> np.ndarray: """Compute a weighted mean feature vector based of the embedding. @@ -241,12 +233,6 @@ def feature_vector(self, stats: Union[str, list] = "mean"): ] if isinstance(stats, str): stats = [stats] - if not isinstance(stats, list): - msg = "Stats argument must be a list of strings" - raise ValueError(msg) - if not all(isinstance(s, str) for s in stats): - msg = "Stats argument must be a list of strings" - raise ValueError(msg) if not all(s in implemented_stats for s in stats): msg = ( f" {[stat for stat in stats if stat not in implemented_stats]} " diff --git a/src/elementembeddings/tests/test_composition.py b/src/elementembeddings/tests/test_composition.py index e1ef0ed..9e5dc7d 100644 --- a/src/elementembeddings/tests/test_composition.py +++ b/src/elementembeddings/tests/test_composition.py @@ -182,3 +182,13 @@ def test_composition_distance(self): ) == 0 ) + + assert self.valid_magpie_compositions[0].distance( + self.formulas[1], + stats=["mean", "variance"], + distance_metric="cosine_distance", + ) == self.valid_magpie_compositions[1].distance( + self.valid_magpie_compositions[0], + stats=["mean", "variance"], + distance_metric="cosine_distance", + )