Skip to content

Commit 39bb8fb

Browse files
committed
will this fix?
1 parent 8ab807b commit 39bb8fb

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

src/pymatgen/analysis/phase_diagram.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,13 @@ def __init__(
375375
if not isinstance(computed_data, dict):
376376
raise TypeError(f"computed_data should be dict, got {type(computed_data).__name__}")
377377

378-
# Update keys to be Element objects in case they are strings in pre-computed data
379-
computed_data["el_refs"] = [(Element(el_str), entry) for el_str, entry in computed_data["el_refs"]]
380378
self.computed_data = computed_data
381379
self.facets = computed_data["facets"]
382380
self.simplexes = computed_data["simplexes"]
383381
self.all_entries = computed_data["all_entries"]
384382
self.qhull_data = computed_data["qhull_data"]
385383
self.dim = computed_data["dim"]
386-
self.el_refs = dict(computed_data["el_refs"])
384+
self.el_refs = {Element(el): entry for el, entry in computed_data["el_refs"]}
387385
self.qhull_entries = tuple(computed_data["qhull_entries"])
388386
self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries)
389387
self._stable_entries = tuple({self.qhull_entries[idx] for idx in set(itertools.chain(*self.facets))})
@@ -394,14 +392,29 @@ def as_dict(self):
394392

395393
qhull_entry_indices = [self.all_entries.index(e) for e in self.qhull_entries]
396394

395+
# Create a copy of computed_data to avoid modifying the original
396+
computed_data = self.computed_data.copy()
397+
computed_data["elements"] = [el.symbol for el in self.elements]
398+
computed_data["el_refs"] = [
399+
(el.symbol if isinstance(el, Element) else el, entry.as_dict()) for el, entry in computed_data["el_refs"]
400+
]
401+
computed_data["all_entries"] = [e.as_dict() for e in computed_data["all_entries"]]
402+
computed_data["qhull_entries"] = qhull_entry_indices
403+
computed_data["qhull_data"] = (
404+
computed_data["qhull_data"].tolist()
405+
if isinstance(computed_data["qhull_data"], np.ndarray)
406+
else computed_data["qhull_data"]
407+
)
408+
computed_data["facets"] = [list(facet) for facet in computed_data["facets"]]
409+
computed_data["simplexes"] = [
410+
{**s.as_dict(), "coords": s.as_dict()["coords"].tolist()} for s in computed_data["simplexes"]
411+
]
412+
397413
return {
398414
"@module": type(self).__module__,
399415
"@class": type(self).__name__,
400-
"elements": [e.as_dict() for e in self.elements],
401-
"computed_data": self.computed_data
402-
| {
403-
"qhull_entries": qhull_entry_indices,
404-
},
416+
"elements": [el.symbol for el in self.elements],
417+
"computed_data": computed_data,
405418
}
406419

