From 2c62c16911e7496a63065b12f9d20c86eeb28d57 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 12 Nov 2025 16:51:51 -0500 Subject: [PATCH 1/8] wip --- src/pymatgen/core/sites.py | 10 +-- src/pymatgen/entries/__init__.py | 14 ++-- src/pymatgen/entries/computed_entries.py | 15 ++--- src/pymatgen/io/optimade.py | 3 +- tests/analysis/test_phase_diagram.py | 86 +++++++++++++++++++++--- 5 files changed, 99 insertions(+), 29 deletions(-) diff --git a/src/pymatgen/core/sites.py b/src/pymatgen/core/sites.py index 6144d612c8e..4127ff9d6cf 100644 --- a/src/pymatgen/core/sites.py +++ b/src/pymatgen/core/sites.py @@ -269,10 +269,11 @@ def as_dict(self) -> dict: @classmethod def from_dict(cls, dct: dict) -> Self: """Create Site from dict representation.""" - atoms_n_occu = {} + atoms_n_occu: dict[SpeciesLike, float] = {} for sp_occu in dct["species"]: + sp: SpeciesLike if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]): - sp: Species | DummySpecies | Element = Species.from_dict(sp_occu) + sp = Species.from_dict(sp_occu) elif "oxidation_state" in sp_occu: sp = DummySpecies.from_dict(sp_occu) else: @@ -636,10 +637,11 @@ def from_dict(cls, dct: dict, lattice: Lattice | None = None) -> Self: Returns: PeriodicSite """ - species = {} + species: dict[SpeciesLike, float] = {} for sp_occu in dct["species"]: + sp: SpeciesLike if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]): - sp: Species | DummySpecies | Element = Species.from_dict(sp_occu) + sp = Species.from_dict(sp_occu) elif "oxidation_state" in sp_occu: sp = DummySpecies.from_dict(sp_occu) else: diff --git a/src/pymatgen/entries/__init__.py b/src/pymatgen/entries/__init__.py index af5da36fff5..d07b6879202 100644 --- a/src/pymatgen/entries/__init__.py +++ b/src/pymatgen/entries/__init__.py @@ -19,6 +19,7 @@ from typing import Literal from pymatgen.core import DummySpecies, Element, Species + from pymatgen.util.typing import CompositionLike __author__ = "Shyue Ping Ong, Anubhav Jain, Ayush Gupta" @@ -37,20 +38,19 @@ class Entry(MSONable, ABC): which inherit from Entry must define a .energy property. """ - def __init__(self, composition: Composition | str | dict[str, float], energy: float) -> None: + def __init__(self, composition: CompositionLike, energy: float) -> None: """Initialize an Entry. Args: - composition (Composition): Composition of the entry. For + composition (CompositionLike): Composition of the entry. For flexibility, this can take the form of all the typical input taken by a Composition, including a {symbol: amt} dict, a string formula, and others. energy (float): Energy of the entry. """ - if isinstance(composition, Composition): - self._composition = composition - else: - self._composition = Composition(composition) - # self._composition = Composition(composition) + if not isinstance(composition, Composition): + composition = Composition(composition) + + self._composition = composition self._energy = energy @property diff --git a/src/pymatgen/entries/computed_entries.py b/src/pymatgen/entries/computed_entries.py index 19f68048d7a..aebbebc64c4 100644 --- a/src/pymatgen/entries/computed_entries.py +++ b/src/pymatgen/entries/computed_entries.py @@ -32,6 +32,7 @@ from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Structure + from pymatgen.util.typing import CompositionLike __author__ = "Ryan Kingsbury, Matt McDermott, Shyue Ping Ong, Anubhav Jain" __copyright__ = "Copyright 2011-2020, The Materials Project" @@ -292,7 +293,7 @@ class ComputedEntry(Entry): def __init__( self, - composition: Composition | str | dict[str, float], + composition: CompositionLike, energy: float, correction: float = 0.0, energy_adjustments: list | None = None, @@ -558,7 +559,7 @@ def __init__( structure: Structure, energy: float, correction: float = 0.0, - composition: Composition | str | dict[str, float] | None = None, + composition: CompositionLike | None = None, energy_adjustments: list | None = None, parameters: dict | None = None, data: dict | None = None, @@ -585,12 +586,10 @@ def __init__( with the entry. Defaults to None. entry_id: An optional id to uniquely identify the entry. """ - if composition: - if isinstance(composition, Composition): - pass - else: + if composition is not None: + if not isinstance(composition, Composition): composition = Composition(composition) - # composition = Composition(composition) + if ( composition.get_integer_formula_and_factor()[0] != structure.composition.get_integer_formula_and_factor()[0] @@ -706,7 +705,7 @@ def __init__( formation_enthalpy_per_atom: float, temp: float = 300, gibbs_model: Literal["SISSO"] = "SISSO", - composition: Composition | None = None, + composition: CompositionLike | None = None, correction: float = 0.0, energy_adjustments: list | None = None, parameters: dict | None = None, diff --git a/src/pymatgen/io/optimade.py b/src/pymatgen/io/optimade.py index 14490f34a54..2bf9c7fec40 100644 --- a/src/pymatgen/io/optimade.py +++ b/src/pymatgen/io/optimade.py @@ -26,6 +26,7 @@ from typing import Any from pymatgen.core.structure import IStructure + from pymatgen.util.typing import CompositionLike __author__ = "Matthew Evans" @@ -33,7 +34,7 @@ def _pymatgen_species( nsites: int, species_at_sites: list[str], -) -> list[dict[str, float]]: +) -> list[CompositionLike]: """Create list of {"symbol": "concentration"} per site for constructing pymatgen Species objects. Removes vacancies, if they are present. diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index adeba0c77de..50bf268f47b 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -9,6 +9,7 @@ import numpy as np import plotly.graph_objects as go import pytest +from monty.json import MontyDecoder from monty.serialization import dumpfn, loadfn from numpy.testing import assert_allclose from pytest import approx @@ -636,12 +637,21 @@ def test_as_from_dict(self): pd_dict = pd.as_dict() pd_roundtrip = PhaseDiagram.from_dict(pd_dict) assert pd.all_entries[0].entry_id == pd_roundtrip.all_entries[0].entry_id - dd = self.pd.as_dict() - new_pd = PhaseDiagram.from_dict(dd) - new_pd_dict = new_pd.as_dict() - assert new_pd_dict == dd + + pd_dict = self.pd.as_dict() + reconstructed_pd = PhaseDiagram.from_dict(pd_dict) + reconstructed_pd_dict = reconstructed_pd.as_dict() + assert reconstructed_pd_dict == pd_dict assert isinstance(pd.to_json(), str) + assert MontyDecoder().process_decoded(pd_dict).as_dict() == pd_dict + + for entry in self.pd.all_entries: + _decomp_rpd, e_above_hull_rppd = reconstructed_pd.get_decomp_and_e_above_hull(entry) + _decomp_pd, e_above_hull_ppd = self.pd.get_decomp_and_e_above_hull(entry) + # assert decomp_rpd == decomp_pd, f"Decomposition for {entry} is not correct!" + assert np.isclose(e_above_hull_rppd, e_above_hull_ppd) + def test_read_json(self): dumpfn(self.pd, f"{self.tmp_path}/pd.json") pd = loadfn(f"{self.tmp_path}/pd.json") @@ -809,9 +819,23 @@ def test_get_hull_energy(self): assert np.isclose(e_hull_pd, e_hull_ppd) def test_get_decomp_and_e_above_hull(self): - for entry in self.pd.stable_entries: + for entry in self.pd.all_entries: decomp_pd, e_above_hull_pd = self.pd.get_decomp_and_e_above_hull(entry) decomp_ppd, e_above_hull_ppd = self.ppd.get_decomp_and_e_above_hull(entry, check_stable=True) + + # Check that decompositions sum to the original composition + for decomp, name in [(decomp_pd, "pd"), (decomp_ppd, "ppd")]: + decomp_comp = Composition({}) + for e, amount in decomp.items(): + comp_scaled = e.composition.fractional_composition * amount + decomp_comp += comp_scaled + assert decomp_comp.almost_equals( + entry.composition.fractional_composition, rtol=0, atol=Composition.amount_tolerance + ), ( + f"Decomposition for {entry} from {name} does not sum to original composition! " + f"Expected {entry.composition}, got {decomp_comp}" + ) + assert decomp_pd == decomp_ppd assert np.isclose(e_above_hull_pd, e_above_hull_ppd) @@ -822,10 +846,54 @@ def test_as_from_dict(self): ppd_dict = self.ppd.as_dict() assert ppd_dict["@module"] == type(self.ppd).__module__ assert ppd_dict["@class"] == type(self.ppd).__name__ - assert ppd_dict["all_entries"] == [entry.as_dict() for entry in self.ppd.all_entries] - assert ppd_dict["elements"] == [elem.as_dict() for elem in self.ppd.elements] - # test round-trip dict serialization - assert PatchedPhaseDiagram.from_dict(ppd_dict).as_dict() == ppd_dict + + # Check new format with computed_data and deduplicated entries + assert "computed_data" in ppd_dict + computed_data = ppd_dict["computed_data"] + assert "unique_entries" in computed_data + assert "all_entries" in computed_data + assert isinstance(computed_data["all_entries"], list) + assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) + + # Verify entries can be reconstructed from indices + unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] + reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] + assert len(reconstructed_entries) == len(self.ppd.all_entries) + assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] + + reconstructed_ppd = PatchedPhaseDiagram.from_dict(ppd_dict) + reconstructed_dict = reconstructed_ppd.as_dict() + assert ppd_dict == reconstructed_dict + + assert MontyDecoder().process_decoded(ppd_dict).as_dict() == ppd_dict + + for entry in self.ppd.all_entries: + decomp_pd, e_above_hull_pd = self.pd.get_decomp_and_e_above_hull(entry) + for pd, name in [ + (self.ppd, "ppd"), + (reconstructed_ppd, "rppd"), + ]: + decomp, e_above_hull = pd.get_decomp_and_e_above_hull(entry, check_stable=True) + decomp_comp = Composition({}) + for e, amount in decomp.items(): + comp_scaled = e.composition.fractional_composition * amount + decomp_comp += comp_scaled + assert decomp_comp.almost_equals( + entry.composition.fractional_composition, rtol=0, atol=Composition.amount_tolerance + ), ( + f"Decomposition for {entry} from {name} does not sum to original composition! " + f"Expected {entry.composition}, got {decomp_comp}" + ) + + assert decomp == approx(decomp_pd), f"Decomposition for {entry} is not correct!" + assert np.isclose(e_above_hull, e_above_hull_pd) + + def test_read_json(self, tmp_path): + dumpfn(self.ppd, f"{tmp_path}/ppd.json") + ppd = loadfn(f"{tmp_path}/ppd.json") + assert isinstance(ppd, PatchedPhaseDiagram) + assert ppd.elements == self.ppd.elements + assert {*ppd.as_dict()} == {*self.ppd.as_dict()} def test_get_pd_for_entry(self): for entry in self.ppd.all_entries: From e08599a8e520abedff3029c9681c8be65bf16417 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 12 Nov 2025 16:53:57 -0500 Subject: [PATCH 2/8] wip --- tests/analysis/test_phase_diagram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 50bf268f47b..82f0919a6c2 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -836,7 +836,7 @@ def test_get_decomp_and_e_above_hull(self): f"Expected {entry.composition}, got {decomp_comp}" ) - assert decomp_pd == decomp_ppd + assert decomp_pd == approx(decomp_ppd) assert np.isclose(e_above_hull_pd, e_above_hull_ppd) def test_repr(self): From a72390d68295c50ad98599ad226900b8c0c7f515 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 12 Nov 2025 17:19:29 -0500 Subject: [PATCH 3/8] wip --- tests/analysis/test_phase_diagram.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 82f0919a6c2..2359c1162cc 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -847,19 +847,19 @@ def test_as_from_dict(self): assert ppd_dict["@module"] == type(self.ppd).__module__ assert ppd_dict["@class"] == type(self.ppd).__name__ - # Check new format with computed_data and deduplicated entries - assert "computed_data" in ppd_dict - computed_data = ppd_dict["computed_data"] - assert "unique_entries" in computed_data - assert "all_entries" in computed_data - assert isinstance(computed_data["all_entries"], list) - assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) - - # Verify entries can be reconstructed from indices - unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] - reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] - assert len(reconstructed_entries) == len(self.ppd.all_entries) - assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] + # # Check new format with computed_data and deduplicated entries + # assert "computed_data" in ppd_dict + # computed_data = ppd_dict["computed_data"] + # assert "unique_entries" in computed_data + # assert "all_entries" in computed_data + # assert isinstance(computed_data["all_entries"], list) + # assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) + + # # Verify entries can be reconstructed from indices + # unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] + # reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] + # assert len(reconstructed_entries) == len(self.ppd.all_entries) + # assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] reconstructed_ppd = PatchedPhaseDiagram.from_dict(ppd_dict) reconstructed_dict = reconstructed_ppd.as_dict() From d77a45d42e7dbf419ea60579580bedbb51e1d915 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 12 Nov 2025 18:47:01 -0500 Subject: [PATCH 4/8] working? --- src/pymatgen/analysis/phase_diagram.py | 578 ++++++++++++++++++++----- tests/analysis/test_phase_diagram.py | 53 ++- 2 files changed, 514 insertions(+), 117 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 2610d96049b..13c4f1882fe 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -10,7 +10,7 @@ import warnings from collections import defaultdict from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt import numpy as np @@ -43,6 +43,8 @@ from numpy.typing import ArrayLike from typing_extensions import Self + from pymatgen.util.typing import CompositionLike + logger = logging.getLogger(__name__) with open( @@ -69,14 +71,14 @@ class PDEntry(Entry): def __init__( self, - composition: Composition, + composition: CompositionLike, energy: float, name: str | None = None, attribute: object = None, ): """ Args: - composition (Composition): Composition + composition (CompositionLike): Composition energy (float): Energy for composition. name (str): Optional parameter to name the entry. Defaults to the reduced chemical formula. @@ -341,8 +343,8 @@ class PhaseDiagram(MSONable): def __init__( self, - entries: Sequence[PDEntry] | set[PDEntry], - elements: Sequence[Element] = (), + entries: Collection[Entry], + elements: Collection[Element] = (), *, computed_data: dict[str, Any] | None = None, ) -> None: @@ -375,6 +377,7 @@ def __init__( # Update keys to be Element objects in case they are strings in pre-computed data computed_data["el_refs"] = [(Element(el_str), entry) for el_str, entry in computed_data["el_refs"]] + self.computed_data = computed_data self.facets = computed_data["facets"] self.simplexes = computed_data["simplexes"] @@ -392,14 +395,27 @@ def as_dict(self): qhull_entry_indices = [self.all_entries.index(e) for e in self.qhull_entries] + # Create a copy of computed_data to avoid modifying the original + computed_data = self.computed_data.copy() + computed_data["elements"] = [el.symbol for el in self.elements] + computed_data["el_refs"] = [(el.symbol, entry.as_dict()) for el, entry in computed_data["el_refs"]] + computed_data["all_entries"] = [e.as_dict() for e in computed_data["all_entries"]] + computed_data["qhull_entries"] = qhull_entry_indices + computed_data["qhull_data"] = ( + computed_data["qhull_data"].tolist() + if isinstance(computed_data["qhull_data"], np.ndarray) + else computed_data["qhull_data"] + ) + computed_data["facets"] = [list(facet) for facet in computed_data["facets"]] + computed_data["simplexes"] = [ + {**s.as_dict(), "coords": s.as_dict()["coords"].tolist()} for s in computed_data["simplexes"] + ] + return { "@module": type(self).__module__, "@class": type(self).__name__, - "elements": [e.as_dict() for e in self.elements], - "computed_data": self.computed_data - | { - "qhull_entries": qhull_entry_indices, - }, + "elements": [el.symbol for el in self.elements], + "computed_data": computed_data, } @classmethod @@ -412,7 +428,7 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: PhaseDiagram """ computed_data = dct.get("computed_data") - elements = [Element.from_dict(elem) for elem in dct["elements"]] + elements = [Element(elem) for elem in dct["elements"]] # for backwards compatibility, check for old format if "all_entries" in dct: @@ -420,9 +436,14 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: else: entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["all_entries"]] - complete_qhull_entries = [computed_data["all_entries"][i] for i in computed_data["qhull_entries"]] - - computed_data = computed_data | {"qhull_entries": complete_qhull_entries} + # Reconstruct computed_data to match _compute() format: (str, Entry) tuples for el_refs + computed_data = computed_data.copy() + computed_data["qhull_entries"] = [entries[i] for i in computed_data["qhull_entries"]] + computed_data["elements"] = [Element(el) for el in computed_data["elements"]] + # Keep el_refs as (str, Entry) format to match _compute() output + computed_data["el_refs"] = [ + (el_str, MontyDecoder().process_decoded(entry)) for el_str, entry in computed_data["el_refs"] + ] return cls(entries, elements, computed_data=computed_data) @@ -435,9 +456,9 @@ def _compute(self) -> dict[str, Any]: entries = sorted(self.entries, key=lambda e: e.composition.reduced_composition) - el_refs: dict[Element, PDEntry] = {} - min_entries: list[PDEntry] = [] - all_entries: list[PDEntry] = [] + el_refs: dict[Element, Entry] = {} + min_entries: list[Entry] = [] + all_entries: list[Entry] = [] for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition): group = list(group_iter) min_entry = min(group, key=lambda e: e.energy_per_atom) @@ -887,6 +908,9 @@ def get_decomp_and_phase_separation_energy( # Handle elemental materials if entry.is_element: + # If stable_only=True, use check_stable=True for fast path + if stable_only: + kwargs.setdefault("check_stable", True) return self.get_decomp_and_e_above_hull(entry, allow_negative=True, **kwargs) # Select space to compare against @@ -980,7 +1004,7 @@ def get_phase_separation_energy(self, entry, **kwargs): """ return self.get_decomp_and_phase_separation_energy(entry, **kwargs)[1] - def get_composition_chempots(self, comp): + def get_composition_chempots(self, comp: Composition) -> dict[Element, float]: """Get the chemical potentials for all elements at a given composition. Args: @@ -992,7 +1016,7 @@ def get_composition_chempots(self, comp): facet = self._get_facet_and_simplex(comp)[0] return self._get_facet_chempots(facet) - def get_all_chempots(self, comp): + def get_all_chempots(self, comp: Composition) -> dict[str, dict[Element, float]]: """Get chemical potentials at a given composition. Args: @@ -1010,7 +1034,7 @@ def get_all_chempots(self, comp): return chempots - def get_transition_chempots(self, element): + def get_transition_chempots(self, element: Element) -> tuple[float, ...]: """Get the critical chemical potentials for an element in the Phase Diagram. @@ -1029,7 +1053,7 @@ def get_transition_chempots(self, element): chempots = self._get_facet_chempots(facet) critical_chempots.append(chempots[element]) - clean_pots = [] + clean_pots: list[float] = [] for c in sorted(critical_chempots): if len(clean_pots) == 0 or not math.isclose( c, clean_pots[-1], abs_tol=PhaseDiagram.numerical_tol, rel_tol=0 @@ -1038,7 +1062,7 @@ def get_transition_chempots(self, element): clean_pots.reverse() return tuple(clean_pots) - def get_critical_compositions(self, comp1, comp2): + def get_critical_compositions(self, comp1: Composition, comp2: Composition) -> list[Composition]: """Get the critical compositions along the tieline between two compositions. I.e. where the decomposition products change. The endpoints are also returned. @@ -1098,7 +1122,7 @@ def get_critical_compositions(self, comp1, comp2): return [Composition((elem, val) for elem, val in zip(pd_els, m, strict=True)) for m in cs] - def get_element_profile(self, element, comp, comp_tol=1e-5): + def get_element_profile(self, element: Element, comp: Composition, comp_tol: float = 1e-5) -> list[dict[str, Any]]: """ Provides the element evolution data for a composition. For example, can be used to analyze Li conversion voltages by varying mu_Li and looking at the phases @@ -1199,7 +1223,9 @@ def get_chempot_range_map( return chempot_ranges - def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): + def getmu_vertices_stability_phase( + self, target_comp: Composition, dep_elt: Element, tol_en: float = 1e-2 + ) -> list[dict[Element, float]] | None: """Get a set of chemical potentials corresponding to the vertices of the simplex in the chemical potential phase diagram. The simplex is built using all elements in the target_composition @@ -1233,11 +1259,11 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): if elem.composition.reduced_composition == target_comp.reduced_composition: multiplier = elem.composition[dep_elt] / target_comp[dep_elt] ef = elem.energy / multiplier - all_coords = [] + all_coords: list[dict[Element, float]] = [] for simplex in chempots: for v in simplex._coords: elements = [elem for elem in self.elements if elem != dep_elt] - res = {} + res: dict[Element, float] = {} for idx, el in enumerate(elements): res[el] = v[idx] + mu_ref[idx] res[dep_elt] = (np.dot(v + mu_ref, coeff) + ef) / target_comp[dep_elt] @@ -1257,7 +1283,9 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): return all_coords return None - def get_chempot_range_stability_phase(self, target_comp, open_elt): + def get_chempot_range_stability_phase( + self, target_comp: Composition, open_elt: Element + ) -> dict[Element, tuple[float, float]]: """Get a set of chemical potentials corresponding to the max and min chemical potential of the open element for a given composition. It is quite common to have for instance a ternary oxide (e.g., ABO3) for @@ -1408,18 +1436,27 @@ class GrandPotentialPhaseDiagram(PhaseDiagram): doi:10.1016/j.elecom.2010.01.010 """ - def __init__(self, entries, chempots, elements=None, *, computed_data=None): + def __init__( + self, + entries: Collection[Entry], + chempots: dict[Element, float], + elements: Collection[Element] | None = None, + *, + computed_data: dict[str, Any] | None = None, + ): """Standard constructor for grand potential phase diagram. + TODO: update serialization here. + Args: - entries ([PDEntry]): A list of PDEntry-like objects having an + entries (Sequence[Entry]): A list of Entry objects having an energy, energy_per_atom and composition. - chempots ({Element: float}): Specify the chemical potentials + chempots (dict[Element, float]): Specify the chemical potentials of the open elements. - elements ([Element]): Optional list of elements in the phase + elements (Sequence[Element]): Optional list of elements in the phase diagram. If set to None, the elements are determined from the entries themselves. - computed_data (dict): A dict containing pre-computed data. This allows + computed_data (dict[str, Any]): A dict containing pre-computed data. This allows PhaseDiagram object to be reconstituted without performing the expensive convex hull computation. The dict is the output from the PhaseDiagram._compute() method and is stored in PhaseDiagram.computed_data @@ -1481,7 +1518,12 @@ class CompoundPhaseDiagram(PhaseDiagram): # Tolerance for determining if amount of a composition is positive. amount_tol = 1e-5 - def __init__(self, entries, terminal_compositions, normalize_terminal_compositions=True): + def __init__( + self, + entries: Sequence[Entry], + terminal_compositions: Sequence[Composition], + normalize_terminal_compositions: bool = True, + ): """Initialize a CompoundPhaseDiagram. Args: @@ -1532,7 +1574,9 @@ def num2str(num): return ret - def transform_entries(self, entries, terminal_compositions): + def transform_entries( + self, entries: Sequence[Entry], terminal_compositions: Sequence[Composition] + ) -> tuple[list[TransformedPDEntry], dict[Composition, DummySpecies]]: """ Method to transform all entries to the composition coordinate in the terminal compositions. If the entry does not fall within the space @@ -1540,6 +1584,8 @@ def transform_entries(self, entries, terminal_compositions): Li3PO4 is mapped into a Li2O:1.5, P2O5:0.5 composition. The terminal compositions are represented by DummySpecies. + TODO: update serialization here. + Args: entries: Sequence of all input entries terminal_compositions: Terminal compositions of phase space. @@ -1622,36 +1668,35 @@ class PatchedPhaseDiagram(PhaseDiagram): elements (list[Element]): List of elements in the phase diagram. """ - def __init__( + def _compute( self, entries: Sequence[Entry] | set[Entry], elements: Sequence[Element] | None = None, keep_all_spaces: bool = False, verbose: bool = False, - ) -> None: + ) -> dict[str, Any]: """ + Compute the phase diagram data for PatchedPhaseDiagram. + Args: - entries (list[PDEntry]): A list of PDEntry-like objects having an - energy, energy_per_atom and composition. - elements (list[Element], optional): Optional list of elements in the phase - diagram. If set to None, the elements are determined from - the entries themselves and are sorted alphabetically. - If specified, element ordering (e.g. for pd coordinates) - is preserved. - keep_all_spaces (bool): Pass True to keep chemical spaces that are subspaces - of other spaces. - verbose (bool): Whether to show progress bar during convex hull construction. + entries: A list of Entry objects. + elements: Optional list of elements in the phase diagram. + keep_all_spaces: Whether to keep chemical spaces that are subspaces of other spaces. + verbose: Whether to show progress bar during convex hull construction. + + Returns: + dict containing computed_data with proper indexing for serialization. """ if elements is None: elements = sorted({els for entry in entries for els in entry.elements}) - self.dim = len(elements) + dim = len(elements) entries = sorted(entries, key=lambda e: e.composition.reduced_composition) - el_refs: dict[Element, PDEntry] = {} - min_entries = [] - all_entries: list[PDEntry] = [] + el_refs: dict[Element, Entry] = {} + min_entries: list[Entry] = [] + all_entries: list[Entry] = [] for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition): group = list(group_iter) min_entry = min(group, key=lambda e: e.energy_per_atom) @@ -1660,10 +1705,10 @@ def __init__( min_entries.append(min_entry) all_entries.extend(group) - if len(el_refs) < self.dim: + if len(el_refs) < dim: missing = set(elements) - set(el_refs) raise ValueError(f"Missing terminal entries for elements {sorted(map(str, missing))}") - if len(el_refs) > self.dim: + if len(el_refs) > dim: extra = set(el_refs) - set(elements) raise ValueError(f"There are more terminal elements than dimensions: {extra}") @@ -1694,16 +1739,186 @@ def __init__( spaces = {s for s in qhull_spaces if len(s) > 1} # Remove redundant chemical spaces - spaces = self.remove_redundant_spaces(spaces, keep_all_spaces) + spaces = PatchedPhaseDiagram.remove_redundant_spaces(spaces, keep_all_spaces) + + spaces_list = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last + + # Build PhaseDiagrams for each space and collect their computed_data + pds_computed_data = {} + for space in tqdm(spaces_list, disable=not verbose): + space_entries = [e for e, s in zip(qhull_entries, qhull_spaces, strict=True) if space.issuperset(s)] + pd = PhaseDiagram(space_entries) + + # Get indices into all_entries for this subspace's all_entries + # pd.all_entries are the entries that PhaseDiagram computed with + # They should be in all_entries, so we can use index() which uses object identity + # If that fails, entries may be equal but different objects, so fall back to equality + # IMPORTANT: We must preserve order to keep facets valid + subspace_all_entry_indices = [] + used_indices = set() # Track used indices to avoid duplicates when using equality fallback + for pd_entry in pd.all_entries: + try: + idx = all_entries.index(pd_entry) + subspace_all_entry_indices.append(idx) + used_indices.add(idx) + except ValueError: + # Entry not found by identity, try equality + # But skip indices we've already used to preserve order + for idx, global_entry in enumerate(all_entries): + if idx not in used_indices and pd_entry == global_entry: + subspace_all_entry_indices.append(idx) + used_indices.add(idx) + break + else: + raise ValueError(f"pd.all_entries entry {pd_entry} not found in all_entries") + + # Get indices into this subspace's all_entries for qhull_entries + # pd.qhull_entries are entries from pd.all_entries, so we can use index + subspace_qhull_entry_indices = [pd.all_entries.index(entry) for entry in pd.qhull_entries] + + # Convert el_refs to indices into subspace's all_entries + # el_refs entries are from pd.all_entries, so we can use index + subspace_el_refs = [(el, pd.all_entries.index(entry)) for el, entry in pd.computed_data["el_refs"]] + + pds_computed_data[space] = { + "all_entries": subspace_all_entry_indices, + "qhull_entries": subspace_qhull_entry_indices, + "facets": pd.computed_data["facets"], + "simplexes": pd.computed_data["simplexes"], + "qhull_data": pd.computed_data["qhull_data"].tolist(), + "dim": pd.computed_data["dim"], + "el_refs": subspace_el_refs, + "elements": tuple(pd.elements), + } - # TODO comprhys: refactor to have self._compute method to allow serialization - self.spaces = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last - self.qhull_entries = qhull_entries - self._qhull_spaces = qhull_spaces - self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=not verbose)) - self.all_entries = all_entries - self.el_refs = el_refs - self.elements = elements + return { + "all_entries": all_entries, + "elements": elements, + "dim": dim, + "el_refs": [(el.symbol, all_entries.index(entry)) for el, entry in el_refs.items()], + "qhull_entries": [all_entries.index(entry) for entry in qhull_entries], + "spaces": spaces_list, + "pds": pds_computed_data, + } + + def __init__( + self, + entries: Sequence[Entry] | set[Entry], + elements: Sequence[Element] | None = None, + keep_all_spaces: bool = False, + verbose: bool = False, + *, + computed_data: dict[str, Any] | None = None, + ) -> None: + """ + Args: + entries (Sequence[Entry] | set[Entry]): A list of Entry objects having an + energy, energy_per_atom and composition. + elements (Sequence[Element], optional): Optional list of elements in the phase + diagram. If set to None, the elements are determined from + the entries themselves and are sorted alphabetically. + If specified, element ordering (e.g. for pd coordinates) + is preserved. + keep_all_spaces (bool): Pass True to keep chemical spaces that are subspaces + of other spaces. + verbose (bool): Whether to show progress bar during convex hull construction. + computed_data (dict): A dict containing pre-computed data. This allows + PatchedPhaseDiagram object to be reconstituted without performing the + expensive convex hull computation. The dict is the output from the + PatchedPhaseDiagram._compute() method. + """ + if computed_data is None: + computed_data = self._compute(entries, elements, keep_all_spaces, verbose) + else: + computed_data = MontyDecoder().process_decoded(computed_data) + if not isinstance(computed_data, dict): + raise TypeError(f"computed_data should be dict, got {type(computed_data).__name__}") + + self.computed_data = computed_data + self.all_entries = computed_data["all_entries"] + self.elements = computed_data["elements"] + self.dim = computed_data["dim"] + + # Convert el_refs from [(el_symbol, index), ...] or [(el_symbol, Entry), ...] to {Element: Entry} + el_refs_data = computed_data["el_refs"] + if el_refs_data and isinstance(el_refs_data[0][1], int): + # el_refs are stored as indices + self.el_refs = {Element(el_symbol): self.all_entries[idx] for el_symbol, idx in el_refs_data} + else: + # el_refs are already entry objects (from from_dict reconstruction) + self.el_refs = {Element(el_symbol): entry for el_symbol, entry in el_refs_data} + + # Convert qhull_entries from indices to entry objects + # When from _compute(), qhull_entries are indices; when from from_dict(), they're already entries + qhull_entries_data = computed_data["qhull_entries"] + if qhull_entries_data and isinstance(qhull_entries_data[0], int): + # qhull_entries are stored as indices into all_entries + self.qhull_entries = tuple(self.all_entries[idx] for idx in qhull_entries_data) + else: + # qhull_entries are already entry objects (from from_dict reconstruction) + self.qhull_entries = tuple(qhull_entries_data) + + self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries) + # Convert spaces from tuples (serialized) or frozensets (in-memory) to frozensets + self.spaces = [ + space if isinstance(space, frozenset) else frozenset(Element(el) for el in space) + for space in computed_data["spaces"] + ] + + # Reconstruct PhaseDiagrams from computed_data + self.pds = {} + for space_key, pd_computed_data in computed_data["pds"].items(): + # Handle both frozenset (in-memory) and tuple (serialized) keys + if not isinstance(space_key, frozenset): + space_key = frozenset(Element(el) for el in space_key) + stored_elements = pd_computed_data.get("elements") + if stored_elements: + subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in stored_elements] + else: + # Fallback to deterministic ordering if elements not stored (legacy data) + subspace_elements = sorted(space_key, key=lambda e: e.symbol) + + # Reconstruct entries for this subspace from indices or entry objects + subspace_all_entries_data = pd_computed_data["all_entries"] + if subspace_all_entries_data and isinstance(subspace_all_entries_data[0], int): + # all_entries are stored as indices + subspace_all_entries = [self.all_entries[idx] for idx in subspace_all_entries_data] + else: + # all_entries are already entry objects (from from_dict reconstruction) + subspace_all_entries = subspace_all_entries_data + + # Reconstruct PhaseDiagram with its computed_data + pd_computed_data_with_entries = pd_computed_data.copy() + pd_computed_data_with_entries["all_entries"] = subspace_all_entries + pd_computed_data_with_entries["elements"] = subspace_elements + + # Convert qhull_entries indices back to entries + qhull_entries_data = pd_computed_data["qhull_entries"] + if qhull_entries_data and isinstance(qhull_entries_data[0], int): + # qhull_entries are stored as indices into subspace_all_entries + pd_computed_data_with_entries["qhull_entries"] = [ + subspace_all_entries[idx] for idx in qhull_entries_data + ] + else: + # qhull_entries are already entry objects + pd_computed_data_with_entries["qhull_entries"] = qhull_entries_data + + # Convert el_refs indices back to entries + el_refs_data = pd_computed_data["el_refs"] + if el_refs_data and isinstance(el_refs_data[0][1], int): + # el_refs are stored as indices into subspace_all_entries + pd_computed_data_with_entries["el_refs"] = [ + (Element(el_symbol), subspace_all_entries[idx]) for el_symbol, idx in el_refs_data + ] + else: + # el_refs are already entry objects + pd_computed_data_with_entries["el_refs"] = [ + (Element(el_symbol), entry) for el_symbol, entry in el_refs_data + ] + + self.pds[space_key] = PhaseDiagram( + subspace_all_entries, elements=subspace_elements, computed_data=pd_computed_data_with_entries + ) # Add terminal elements as we may not have PD patches including them # NOTE add el_refs in case no multielement entries are present for el @@ -1736,64 +1951,239 @@ def as_dict(self) -> dict[str, Any]: """Write the entries and elements used to construct the PatchedPhaseDiagram to a dictionary. - NOTE unlike PhaseDiagram the computation involved in constructing the - PatchedPhaseDiagram is not saved on serialization. This is done because - hierarchically calling the `PhaseDiagram.as_dict()` method would break the - link in memory between entries in overlapping patches leading to a - ballooning of the amount of memory used. - - NOTE For memory efficiency the best way to store patched phase diagrams is - via pickling. As this allows all the entries in overlapping patches to share - the same id in memory when unpickling. - Returns: dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram. """ + unique_entry_dicts: list[dict[str, Any]] = [] + entry_dict_to_index = {} + all_entry_indices = [] + + for entry in self.all_entries: + entry_dict = entry.as_dict() + entry_key = orjson.dumps(entry_dict, option=orjson.OPT_SORT_KEYS).decode() + + if entry_key not in entry_dict_to_index: + entry_dict_to_index[entry_key] = len(unique_entry_dicts) + unique_entry_dicts.append(entry_dict) + + all_entry_indices.append(entry_dict_to_index[entry_key]) + + computed_data = self.computed_data.copy() + + computed_data["elements"] = [e.symbol for e in self.elements] + + qhull_entries_data = computed_data["qhull_entries"] + if qhull_entries_data and not isinstance(qhull_entries_data[0], int): + qhull_entry_indices = [self.all_entries.index(entry) for entry in qhull_entries_data] + else: + qhull_entry_indices = qhull_entries_data + + computed_data["all_entries"] = all_entry_indices + + qhull_entry_indices_remapped = [all_entry_indices[idx] for idx in qhull_entry_indices] + computed_data["qhull_entries"] = qhull_entry_indices_remapped + + el_refs_data = computed_data["el_refs"] + if el_refs_data and not isinstance(el_refs_data[0][1], int): + el_refs_indices = [(el_symbol, self.all_entries.index(entry)) for el_symbol, entry in el_refs_data] + else: + el_refs_indices = el_refs_data + + computed_data["el_refs"] = [ + (el_symbol.symbol if isinstance(el_symbol, Element) else el_symbol, all_entry_indices[idx]) + for el_symbol, idx in el_refs_indices + ] + + pds_remapped = {} + for space_key, pd_data in computed_data["pds"].items(): + space_key_serialized = "-".join(sorted(el.symbol if isinstance(el, Element) else el for el in space_key)) + subspace_all_entries_data = pd_data["all_entries"] + if subspace_all_entries_data and not isinstance(subspace_all_entries_data[0], int): + subspace_all_entry_indices_orig = [self.all_entries.index(entry) for entry in subspace_all_entries_data] + else: + subspace_all_entry_indices_orig = subspace_all_entries_data + + subspace_all_entry_indices_remapped = [all_entry_indices[idx] for idx in subspace_all_entry_indices_orig] + + subspace_qhull_entries_data = pd_data["qhull_entries"] + if subspace_qhull_entries_data and not isinstance(subspace_qhull_entries_data[0], int): + subspace_qhull_indices_orig = [ + subspace_all_entries_data.index(entry) for entry in subspace_qhull_entries_data + ] + else: + subspace_qhull_indices_orig = subspace_qhull_entries_data + + subspace_qhull_indices_remapped = [ + subspace_all_entry_indices_remapped[idx] for idx in subspace_qhull_indices_orig + ] + + subspace_el_refs_data = pd_data["el_refs"] + if subspace_el_refs_data and not isinstance(subspace_el_refs_data[0][1], int): + subspace_el_refs_indices_orig = [ + (el_symbol, subspace_all_entries_data.index(entry)) for el_symbol, entry in subspace_el_refs_data + ] + else: + subspace_el_refs_indices_orig = subspace_el_refs_data + + subspace_el_refs_remapped = [ + ( + el_symbol.symbol if isinstance(el_symbol, Element) else el_symbol, + subspace_all_entry_indices_remapped[idx], + ) + for el_symbol, idx in subspace_el_refs_indices_orig + ] + + qhull_data = pd_data["qhull_data"] + if isinstance(qhull_data, np.ndarray): + qhull_data = qhull_data.tolist() + + facets = pd_data["facets"] + facets = [facet.tolist() for facet in facets] + + simplexes = pd_data["simplexes"] + simplexes = [{**s.as_dict(), "coords": s.as_dict()["coords"].tolist()} for s in simplexes] + + elements_data = pd_data.get("elements") + if elements_data: + elements_serialized = [el.symbol if isinstance(el, Element) else el for el in elements_data] + else: + elements_serialized = None + + pds_remapped[space_key_serialized] = { + "all_entries": subspace_all_entry_indices_remapped, + "qhull_entries": subspace_qhull_indices_remapped, + "facets": facets, + "simplexes": simplexes, + "qhull_data": qhull_data, + "dim": pd_data["dim"], + "el_refs": subspace_el_refs_remapped, + **({"elements": elements_serialized} if elements_serialized else {}), + } + computed_data["pds"] = pds_remapped + + # Add spaces to computed_data as tuples of element symbols + computed_data["spaces"] = [ + tuple(sorted(el.symbol if isinstance(el, Element) else el for el in space)) for space in self.spaces + ] + return { "@module": type(self).__module__, "@class": type(self).__name__, - "all_entries": [entry.as_dict() for entry in self.all_entries], - "elements": [entry.as_dict() for entry in self.elements], + "elements": [e.symbol for e in self.elements], + "computed_data": computed_data | {"unique_entries": unique_entry_dicts}, } @classmethod def from_dict(cls, dct: dict) -> Self: """Reconstruct PatchedPhaseDiagram from dictionary serialization. - NOTE unlike PhaseDiagram the computation involved in constructing the - PatchedPhaseDiagram is not saved on serialization. This is done because - hierarchically calling the `PhaseDiagram.as_dict()` method would break the - link in memory between entries in overlapping patches leading to a - ballooning of the amount of memory used. - - NOTE For memory efficiency the best way to store patched phase diagrams is - via pickling. As this allows all the entries in overlapping patches to share - the same id in memory when unpickling. - Args: dct (dict): dictionary representation of PatchedPhaseDiagram. Returns: PatchedPhaseDiagram """ - entries = [MontyDecoder().process_decoded(entry) for entry in dct["all_entries"]] - elements = [Element.from_dict(elem) for elem in dct["elements"]] + computed_data = dct.get("computed_data") + elements = [Element(elem) for elem in dct["elements"]] + + if computed_data and "unique_entries" in computed_data: + unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] + all_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] + computed_data_reconstructed = computed_data.copy() + computed_data_reconstructed["all_entries"] = all_entries + computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]] + computed_data_reconstructed["spaces"] = computed_data["spaces"] + computed_data_reconstructed["qhull_entries"] = [all_entries[idx] for idx in computed_data["qhull_entries"]] + computed_data_reconstructed["el_refs"] = [ + (Element(elem), all_entries[idx]) for elem, idx in computed_data["el_refs"] + ] + + pds_reconstructed = {} + for space_key, pd_data in computed_data["pds"].items(): + space_key = frozenset(Element(el) for el in space_key.split("-")) + subspace_all_entries = [all_entries[idx] for idx in pd_data["all_entries"]] + # Create mapping from global index to subspace index for all_entries + global_to_subspace_idx = { + global_idx: sub_idx for sub_idx, global_idx in enumerate(pd_data["all_entries"]) + } + + # Map qhull_entries from global indices to subspace indices + # pd_data["qhull_entries"] are indices into global all_entries (after unique_entries remapping) + # These should all be in pd_data["all_entries"], so we can map directly + # IMPORTANT: Preserve order as facets are indices into this list + subspace_qhull_indices = [] + for global_idx in pd_data["qhull_entries"]: + if global_idx in global_to_subspace_idx: + subspace_qhull_indices.append(global_to_subspace_idx[global_idx]) + else: + # This shouldn't happen, but if it does, we need to handle it + # Find the entry in subspace_all_entries by equality + entry = all_entries[global_idx] + for sub_idx, sub_entry in enumerate(subspace_all_entries): + if entry == sub_entry: + subspace_qhull_indices.append(sub_idx) + break + else: + raise ValueError( + f"qhull_entry at global index {global_idx} not found in subspace_all_entries" + ) + + # Map el_refs from global indices to subspace indices + subspace_el_refs_indices = [] + for el_symbol, global_idx in pd_data["el_refs"]: + if global_idx in global_to_subspace_idx: + subspace_el_refs_indices.append((el_symbol, global_to_subspace_idx[global_idx])) + + facets = [np.array(facet, dtype=int) for facet in pd_data["facets"]] + simplexes = pd_data["simplexes"] + if isinstance(simplexes, list) and len(simplexes) > 0: + simplexes = [MontyDecoder().process_decoded(s) for s in simplexes] + + subspace_elements = pd_data.get("elements") + if subspace_elements is not None: + subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in subspace_elements] + + pds_reconstructed[space_key] = { + "all_entries": subspace_all_entries, + "qhull_entries": subspace_qhull_indices, # Store as indices into subspace_all_entries + "facets": facets, + "simplexes": simplexes, + "qhull_data": np.array(pd_data["qhull_data"]), + "dim": pd_data["dim"], + "el_refs": subspace_el_refs_indices, # Store as indices into subspace_all_entries + **({"elements": subspace_elements} if subspace_elements is not None else {}), + } + computed_data_reconstructed["pds"] = pds_reconstructed + + return cls(entries=all_entries, elements=elements, computed_data=computed_data_reconstructed) + + # Handle old format (backwards compatibility) + if "unique_entries" in dct: + unique_entries = [MontyDecoder().process_decoded(entry) for entry in dct["unique_entries"]] + entries = [unique_entries[idx] for idx in dct["all_entries"]] + elif "all_entries" in dct: + entries = [MontyDecoder().process_decoded(entry) for entry in dct["all_entries"]] + else: + raise ValueError("Invalid dictionary format: missing 'all_entries' or 'computed_data'") + return cls(entries, elements) @staticmethod - def remove_redundant_spaces(spaces, keep_all_spaces=False): + def remove_redundant_spaces( + spaces: set[frozenset[Element]], keep_all_spaces: bool = False + ) -> set[frozenset[Element]]: if keep_all_spaces or len(spaces) <= 1: return spaces # Sort spaces by size in descending order and pre-compute lengths sorted_spaces = sorted(spaces, key=len, reverse=True) - result = [] + result = list() for idx, space_i in enumerate(sorted_spaces): if not any(space_i.issubset(larger_space) for larger_space in sorted_spaces[:idx]): result.append(space_i) - return result + return set(result) # NOTE following methods are inherited unchanged from PhaseDiagram: # __repr__, @@ -1883,18 +2273,6 @@ def get_decomp_and_e_above_hull( on_error=on_error, ) - def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]: - """ - Args: - space (frozenset[Element]): chemical space of the form A-B-X. - - Returns: - space, PhaseDiagram for the given chemical space - """ - space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces, strict=True) if space.issuperset(s)] - - return space, PhaseDiagram(space_entries) - # NOTE the following functions are not implemented for PatchedPhaseDiagram def _get_facet_and_simplex(self): diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 2359c1162cc..1e2ccac4336 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -647,10 +647,29 @@ def test_as_from_dict(self): assert MontyDecoder().process_decoded(pd_dict).as_dict() == pd_dict for entry in self.pd.all_entries: - _decomp_rpd, e_above_hull_rppd = reconstructed_pd.get_decomp_and_e_above_hull(entry) - _decomp_pd, e_above_hull_ppd = self.pd.get_decomp_and_e_above_hull(entry) - # assert decomp_rpd == decomp_pd, f"Decomposition for {entry} is not correct!" - assert np.isclose(e_above_hull_rppd, e_above_hull_ppd) + # NOTE: allow_negative=True is necessary due to fp errors we see in the decomposition + decomp_rpd, e_above_hull_rppd = reconstructed_pd.get_decomp_and_e_above_hull(entry, allow_negative=True) + decomp_pd, e_above_hull_ppd = self.pd.get_decomp_and_e_above_hull(entry) + + # Check that decompositions sum to the original composition + for decomp, name in [(decomp_pd, "pd"), (decomp_rpd, "rpd")]: + decomp_comp = Composition({}) + for e, amount in decomp.items(): + comp_scaled = e.composition.fractional_composition * amount + decomp_comp += comp_scaled + assert decomp_comp.almost_equals( + entry.composition.fractional_composition, rtol=0, atol=Composition.amount_tolerance + ), ( + f"Decomposition for {entry} from {name} does not sum to original composition! " + f"Expected {entry.composition}, got {decomp_comp}" + ) + + # Compare decompositions by matching entries by composition + # since serialization creates new entry objects with potentially slightly different energies + decomp_pd_by_comp = {e.composition: amount for e, amount in decomp_pd.items()} + decomp_rpd_by_comp = {e.composition: amount for e, amount in decomp_rpd.items()} + assert decomp_pd_by_comp == approx(decomp_rpd_by_comp), f"Decomposition for {entry} is not correct!" + assert np.isclose(e_above_hull_rppd, e_above_hull_ppd, rtol=1e-4, atol=1e-4) def test_read_json(self): dumpfn(self.pd, f"{self.tmp_path}/pd.json") @@ -847,19 +866,19 @@ def test_as_from_dict(self): assert ppd_dict["@module"] == type(self.ppd).__module__ assert ppd_dict["@class"] == type(self.ppd).__name__ - # # Check new format with computed_data and deduplicated entries - # assert "computed_data" in ppd_dict - # computed_data = ppd_dict["computed_data"] - # assert "unique_entries" in computed_data - # assert "all_entries" in computed_data - # assert isinstance(computed_data["all_entries"], list) - # assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) + # Check new format with computed_data and deduplicated entries + assert "computed_data" in ppd_dict + computed_data = ppd_dict["computed_data"] + assert "unique_entries" in computed_data + assert "all_entries" in computed_data + assert isinstance(computed_data["all_entries"], list) + assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) - # # Verify entries can be reconstructed from indices - # unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] - # reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] - # assert len(reconstructed_entries) == len(self.ppd.all_entries) - # assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] + # Verify entries can be reconstructed from indices + unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] + reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] + assert len(reconstructed_entries) == len(self.ppd.all_entries) + assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] reconstructed_ppd = PatchedPhaseDiagram.from_dict(ppd_dict) reconstructed_dict = reconstructed_ppd.as_dict() @@ -873,7 +892,7 @@ def test_as_from_dict(self): (self.ppd, "ppd"), (reconstructed_ppd, "rppd"), ]: - decomp, e_above_hull = pd.get_decomp_and_e_above_hull(entry, check_stable=True) + decomp, e_above_hull = pd.get_decomp_and_e_above_hull(entry, check_stable=True, allow_negative=True) decomp_comp = Composition({}) for e, amount in decomp.items(): comp_scaled = e.composition.fractional_composition * amount From 094e2015ffbde40f294ebaa151cb7caa83ac10c9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 12 Nov 2025 20:30:31 -0500 Subject: [PATCH 5/8] working? --- src/pymatgen/analysis/phase_diagram.py | 389 ++++++++----------------- tests/analysis/test_phase_diagram.py | 3 +- 2 files changed, 131 insertions(+), 261 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 13c4f1882fe..677c92ba494 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -1712,6 +1712,8 @@ def _compute( extra = set(el_refs) - set(elements) raise ValueError(f"There are more terminal elements than dimensions: {extra}") + entry_to_index = {entry: idx for idx, entry in enumerate(all_entries)} + data = np.array( [ [ @@ -1731,16 +1733,10 @@ def _compute( inds.extend([min_entries.index(el) for el in el_refs.values()]) qhull_entries = tuple(min_entries[idx] for idx in inds) - # make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable - # prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically) qhull_spaces = tuple(frozenset(entry.elements) for entry in qhull_entries) - # Get all unique chemical spaces spaces = {s for s in qhull_spaces if len(s) > 1} - - # Remove redundant chemical spaces spaces = PatchedPhaseDiagram.remove_redundant_spaces(spaces, keep_all_spaces) - spaces_list = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last # Build PhaseDiagrams for each space and collect their computed_data @@ -1749,35 +1745,14 @@ def _compute( space_entries = [e for e, s in zip(qhull_entries, qhull_spaces, strict=True) if space.issuperset(s)] pd = PhaseDiagram(space_entries) - # Get indices into all_entries for this subspace's all_entries - # pd.all_entries are the entries that PhaseDiagram computed with - # They should be in all_entries, so we can use index() which uses object identity - # If that fails, entries may be equal but different objects, so fall back to equality - # IMPORTANT: We must preserve order to keep facets valid - subspace_all_entry_indices = [] - used_indices = set() # Track used indices to avoid duplicates when using equality fallback - for pd_entry in pd.all_entries: - try: - idx = all_entries.index(pd_entry) - subspace_all_entry_indices.append(idx) - used_indices.add(idx) - except ValueError: - # Entry not found by identity, try equality - # But skip indices we've already used to preserve order - for idx, global_entry in enumerate(all_entries): - if idx not in used_indices and pd_entry == global_entry: - subspace_all_entry_indices.append(idx) - used_indices.add(idx) - break - else: - raise ValueError(f"pd.all_entries entry {pd_entry} not found in all_entries") + # NOTE: We must preserve order to keep facets valid + try: + subspace_all_entry_indices = [entry_to_index[pd_entry] for pd_entry in pd.all_entries] + except KeyError as exc: + missing_entry = exc.args[0] + raise ValueError(f"pd.all_entries entry {missing_entry} not found in all_entries") from exc - # Get indices into this subspace's all_entries for qhull_entries - # pd.qhull_entries are entries from pd.all_entries, so we can use index subspace_qhull_entry_indices = [pd.all_entries.index(entry) for entry in pd.qhull_entries] - - # Convert el_refs to indices into subspace's all_entries - # el_refs entries are from pd.all_entries, so we can use index subspace_el_refs = [(el, pd.all_entries.index(entry)) for el, entry in pd.computed_data["el_refs"]] pds_computed_data[space] = { @@ -1841,22 +1816,20 @@ def __init__( # Convert el_refs from [(el_symbol, index), ...] or [(el_symbol, Entry), ...] to {Element: Entry} el_refs_data = computed_data["el_refs"] - if el_refs_data and isinstance(el_refs_data[0][1], int): - # el_refs are stored as indices - self.el_refs = {Element(el_symbol): self.all_entries[idx] for el_symbol, idx in el_refs_data} - else: - # el_refs are already entry objects (from from_dict reconstruction) - self.el_refs = {Element(el_symbol): entry for el_symbol, entry in el_refs_data} + if el_refs_data and not isinstance(el_refs_data[0][1], int): + raise TypeError("computed_data['el_refs'] must contain integer indices") + + def _ensure_element(el: Element | str) -> Element: + return el if isinstance(el, Element) else Element(el) + + self.el_refs = {_ensure_element(el_symbol): self.all_entries[idx] for el_symbol, idx in el_refs_data} # Convert qhull_entries from indices to entry objects # When from _compute(), qhull_entries are indices; when from from_dict(), they're already entries qhull_entries_data = computed_data["qhull_entries"] - if qhull_entries_data and isinstance(qhull_entries_data[0], int): - # qhull_entries are stored as indices into all_entries - self.qhull_entries = tuple(self.all_entries[idx] for idx in qhull_entries_data) - else: - # qhull_entries are already entry objects (from from_dict reconstruction) - self.qhull_entries = tuple(qhull_entries_data) + if qhull_entries_data and not isinstance(qhull_entries_data[0], int): + raise TypeError("computed_data['qhull_entries'] must be indices into all_entries") + self.qhull_entries = tuple(self.all_entries[idx] for idx in qhull_entries_data) self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries) # Convert spaces from tuples (serialized) or frozensets (in-memory) to frozensets @@ -1871,59 +1844,35 @@ def __init__( # Handle both frozenset (in-memory) and tuple (serialized) keys if not isinstance(space_key, frozenset): space_key = frozenset(Element(el) for el in space_key) - stored_elements = pd_computed_data.get("elements") - if stored_elements: - subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in stored_elements] - else: - # Fallback to deterministic ordering if elements not stored (legacy data) - subspace_elements = sorted(space_key, key=lambda e: e.symbol) - # Reconstruct entries for this subspace from indices or entry objects + subspace_elements = [ + Element(el) if not isinstance(el, Element) else el for el in pd_computed_data["elements"] + ] subspace_all_entries_data = pd_computed_data["all_entries"] - if subspace_all_entries_data and isinstance(subspace_all_entries_data[0], int): - # all_entries are stored as indices - subspace_all_entries = [self.all_entries[idx] for idx in subspace_all_entries_data] - else: - # all_entries are already entry objects (from from_dict reconstruction) - subspace_all_entries = subspace_all_entries_data - - # Reconstruct PhaseDiagram with its computed_data + if subspace_all_entries_data and not isinstance(subspace_all_entries_data[0], int): + raise TypeError("subspace 'all_entries' must contain indices into PatchedPhaseDiagram.all_entries") + subspace_all_entries = [self.all_entries[idx] for idx in subspace_all_entries_data] pd_computed_data_with_entries = pd_computed_data.copy() pd_computed_data_with_entries["all_entries"] = subspace_all_entries pd_computed_data_with_entries["elements"] = subspace_elements - - # Convert qhull_entries indices back to entries - qhull_entries_data = pd_computed_data["qhull_entries"] - if qhull_entries_data and isinstance(qhull_entries_data[0], int): - # qhull_entries are stored as indices into subspace_all_entries - pd_computed_data_with_entries["qhull_entries"] = [ - subspace_all_entries[idx] for idx in qhull_entries_data - ] - else: - # qhull_entries are already entry objects - pd_computed_data_with_entries["qhull_entries"] = qhull_entries_data - - # Convert el_refs indices back to entries - el_refs_data = pd_computed_data["el_refs"] - if el_refs_data and isinstance(el_refs_data[0][1], int): - # el_refs are stored as indices into subspace_all_entries - pd_computed_data_with_entries["el_refs"] = [ - (Element(el_symbol), subspace_all_entries[idx]) for el_symbol, idx in el_refs_data - ] - else: - # el_refs are already entry objects - pd_computed_data_with_entries["el_refs"] = [ - (Element(el_symbol), entry) for el_symbol, entry in el_refs_data - ] - + pd_computed_data_with_entries["qhull_entries"] = [ + subspace_all_entries[idx] for idx in pd_computed_data["qhull_entries"] + ] + pd_computed_data_with_entries["el_refs"] = [ + ( + _ensure_element(el_symbol), + subspace_all_entries[idx], + ) + for el_symbol, idx in pd_computed_data["el_refs"] + ] self.pds[space_key] = PhaseDiagram( subspace_all_entries, elements=subspace_elements, computed_data=pd_computed_data_with_entries ) - # Add terminal elements as we may not have PD patches including them # NOTE add el_refs in case no multielement entries are present for el - _stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries} - self._stable_entries = tuple(_stable_entries | {*self.el_refs.values()}) + self._stable_entries = tuple( + {se for pd in self.pds.values() for se in pd._stable_entries} | {*self.el_refs.values()} + ) self._stable_spaces = tuple(frozenset(entry.elements) for entry in self._stable_entries) def __repr__(self): @@ -1948,128 +1897,66 @@ def __contains__(self, item: frozenset[Element]) -> bool: return item in self.pds def as_dict(self) -> dict[str, Any]: - """Write the entries and elements used to construct the PatchedPhaseDiagram - to a dictionary. - - Returns: - dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram. - """ + """Write the entries and elements used to construct the PatchedPhaseDiagram to a dictionary.""" unique_entry_dicts: list[dict[str, Any]] = [] - entry_dict_to_index = {} - all_entry_indices = [] + entry_dict_to_index: dict[str, int] = {} + all_entry_indices: list[int] = [] for entry in self.all_entries: entry_dict = entry.as_dict() entry_key = orjson.dumps(entry_dict, option=orjson.OPT_SORT_KEYS).decode() - if entry_key not in entry_dict_to_index: entry_dict_to_index[entry_key] = len(unique_entry_dicts) unique_entry_dicts.append(entry_dict) - all_entry_indices.append(entry_dict_to_index[entry_key]) - computed_data = self.computed_data.copy() - - computed_data["elements"] = [e.symbol for e in self.elements] - - qhull_entries_data = computed_data["qhull_entries"] - if qhull_entries_data and not isinstance(qhull_entries_data[0], int): - qhull_entry_indices = [self.all_entries.index(entry) for entry in qhull_entries_data] - else: - qhull_entry_indices = qhull_entries_data + entry_to_unique_index = dict(zip(self.all_entries, all_entry_indices, strict=True)) - computed_data["all_entries"] = all_entry_indices - - qhull_entry_indices_remapped = [all_entry_indices[idx] for idx in qhull_entry_indices] - computed_data["qhull_entries"] = qhull_entry_indices_remapped - - el_refs_data = computed_data["el_refs"] - if el_refs_data and not isinstance(el_refs_data[0][1], int): - el_refs_indices = [(el_symbol, self.all_entries.index(entry)) for el_symbol, entry in el_refs_data] - else: - el_refs_indices = el_refs_data - - computed_data["el_refs"] = [ - (el_symbol.symbol if isinstance(el_symbol, Element) else el_symbol, all_entry_indices[idx]) - for el_symbol, idx in el_refs_indices - ] - - pds_remapped = {} - for space_key, pd_data in computed_data["pds"].items(): - space_key_serialized = "-".join(sorted(el.symbol if isinstance(el, Element) else el for el in space_key)) - subspace_all_entries_data = pd_data["all_entries"] - if subspace_all_entries_data and not isinstance(subspace_all_entries_data[0], int): - subspace_all_entry_indices_orig = [self.all_entries.index(entry) for entry in subspace_all_entries_data] - else: - subspace_all_entry_indices_orig = subspace_all_entries_data - - subspace_all_entry_indices_remapped = [all_entry_indices[idx] for idx in subspace_all_entry_indices_orig] - - subspace_qhull_entries_data = pd_data["qhull_entries"] - if subspace_qhull_entries_data and not isinstance(subspace_qhull_entries_data[0], int): - subspace_qhull_indices_orig = [ - subspace_all_entries_data.index(entry) for entry in subspace_qhull_entries_data - ] - else: - subspace_qhull_indices_orig = subspace_qhull_entries_data - - subspace_qhull_indices_remapped = [ - subspace_all_entry_indices_remapped[idx] for idx in subspace_qhull_indices_orig - ] + computed_data: dict[str, Any] = { + "elements": [element.symbol for element in self.elements], + "all_entries": all_entry_indices, + "qhull_entries": [entry_to_unique_index[entry] for entry in self.qhull_entries], + "el_refs": [(element.symbol, entry_to_unique_index[entry]) for element, entry in self.el_refs.items()], + "dim": self.dim, + } - subspace_el_refs_data = pd_data["el_refs"] - if subspace_el_refs_data and not isinstance(subspace_el_refs_data[0][1], int): - subspace_el_refs_indices_orig = [ - (el_symbol, subspace_all_entries_data.index(entry)) for el_symbol, entry in subspace_el_refs_data - ] - else: - subspace_el_refs_indices_orig = subspace_el_refs_data + spaces_serialized = [tuple(sorted(element.symbol for element in space)) for space in self.spaces] - subspace_el_refs_remapped = [ - ( - el_symbol.symbol if isinstance(el_symbol, Element) else el_symbol, - subspace_all_entry_indices_remapped[idx], - ) - for el_symbol, idx in subspace_el_refs_indices_orig - ] + pds_remapped: dict[str, Any] = {} + for space, pd in self.pds.items(): + space_key_serialized = "-".join(sorted(element.symbol for element in space)) - qhull_data = pd_data["qhull_data"] - if isinstance(qhull_data, np.ndarray): - qhull_data = qhull_data.tolist() + subspace_all_entry_indices = [entry_to_unique_index[entry] for entry in pd.all_entries] + subspace_qhull_indices = [entry_to_unique_index[entry] for entry in pd.qhull_entries] + subspace_el_refs = [(element.symbol, entry_to_unique_index[entry]) for element, entry in pd.el_refs.items()] - facets = pd_data["facets"] - facets = [facet.tolist() for facet in facets] + qhull_data = np.asarray(pd.qhull_data).tolist() + facets = [facet.tolist() for facet in pd.facets] - simplexes = pd_data["simplexes"] - simplexes = [{**s.as_dict(), "coords": s.as_dict()["coords"].tolist()} for s in simplexes] - - elements_data = pd_data.get("elements") - if elements_data: - elements_serialized = [el.symbol if isinstance(el, Element) else el for el in elements_data] - else: - elements_serialized = None + simplexes = [] + for simplex in pd.simplexes: + simplex_dict = simplex.as_dict() + simplex_dict["coords"] = np.asarray(simplex_dict["coords"]).tolist() + simplexes.append(simplex_dict) pds_remapped[space_key_serialized] = { - "all_entries": subspace_all_entry_indices_remapped, - "qhull_entries": subspace_qhull_indices_remapped, + "all_entries": subspace_all_entry_indices, + "qhull_entries": subspace_qhull_indices, "facets": facets, "simplexes": simplexes, "qhull_data": qhull_data, - "dim": pd_data["dim"], - "el_refs": subspace_el_refs_remapped, - **({"elements": elements_serialized} if elements_serialized else {}), + "dim": pd.dim, + "el_refs": subspace_el_refs, + "elements": [element.symbol for element in pd.elements], } - computed_data["pds"] = pds_remapped - # Add spaces to computed_data as tuples of element symbols - computed_data["spaces"] = [ - tuple(sorted(el.symbol if isinstance(el, Element) else el for el in space)) for space in self.spaces - ] + computed_data["spaces"] = spaces_serialized + computed_data["pds"] = pds_remapped return { "@module": type(self).__module__, "@class": type(self).__name__, - "elements": [e.symbol for e in self.elements], + "elements": [element.symbol for element in self.elements], "computed_data": computed_data | {"unique_entries": unique_entry_dicts}, } @@ -2083,90 +1970,72 @@ def from_dict(cls, dct: dict) -> Self: Returns: PatchedPhaseDiagram """ - computed_data = dct.get("computed_data") + computed_data = dct["computed_data"] elements = [Element(elem) for elem in dct["elements"]] + decoder = MontyDecoder() + unique_entries = [decoder.process_decoded(entry) for entry in computed_data["unique_entries"]] + global_unique_indices = computed_data["all_entries"] + all_entries = [unique_entries[idx] for idx in global_unique_indices] - if computed_data and "unique_entries" in computed_data: - unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] - all_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] - computed_data_reconstructed = computed_data.copy() - computed_data_reconstructed["all_entries"] = all_entries - computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]] - computed_data_reconstructed["spaces"] = computed_data["spaces"] - computed_data_reconstructed["qhull_entries"] = [all_entries[idx] for idx in computed_data["qhull_entries"]] - computed_data_reconstructed["el_refs"] = [ - (Element(elem), all_entries[idx]) for elem, idx in computed_data["el_refs"] - ] + computed_data_reconstructed = computed_data.copy() + computed_data_reconstructed["all_entries"] = all_entries + computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]] + computed_data_reconstructed["spaces"] = computed_data["spaces"] - pds_reconstructed = {} - for space_key, pd_data in computed_data["pds"].items(): - space_key = frozenset(Element(el) for el in space_key.split("-")) - subspace_all_entries = [all_entries[idx] for idx in pd_data["all_entries"]] - # Create mapping from global index to subspace index for all_entries - global_to_subspace_idx = { - global_idx: sub_idx for sub_idx, global_idx in enumerate(pd_data["all_entries"]) - } + unique_idx_to_global_idx: dict[int, int] = {} + for global_idx, unique_idx in enumerate(global_unique_indices): + unique_idx_to_global_idx.setdefault(unique_idx, global_idx) - # Map qhull_entries from global indices to subspace indices - # pd_data["qhull_entries"] are indices into global all_entries (after unique_entries remapping) - # These should all be in pd_data["all_entries"], so we can map directly - # IMPORTANT: Preserve order as facets are indices into this list - subspace_qhull_indices = [] - for global_idx in pd_data["qhull_entries"]: - if global_idx in global_to_subspace_idx: - subspace_qhull_indices.append(global_to_subspace_idx[global_idx]) - else: - # This shouldn't happen, but if it does, we need to handle it - # Find the entry in subspace_all_entries by equality - entry = all_entries[global_idx] - for sub_idx, sub_entry in enumerate(subspace_all_entries): - if entry == sub_entry: - subspace_qhull_indices.append(sub_idx) - break - else: - raise ValueError( - f"qhull_entry at global index {global_idx} not found in subspace_all_entries" - ) - - # Map el_refs from global indices to subspace indices - subspace_el_refs_indices = [] - for el_symbol, global_idx in pd_data["el_refs"]: - if global_idx in global_to_subspace_idx: - subspace_el_refs_indices.append((el_symbol, global_to_subspace_idx[global_idx])) - - facets = [np.array(facet, dtype=int) for facet in pd_data["facets"]] - simplexes = pd_data["simplexes"] - if isinstance(simplexes, list) and len(simplexes) > 0: - simplexes = [MontyDecoder().process_decoded(s) for s in simplexes] - - subspace_elements = pd_data.get("elements") - if subspace_elements is not None: - subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in subspace_elements] - - pds_reconstructed[space_key] = { - "all_entries": subspace_all_entries, - "qhull_entries": subspace_qhull_indices, # Store as indices into subspace_all_entries - "facets": facets, - "simplexes": simplexes, - "qhull_data": np.array(pd_data["qhull_data"]), - "dim": pd_data["dim"], - "el_refs": subspace_el_refs_indices, # Store as indices into subspace_all_entries - **({"elements": subspace_elements} if subspace_elements is not None else {}), - } - computed_data_reconstructed["pds"] = pds_reconstructed + def _global_index(unique_idx: int) -> int: + try: + return unique_idx_to_global_idx[unique_idx] + except KeyError as exc: + msg = f"unique entry index {unique_idx} missing from computed_data['all_entries']" + raise KeyError(msg) from exc + + computed_data_reconstructed["qhull_entries"] = [ + _global_index(unique_idx) for unique_idx in computed_data["qhull_entries"] + ] + computed_data_reconstructed["el_refs"] = [(elem, _global_index(idx)) for elem, idx in computed_data["el_refs"]] - return cls(entries=all_entries, elements=elements, computed_data=computed_data_reconstructed) + pds_reconstructed = {} + for space_key, pd_data in computed_data["pds"].items(): + space_key_frozen = frozenset(Element(el) for el in space_key.split("-")) + subspace_unique_indices = pd_data["all_entries"] + subspace_global_indices = [_global_index(unique_idx) for unique_idx in subspace_unique_indices] - # Handle old format (backwards compatibility) - if "unique_entries" in dct: - unique_entries = [MontyDecoder().process_decoded(entry) for entry in dct["unique_entries"]] - entries = [unique_entries[idx] for idx in dct["all_entries"]] - elif "all_entries" in dct: - entries = [MontyDecoder().process_decoded(entry) for entry in dct["all_entries"]] - else: - raise ValueError("Invalid dictionary format: missing 'all_entries' or 'computed_data'") + unique_idx_to_sub_idx: dict[int, int] = {} + for sub_idx, unique_idx in enumerate(subspace_unique_indices): + unique_idx_to_sub_idx.setdefault(unique_idx, sub_idx) + + def _subspace_index(unique_idx: int) -> int: + try: + return unique_idx_to_sub_idx[unique_idx] + except KeyError as exc: + msg = f"unique entry index {unique_idx} missing from subspace entries" + raise KeyError(msg) from exc + + facets = [np.array(facet, dtype=int) for facet in pd_data["facets"]] + simplexes = [decoder.process_decoded(simplex) for simplex in pd_data["simplexes"]] + subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in pd_data["elements"]] + + pds_reconstructed[space_key_frozen] = { + "all_entries": subspace_global_indices, + "qhull_entries": [_subspace_index(unique_idx) for unique_idx in pd_data["qhull_entries"]], + "facets": facets, + "simplexes": simplexes, + "qhull_data": np.array(pd_data["qhull_data"]), + "dim": pd_data["dim"], + "el_refs": [ + (Element(el) if not isinstance(el, Element) else el, _subspace_index(idx)) + for el, idx in pd_data["el_refs"] + ], + "elements": subspace_elements, + } + + computed_data_reconstructed["pds"] = pds_reconstructed - return cls(entries, elements) + return cls(entries=all_entries, elements=elements, computed_data=computed_data_reconstructed) @staticmethod def remove_redundant_spaces( diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 1e2ccac4336..8f23a24e9f0 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -647,7 +647,7 @@ def test_as_from_dict(self): assert MontyDecoder().process_decoded(pd_dict).as_dict() == pd_dict for entry in self.pd.all_entries: - # NOTE: allow_negative=True is necessary due to fp errors we see in the decomposition + # NOTE: allow_negative=True is necessary due to fp errors we see after serialization decomp_rpd, e_above_hull_rppd = reconstructed_pd.get_decomp_and_e_above_hull(entry, allow_negative=True) decomp_pd, e_above_hull_ppd = self.pd.get_decomp_and_e_above_hull(entry) @@ -892,6 +892,7 @@ def test_as_from_dict(self): (self.ppd, "ppd"), (reconstructed_ppd, "rppd"), ]: + # NOTE: allow_negative=True is necessary due to fp errors we see after serialization decomp, e_above_hull = pd.get_decomp_and_e_above_hull(entry, check_stable=True, allow_negative=True) decomp_comp = Composition({}) for e, amount in decomp.items(): From 5e54fd0a6ce83798c8f497ae821a80eb59c254d9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 15 Nov 2025 14:14:29 -0500 Subject: [PATCH 6/8] use hasmaps for speed --- src/pymatgen/analysis/phase_diagram.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 677c92ba494..30f903061db 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -393,7 +393,8 @@ def __init__( def as_dict(self): """Get MSONable dict representation of PhaseDiagram.""" - qhull_entry_indices = [self.all_entries.index(e) for e in self.qhull_entries] + entry_to_index = {entry: idx for idx, entry in enumerate(self.all_entries)} + qhull_entry_indices = [entry_to_index[e] for e in self.qhull_entries] # Create a copy of computed_data to avoid modifying the original computed_data = self.computed_data.copy() @@ -1752,8 +1753,9 @@ def _compute( missing_entry = exc.args[0] raise ValueError(f"pd.all_entries entry {missing_entry} not found in all_entries") from exc - subspace_qhull_entry_indices = [pd.all_entries.index(entry) for entry in pd.qhull_entries] - subspace_el_refs = [(el, pd.all_entries.index(entry)) for el, entry in pd.computed_data["el_refs"]] + pd_entry_to_index = {entry: idx for idx, entry in enumerate(pd.all_entries)} + subspace_qhull_entry_indices = [pd_entry_to_index[entry] for entry in pd.qhull_entries] + subspace_el_refs = [(el, pd_entry_to_index[entry]) for el, entry in pd.computed_data["el_refs"]] pds_computed_data[space] = { "all_entries": subspace_all_entry_indices, @@ -1770,8 +1772,8 @@ def _compute( "all_entries": all_entries, "elements": elements, "dim": dim, - "el_refs": [(el.symbol, all_entries.index(entry)) for el, entry in el_refs.items()], - "qhull_entries": [all_entries.index(entry) for entry in qhull_entries], + "el_refs": [(el.symbol, entry_to_index[entry]) for el, entry in el_refs.items()], + "qhull_entries": [entry_to_index[entry] for entry in qhull_entries], "spaces": spaces_list, "pds": pds_computed_data, } From 3f4fb88286d21a855be277e1018ea0684d343395 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 26 Nov 2025 11:49:04 -0500 Subject: [PATCH 7/8] fea: simplify and use ids for el_Refs in normal phasediagram --- src/pymatgen/analysis/phase_diagram.py | 84 ++++++-------------------- tests/analysis/test_phase_diagram.py | 10 ++- 2 files changed, 24 insertions(+), 70 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 30f903061db..be66b95ff1d 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -399,7 +399,7 @@ def as_dict(self): # Create a copy of computed_data to avoid modifying the original computed_data = self.computed_data.copy() computed_data["elements"] = [el.symbol for el in self.elements] - computed_data["el_refs"] = [(el.symbol, entry.as_dict()) for el, entry in computed_data["el_refs"]] + computed_data["el_refs"] = [(el.symbol, entry_to_index[entry]) for el, entry in computed_data["el_refs"]] computed_data["all_entries"] = [e.as_dict() for e in computed_data["all_entries"]] computed_data["qhull_entries"] = qhull_entry_indices computed_data["qhull_data"] = ( @@ -441,10 +441,8 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: computed_data = computed_data.copy() computed_data["qhull_entries"] = [entries[i] for i in computed_data["qhull_entries"]] computed_data["elements"] = [Element(el) for el in computed_data["elements"]] - # Keep el_refs as (str, Entry) format to match _compute() output - computed_data["el_refs"] = [ - (el_str, MontyDecoder().process_decoded(entry)) for el_str, entry in computed_data["el_refs"] - ] + # el_refs stored as (str, index) - convert to (str, Entry) + computed_data["el_refs"] = [(el_str, entries[idx]) for el_str, idx in computed_data["el_refs"]] return cls(entries, elements, computed_data=computed_data) @@ -1900,25 +1898,13 @@ def __contains__(self, item: frozenset[Element]) -> bool: def as_dict(self) -> dict[str, Any]: """Write the entries and elements used to construct the PatchedPhaseDiagram to a dictionary.""" - unique_entry_dicts: list[dict[str, Any]] = [] - entry_dict_to_index: dict[str, int] = {} - all_entry_indices: list[int] = [] - - for entry in self.all_entries: - entry_dict = entry.as_dict() - entry_key = orjson.dumps(entry_dict, option=orjson.OPT_SORT_KEYS).decode() - if entry_key not in entry_dict_to_index: - entry_dict_to_index[entry_key] = len(unique_entry_dicts) - unique_entry_dicts.append(entry_dict) - all_entry_indices.append(entry_dict_to_index[entry_key]) - - entry_to_unique_index = dict(zip(self.all_entries, all_entry_indices, strict=True)) + entry_to_index = {entry: idx for idx, entry in enumerate(self.all_entries)} computed_data: dict[str, Any] = { "elements": [element.symbol for element in self.elements], - "all_entries": all_entry_indices, - "qhull_entries": [entry_to_unique_index[entry] for entry in self.qhull_entries], - "el_refs": [(element.symbol, entry_to_unique_index[entry]) for element, entry in self.el_refs.items()], + "all_entries": [e.as_dict() for e in self.all_entries], + "qhull_entries": [entry_to_index[entry] for entry in self.qhull_entries], + "el_refs": [(element.symbol, entry_to_index[entry]) for element, entry in self.el_refs.items()], "dim": self.dim, } @@ -1928,9 +1914,9 @@ def as_dict(self) -> dict[str, Any]: for space, pd in self.pds.items(): space_key_serialized = "-".join(sorted(element.symbol for element in space)) - subspace_all_entry_indices = [entry_to_unique_index[entry] for entry in pd.all_entries] - subspace_qhull_indices = [entry_to_unique_index[entry] for entry in pd.qhull_entries] - subspace_el_refs = [(element.symbol, entry_to_unique_index[entry]) for element, entry in pd.el_refs.items()] + subspace_all_entry_indices = [entry_to_index[entry] for entry in pd.all_entries] + subspace_qhull_indices = [entry_to_index[entry] for entry in pd.qhull_entries] + subspace_el_refs = [(element.symbol, entry_to_index[entry]) for element, entry in pd.el_refs.items()] qhull_data = np.asarray(pd.qhull_data).tolist() facets = [facet.tolist() for facet in pd.facets] @@ -1959,7 +1945,7 @@ def as_dict(self) -> dict[str, Any]: "@module": type(self).__module__, "@class": type(self).__name__, "elements": [element.symbol for element in self.elements], - "computed_data": computed_data | {"unique_entries": unique_entry_dicts}, + "computed_data": computed_data, } @classmethod @@ -1975,63 +1961,33 @@ def from_dict(cls, dct: dict) -> Self: computed_data = dct["computed_data"] elements = [Element(elem) for elem in dct["elements"]] decoder = MontyDecoder() - unique_entries = [decoder.process_decoded(entry) for entry in computed_data["unique_entries"]] - global_unique_indices = computed_data["all_entries"] - all_entries = [unique_entries[idx] for idx in global_unique_indices] + all_entries = [decoder.process_decoded(entry) for entry in computed_data["all_entries"]] computed_data_reconstructed = computed_data.copy() computed_data_reconstructed["all_entries"] = all_entries computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]] - computed_data_reconstructed["spaces"] = computed_data["spaces"] - - unique_idx_to_global_idx: dict[int, int] = {} - for global_idx, unique_idx in enumerate(global_unique_indices): - unique_idx_to_global_idx.setdefault(unique_idx, global_idx) - - def _global_index(unique_idx: int) -> int: - try: - return unique_idx_to_global_idx[unique_idx] - except KeyError as exc: - msg = f"unique entry index {unique_idx} missing from computed_data['all_entries']" - raise KeyError(msg) from exc - - computed_data_reconstructed["qhull_entries"] = [ - _global_index(unique_idx) for unique_idx in computed_data["qhull_entries"] - ] - computed_data_reconstructed["el_refs"] = [(elem, _global_index(idx)) for elem, idx in computed_data["el_refs"]] + # qhull_entries and el_refs are already indices into all_entries pds_reconstructed = {} for space_key, pd_data in computed_data["pds"].items(): space_key_frozen = frozenset(Element(el) for el in space_key.split("-")) - subspace_unique_indices = pd_data["all_entries"] - subspace_global_indices = [_global_index(unique_idx) for unique_idx in subspace_unique_indices] - - unique_idx_to_sub_idx: dict[int, int] = {} - for sub_idx, unique_idx in enumerate(subspace_unique_indices): - unique_idx_to_sub_idx.setdefault(unique_idx, sub_idx) - - def _subspace_index(unique_idx: int) -> int: - try: - return unique_idx_to_sub_idx[unique_idx] - except KeyError as exc: - msg = f"unique entry index {unique_idx} missing from subspace entries" - raise KeyError(msg) from exc facets = [np.array(facet, dtype=int) for facet in pd_data["facets"]] simplexes = [decoder.process_decoded(simplex) for simplex in pd_data["simplexes"]] subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in pd_data["elements"]] + # Create mapping from global index to subspace index + subspace_indices = pd_data["all_entries"] + global_to_sub_idx = {global_idx: sub_idx for sub_idx, global_idx in enumerate(subspace_indices)} + pds_reconstructed[space_key_frozen] = { - "all_entries": subspace_global_indices, - "qhull_entries": [_subspace_index(unique_idx) for unique_idx in pd_data["qhull_entries"]], + "all_entries": subspace_indices, + "qhull_entries": [global_to_sub_idx[idx] for idx in pd_data["qhull_entries"]], "facets": facets, "simplexes": simplexes, "qhull_data": np.array(pd_data["qhull_data"]), "dim": pd_data["dim"], - "el_refs": [ - (Element(el) if not isinstance(el, Element) else el, _subspace_index(idx)) - for el, idx in pd_data["el_refs"] - ], + "el_refs": [(Element(el), global_to_sub_idx[idx]) for el, idx in pd_data["el_refs"]], "elements": subspace_elements, } diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 8f23a24e9f0..9bb1cd592d9 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -866,17 +866,15 @@ def test_as_from_dict(self): assert ppd_dict["@module"] == type(self.ppd).__module__ assert ppd_dict["@class"] == type(self.ppd).__name__ - # Check new format with computed_data and deduplicated entries + # Check format with computed_data assert "computed_data" in ppd_dict computed_data = ppd_dict["computed_data"] - assert "unique_entries" in computed_data assert "all_entries" in computed_data assert isinstance(computed_data["all_entries"], list) - assert all(isinstance(idx, int) for idx in computed_data["all_entries"]) + assert all(isinstance(entry, dict) for entry in computed_data["all_entries"]) - # Verify entries can be reconstructed from indices - unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]] - reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]] + # Verify entries can be reconstructed + reconstructed_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["all_entries"]] assert len(reconstructed_entries) == len(self.ppd.all_entries) assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements] From 082e4498e111cd838865899989c47697b1274f3c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 26 Nov 2025 15:51:52 -0500 Subject: [PATCH 8/8] fea: save less redundant stuff --- src/pymatgen/analysis/phase_diagram.py | 57 ++++++++------------------ 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index be66b95ff1d..648c2019afe 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -344,7 +344,7 @@ class PhaseDiagram(MSONable): def __init__( self, entries: Collection[Entry], - elements: Collection[Element] = (), + elements: Collection[Element] | None = None, *, computed_data: dict[str, Any] | None = None, ) -> None: @@ -366,7 +366,7 @@ def __init__( if not entries: raise ValueError("Unable to build phase diagram without entries.") - self.elements = elements + self.elements = list(elements) if elements else sorted({els for e in entries for els in e.elements}) self.entries = entries if computed_data is None: computed_data = self._compute() @@ -380,10 +380,9 @@ def __init__( self.computed_data = computed_data self.facets = computed_data["facets"] - self.simplexes = computed_data["simplexes"] + self.qhull_data = np.asarray(computed_data["qhull_data"]) + self.simplexes = [Simplex(self.qhull_data[facet, :-1]) for facet in self.facets] self.all_entries = computed_data["all_entries"] - self.qhull_data = computed_data["qhull_data"] - self.dim = computed_data["dim"] self.el_refs = dict(computed_data["el_refs"]) self.qhull_entries = tuple(computed_data["qhull_entries"]) self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries) @@ -398,7 +397,6 @@ def as_dict(self): # Create a copy of computed_data to avoid modifying the original computed_data = self.computed_data.copy() - computed_data["elements"] = [el.symbol for el in self.elements] computed_data["el_refs"] = [(el.symbol, entry_to_index[entry]) for el, entry in computed_data["el_refs"]] computed_data["all_entries"] = [e.as_dict() for e in computed_data["all_entries"]] computed_data["qhull_entries"] = qhull_entry_indices @@ -407,10 +405,8 @@ def as_dict(self): if isinstance(computed_data["qhull_data"], np.ndarray) else computed_data["qhull_data"] ) - computed_data["facets"] = [list(facet) for facet in computed_data["facets"]] - computed_data["simplexes"] = [ - {**s.as_dict(), "coords": s.as_dict()["coords"].tolist()} for s in computed_data["simplexes"] - ] + computed_data["facets"] = [list(facet) for facet in self.facets] + computed_data.pop("simplexes", None) # Reconstructed from qhull_data and facets return { "@module": type(self).__module__, @@ -440,16 +436,12 @@ def from_dict(cls, dct: dict[str, Any]) -> Self: # Reconstruct computed_data to match _compute() format: (str, Entry) tuples for el_refs computed_data = computed_data.copy() computed_data["qhull_entries"] = [entries[i] for i in computed_data["qhull_entries"]] - computed_data["elements"] = [Element(el) for el in computed_data["elements"]] # el_refs stored as (str, index) - convert to (str, Entry) computed_data["el_refs"] = [(el_str, entries[idx]) for el_str, idx in computed_data["el_refs"]] return cls(entries, elements, computed_data=computed_data) def _compute(self) -> dict[str, Any]: - if self.elements == (): - self.elements = sorted({els for e in self.entries for els in e.elements}) - elements = list(self.elements) dim = len(elements) @@ -507,19 +499,19 @@ def _compute(self) -> dict[str, Any]: final_facets.append(facet) facets = final_facets - simplexes = [Simplex(qhull_data[facet, :-1]) for facet in facets] - self.elements = elements return { "facets": facets, - "simplexes": simplexes, "all_entries": all_entries, "qhull_data": qhull_data, - "dim": dim, - # Dictionary with Element keys is not JSON-serializable - "el_refs": list(el_refs.items()), + "el_refs": list(el_refs.items()), # Dictionary with Element keys is not JSON-serializable "qhull_entries": qhull_entries, } + @property + def dim(self) -> int: + """The dimensionality of the phase diagram.""" + return len(self.elements) + def pd_coords(self, comp: Composition) -> np.ndarray: """ The phase diagram is generated in a reduced dimensional space @@ -1759,17 +1751,13 @@ def _compute( "all_entries": subspace_all_entry_indices, "qhull_entries": subspace_qhull_entry_indices, "facets": pd.computed_data["facets"], - "simplexes": pd.computed_data["simplexes"], - "qhull_data": pd.computed_data["qhull_data"].tolist(), - "dim": pd.computed_data["dim"], + "qhull_data": pd.qhull_data.tolist(), "el_refs": subspace_el_refs, "elements": tuple(pd.elements), } return { "all_entries": all_entries, - "elements": elements, - "dim": dim, "el_refs": [(el.symbol, entry_to_index[entry]) for el, entry in el_refs.items()], "qhull_entries": [entry_to_index[entry] for entry in qhull_entries], "spaces": spaces_list, @@ -1802,6 +1790,8 @@ def __init__( expensive convex hull computation. The dict is the output from the PatchedPhaseDiagram._compute() method. """ + self.elements = list(elements) if elements else sorted({el for e in entries for el in e.elements}) + if computed_data is None: computed_data = self._compute(entries, elements, keep_all_spaces, verbose) else: @@ -1811,8 +1801,6 @@ def __init__( self.computed_data = computed_data self.all_entries = computed_data["all_entries"] - self.elements = computed_data["elements"] - self.dim = computed_data["dim"] # Convert el_refs from [(el_symbol, index), ...] or [(el_symbol, Entry), ...] to {Element: Entry} el_refs_data = computed_data["el_refs"] @@ -1901,11 +1889,9 @@ def as_dict(self) -> dict[str, Any]: entry_to_index = {entry: idx for idx, entry in enumerate(self.all_entries)} computed_data: dict[str, Any] = { - "elements": [element.symbol for element in self.elements], "all_entries": [e.as_dict() for e in self.all_entries], "qhull_entries": [entry_to_index[entry] for entry in self.qhull_entries], "el_refs": [(element.symbol, entry_to_index[entry]) for element, entry in self.el_refs.items()], - "dim": self.dim, } spaces_serialized = [tuple(sorted(element.symbol for element in space)) for space in self.spaces] @@ -1921,19 +1907,11 @@ def as_dict(self) -> dict[str, Any]: qhull_data = np.asarray(pd.qhull_data).tolist() facets = [facet.tolist() for facet in pd.facets] - simplexes = [] - for simplex in pd.simplexes: - simplex_dict = simplex.as_dict() - simplex_dict["coords"] = np.asarray(simplex_dict["coords"]).tolist() - simplexes.append(simplex_dict) - pds_remapped[space_key_serialized] = { "all_entries": subspace_all_entry_indices, "qhull_entries": subspace_qhull_indices, "facets": facets, - "simplexes": simplexes, "qhull_data": qhull_data, - "dim": pd.dim, "el_refs": subspace_el_refs, "elements": [element.symbol for element in pd.elements], } @@ -1965,7 +1943,6 @@ def from_dict(cls, dct: dict) -> Self: computed_data_reconstructed = computed_data.copy() computed_data_reconstructed["all_entries"] = all_entries - computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]] # qhull_entries and el_refs are already indices into all_entries pds_reconstructed = {} @@ -1973,20 +1950,18 @@ def from_dict(cls, dct: dict) -> Self: space_key_frozen = frozenset(Element(el) for el in space_key.split("-")) facets = [np.array(facet, dtype=int) for facet in pd_data["facets"]] - simplexes = [decoder.process_decoded(simplex) for simplex in pd_data["simplexes"]] subspace_elements = [Element(el) if not isinstance(el, Element) else el for el in pd_data["elements"]] # Create mapping from global index to subspace index subspace_indices = pd_data["all_entries"] global_to_sub_idx = {global_idx: sub_idx for sub_idx, global_idx in enumerate(subspace_indices)} + # simplexes reconstructed in PhaseDiagram.__init__ from qhull_data and facets pds_reconstructed[space_key_frozen] = { "all_entries": subspace_indices, "qhull_entries": [global_to_sub_idx[idx] for idx in pd_data["qhull_entries"]], "facets": facets, - "simplexes": simplexes, "qhull_data": np.array(pd_data["qhull_data"]), - "dim": pd_data["dim"], "el_refs": [(Element(el), global_to_sub_idx[idx]) for el, idx in pd_data["el_refs"]], "elements": subspace_elements, }