407420
@classmethod
@@ -414,17 +427,22 @@ def from_dict(cls, dct: dict[str, Any]) -> Self:
414427
PhaseDiagram
415428
"""
416429
computed_data = dct.get("computed_data")
417-
elements = [Element.from_dict(elem) for elem in dct["elements"]]
430+
elements = [Element(elem) for elem in dct["elements"]]
418431

419432
# for backwards compatibility, check for old format
420433
if "all_entries" in dct:
421434
entries = [MontyDecoder().process_decoded(entry) for entry in dct["all_entries"]]
422435
else:
423436
entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["all_entries"]]
424437

425-
complete_qhull_entries = [computed_data["all_entries"][i] for i in computed_data["qhull_entries"]]
426-
427-
computed_data = computed_data | {"qhull_entries": complete_qhull_entries}
438+
# Reconstruct computed_data to match _compute() format: (str, Entry) tuples for el_refs
439+
computed_data = computed_data.copy()
440+
computed_data["qhull_entries"] = [entries[i] for i in computed_data["qhull_entries"]]
441+
computed_data["elements"] = [Element(el) for el in computed_data["elements"]]
442+
# Keep el_refs as (str, Entry) format to match _compute() output
443+
computed_data["el_refs"] = [
444+
(el_str, MontyDecoder().process_decoded(entry)) for el_str, entry in computed_data["el_refs"]
445+
]
428446

429447
return cls(entries, elements, computed_data=computed_data)
430448

@@ -497,8 +515,7 @@ def _compute(self) -> dict[str, Any]:
497515
"all_entries": all_entries,
498516
"qhull_data": qhull_data,
499517
"dim": dim,
500-
# Dictionary with Element keys is not JSON-serializable
501-
"el_refs": list(el_refs.items()),
518+
"el_refs": [(el.symbol, entry) for el, entry in el_refs.items()],
502519
"qhull_entries": qhull_entries,
503520
}
504521

@@ -1427,6 +1444,8 @@ def __init__(
14271444
):
14281445
"""Standard constructor for grand potential phase diagram.
14291446
1447+
TODO: update serialization here.
1448+
14301449
Args:
14311450
entries (Sequence[EntryLike]): A list of EntryLike objects having an
14321451
energy, energy_per_atom and composition.
@@ -1563,6 +1582,8 @@ def transform_entries(
15631582
Li3PO4 is mapped into a Li2O:1.5, P2O5:0.5 composition. The terminal
15641583
compositions are represented by DummySpecies.
15651584
1585+
TODO: update serialization here.
1586+
15661587
Args:
15671588
entries: Sequence of all input entries
15681589
terminal_compositions: Terminal compositions of phase space.
@@ -1732,7 +1753,7 @@ def _compute(
17321753
subspace_qhull_entry_indices = [pd.all_entries.index(entry) for entry in pd.qhull_entries]
17331754

17341755
# Convert el_refs to indices into subspace's all_entries
1735-
subspace_el_refs = [(el.symbol, pd.all_entries.index(entry)) for el, entry in pd.computed_data["el_refs"]]
1756+
subspace_el_refs = [(el, pd.all_entries.index(entry)) for el, entry in pd.computed_data["el_refs"]]
17361757

17371758
pds_computed_data[space] = {
17381759
"all_entries": subspace_all_entry_indices,
@@ -1921,7 +1942,7 @@ def as_dict(self) -> dict[str, Any]:
19211942
# Update computed_data with deduplicated entries and remapped indices
19221943
computed_data = self.computed_data.copy()
19231944

1924-
computed_data["elements"] = [e.as_dict() for e in self.elements]
1945+
computed_data["elements"] = [e.symbol for e in self.elements]
19251946

19261947
# Convert qhull_entries to indices if they're entry objects
19271948
qhull_entries_data = computed_data["qhull_entries"]
@@ -2043,7 +2064,7 @@ def as_dict(self) -> dict[str, Any]:
20432064
return {
20442065
"@module": type(self).__module__,
20452066
"@class": type(self).__name__,
2046-
"elements": [e.as_dict() for e in self.elements],
2067+
"elements": [e.symbol for e in self.elements],
20472068
"computed_data": computed_data | {"unique_entries": unique_entry_dicts},
20482069
}
20492070

@@ -2058,7 +2079,7 @@ def from_dict(cls, dct: dict) -> Self:
20582079
PatchedPhaseDiagram
20592080
"""
20602081
computed_data = dct.get("computed_data")
2061-
elements = [Element.from_dict(elem) for elem in dct["elements"]]
2082+
elements = [Element(elem) for elem in dct["elements"]]
20622083

20632084
# Handle new format with computed_data and unique_entries
20642085
if computed_data and "unique_entries" in computed_data:
@@ -2069,15 +2090,15 @@ def from_dict(cls, dct: dict) -> Self:
20692090
# Reconstruct computed_data with entry objects
20702091
computed_data_reconstructed = computed_data.copy()
20712092
computed_data_reconstructed["all_entries"] = all_entries
2072-
computed_data_reconstructed["elements"] = [Element.from_dict(e) for e in computed_data["elements"]]
2093+
computed_data_reconstructed["elements"] = [Element(elem) for elem in computed_data["elements"]]
20732094
computed_data_reconstructed["spaces"] = computed_data["spaces"]
20742095

20752096
# Reconstruct qhull_entries from indices
20762097
computed_data_reconstructed["qhull_entries"] = [all_entries[idx] for idx in computed_data["qhull_entries"]]
20772098

20782099
# Reconstruct el_refs from indices
20792100
computed_data_reconstructed["el_refs"] = [
2080-
(Element(el_symbol), all_entries[idx]) for el_symbol, idx in computed_data["el_refs"]
2101+
(Element(elem), all_entries[idx]) for elem, idx in computed_data["el_refs"]
20812102
]
20822103

20832104
# Reconstruct subspace PhaseDiagram computed_data

tests/analysis/test_phase_diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def test_as_from_dict(self):
836836
unique_entries = [MontyDecoder().process_decoded(entry) for entry in computed_data["unique_entries"]]
837837
reconstructed_entries = [unique_entries[idx] for idx in computed_data["all_entries"]]
838838
assert len(reconstructed_entries) == len(self.ppd.all_entries)
839-
assert ppd_dict["elements"] == [elem.as_dict() for elem in self.ppd.elements]
839+
assert ppd_dict["elements"] == [elem.symbol for elem in self.ppd.elements]
840840

841841
reconstructed_ppd = PatchedPhaseDiagram.from_dict(ppd_dict)
842842
reconstructed_dict = reconstructed_ppd.as_dict()

0 commit comments

Comments
 (0)