Skip to content

Commit 9f28cd8

Browse files
precommit
1 parent 25d4002 commit 9f28cd8

File tree

2 files changed

+98
-64
lines changed

2 files changed

+98
-64
lines changed

emmet-core/emmet/core/math.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
Vector3D = TypeVar("Vector3D", bound=tuple[float, float, float])
55
Vector3D.__doc__ = "Real space vector" # type: ignore
66

7-
Matrix3D = TypeVar("Matrix3D", bound = tuple[Vector3D, Vector3D, Vector3D])
7+
Matrix3D = TypeVar("Matrix3D", bound=tuple[Vector3D, Vector3D, Vector3D])
88
Matrix3D.__doc__ = "Real space Matrix" # type: ignore
99

1010
Vector6D = TypeVar("Vector6D", bound=tuple[float, float, float, float, float, float])
1111
Vector6D.__doc__ = "6D Voigt matrix component" # type: ignore
1212

13-
MatrixVoigt = TypeVar("MatrixVoigt",bound=tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D])
13+
MatrixVoigt = TypeVar(
14+
"MatrixVoigt",
15+
bound=tuple[Vector6D, Vector6D, Vector6D, Vector6D, Vector6D, Vector6D],
16+
)
1417
MatrixVoigt.__doc__ = "Voigt representation of a 3x3x3x3 tensor" # type: ignore
1518

16-
Tensor3R = TypeVar("Tensor3R",bound=list[list[list[float]]])
19+
Tensor3R = TypeVar("Tensor3R", bound=list[list[list[float]]])
1720
Tensor3R.__doc__ = "Generic tensor of rank 3" # type: ignore
1821

19-
Tensor4R = TypeVar("Tensor4R",bound=list[list[list[list[float]]]])
22+
Tensor4R = TypeVar("Tensor4R", bound=list[list[list[list[float]]]])
2023
Tensor4R.__doc__ = "Generic tensor of rank 4" # type: ignore
2124

22-
ListVector3D = TypeVar("ListVector3D",bound=list[float])
25+
ListVector3D = TypeVar("ListVector3D", bound=list[float])
2326
ListVector3D.__doc__ = "Real space vector as list" # type: ignore
2427

25-
ListMatrix3D = TypeVar("ListMatrix3D",bound=list[ListVector3D])
28+
ListMatrix3D = TypeVar("ListMatrix3D", bound=list[ListVector3D])
2629
ListMatrix3D.__doc__ = "Real space Matrix as list" # type: ignore

emmet-core/emmet/core/structure_replicas.py

+89-58
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,25 @@
1919
from typing import Any
2020
from typing_extensions import Self
2121

22+
2223
class EmmetReplica(BaseModel):
2324
"""Define strongly typed, fixed schema versions of generic pymatgen objects."""
2425

2526
@classmethod
26-
def from_pymatgen(cls, pmg_obj : Any) -> Self:
27+
def from_pymatgen(cls, pmg_obj: Any) -> Self:
2728
"""Convert pymatgen objects to an EmmetReplica representation."""
2829
raise NotImplementedError
2930

3031
def to_pymatgen(self) -> Any:
3132
"""Convert EmmetReplica object to pymatgen equivalent."""
3233
raise NotImplementedError
33-
34+
3435
@classmethod
35-
def from_dict(cls, dct : dict[str,Any]) -> Self:
36+
def from_dict(cls, dct: dict[str, Any]) -> Self:
3637
"""MSONable-like function to create this object from a dict."""
3738
raise NotImplementedError
3839

39-
def as_dict(self) -> dict[str,Any]:
40+
def as_dict(self) -> dict[str, Any]:
4041
"""MSONable-like function to create dict representation of this object."""
4142
raise NotImplementedError
4243

@@ -49,6 +50,7 @@ class SiteProperties(Enum):
4950
velocities = "velocities"
5051
selective_dynamics = "selective_dynamics"
5152

53+
5254
class ElementSymbol(Enum):
5355
"""Lightweight representation of a chemical element."""
5456

@@ -180,6 +182,7 @@ def __str__(self):
180182
"""Get element name."""
181183
return self.name
182184

185+
183186
class LightLattice(tuple):
184187
"""Low memory representation of a Lattice as a tuple of a 3x3 matrix."""
185188

@@ -188,7 +191,9 @@ def __new__(cls, matrix):
188191
lattice_matrix = np.array(matrix)
189192
if lattice_matrix.shape != (3, 3):
190193
raise ValueError("Lattice matrix must be 3x3.")
191-
return super(LightLattice,cls).__new__(cls,tuple([tuple(v) for v in lattice_matrix.tolist()]))
194+
return super(LightLattice, cls).__new__(
195+
cls, tuple([tuple(v) for v in lattice_matrix.tolist()])
196+
)
192197

193198
def as_dict(self) -> dict[str, list | str]:
194199
"""Define MSONable-like as_dict."""
@@ -211,7 +216,7 @@ def volume(self) -> float:
211216

212217
class ElementReplica(EmmetReplica):
213218
"""Define a flexible schema for elements and periodic sites.
214-
219+
215220
The only required field in this model is `element`.
216221
This is intended to mimic a `pymatgen` `.Element` object.
217222
Additionally, the `lattice` and coordinates of the site can be specified
@@ -239,43 +244,59 @@ class ElementReplica(EmmetReplica):
239244
was allowed to relax on.
240245
"""
241246

242-
element : ElementSymbol = Field(description="The element.")
243-
lattice : Matrix3D | None = Field(default = None, description="The lattice in 3x3 matrix form.")
244-
cart_coords : Vector3D | None = Field(default = None, description="The postion of the site in Cartesian coordinates.")
245-
frac_coords : Vector3D | None = Field(default = None, description="The postion of the site in direct lattice vector coordinates.")
246-
charge : float | None = Field(default = None, description="The on-site charge.")
247-
magmom : float | None = Field(default = None, description="The on-site magnetic moment.")
248-
velocities : Vector3D | None = Field(default = None, description="The Cartesian components of the site velocity.")
249-
selective_dynamics : tuple[bool, bool, bool] | None = Field(default = None, description="The degrees of freedom which are allowed to relax on the site.")
250-
251-
def model_post_init(self, __context : Any) -> None:
247+
element: ElementSymbol = Field(description="The element.")
248+
lattice: Matrix3D | None = Field(
249+
default=None, description="The lattice in 3x3 matrix form."
250+
)
251+
cart_coords: Vector3D | None = Field(
252+
default=None, description="The postion of the site in Cartesian coordinates."
253+
)
254+
frac_coords: Vector3D | None = Field(
255+
default=None,
256+
description="The postion of the site in direct lattice vector coordinates.",
257+
)
258+
charge: float | None = Field(default=None, description="The on-site charge.")
259+
magmom: float | None = Field(
260+
default=None, description="The on-site magnetic moment."
261+
)
262+
velocities: Vector3D | None = Field(
263+
default=None, description="The Cartesian components of the site velocity."
264+
)
265+
selective_dynamics: tuple[bool, bool, bool] | None = Field(
266+
default=None,
267+
description="The degrees of freedom which are allowed to relax on the site.",
268+
)
269+
270+
def model_post_init(self, __context: Any) -> None:
252271
"""Ensure both Cartesian and direct coordinates are set, if necessary."""
253272
if self.lattice:
254273
if self.cart_coords is not None:
255274
self.frac_coords = self.frac_coords or np.linalg.solve(
256-
np.array(self.lattice).T, np.array(self.cart_coords)
257-
)
275+
np.array(self.lattice).T, np.array(self.cart_coords)
276+
)
258277
elif self.frac_coords is not None:
259278
self.cart_coords = self.cart_coords or tuple(
260279
np.matmul(np.array(self.lattice).T, np.array(self.frac_coords))
261280
)
262-
281+
263282
@classmethod
264-
def from_pymatgen(cls, pmg_obj : Element | PeriodicSite) -> Self:
283+
def from_pymatgen(cls, pmg_obj: Element | PeriodicSite) -> Self:
265284
"""Convert a pymatgen .PeriodicSite or .Element to .ElementReplica.
266-
285+
267286
Parameters
268287
-----------
269288
site : pymatgen .Element or .PeriodicSite
270289
"""
271290
if isinstance(pmg_obj, Element):
272-
return cls(element = ElementSymbol(pmg_obj.name))
291+
return cls(element=ElementSymbol(pmg_obj.name))
273292

274293
return cls(
275-
element = ElementSymbol(next(iter(pmg_obj.species.remove_charges().as_dict()))),
276-
lattice = LightLattice(pmg_obj.lattice.matrix),
277-
frac_coords = pmg_obj.frac_coords,
278-
cart_coords = pmg_obj.coords,
294+
element=ElementSymbol(
295+
next(iter(pmg_obj.species.remove_charges().as_dict()))
296+
),
297+
lattice=LightLattice(pmg_obj.lattice.matrix),
298+
frac_coords=pmg_obj.frac_coords,
299+
cart_coords=pmg_obj.coords,
279300
)
280301

281302
def to_pymatgen(self) -> PeriodicSite:
@@ -285,20 +306,20 @@ def to_pymatgen(self) -> PeriodicSite:
285306
self.frac_coords,
286307
Lattice(self.lattice),
287308
coords_are_cartesian=False,
288-
properties = self.properties
309+
properties=self.properties,
289310
)
290311

291312
@property
292-
def species(self) -> dict[str,int]:
313+
def species(self) -> dict[str, int]:
293314
"""Composition-like representation of site."""
294-
return {self.element.name : 1}
315+
return {self.element.name: 1}
295316

296317
@property
297-
def properties(self) -> dict[str,float]:
318+
def properties(self) -> dict[str, float]:
298319
"""Aggregate optional properties defined on the site."""
299320
props = {}
300321
for k in SiteProperties.__members__:
301-
if (prop := getattr(self,k,None)) is not None:
322+
if (prop := getattr(self, k, None)) is not None:
302323
props[k] = prop
303324
return props
304325

@@ -324,7 +345,7 @@ def Z(self) -> int:
324345
def name(self) -> str:
325346
"""Ensure compatibility with PeriodicSite."""
326347
return self.element.name
327-
348+
328349
@property
329350
def species_string(self) -> str:
330351
"""Ensure compatibility with PeriodicSite."""
@@ -337,18 +358,18 @@ def label(self) -> str:
337358

338359
def __str__(self):
339360
return self.label
340-
361+
341362
def add_attrs(self, **kwargs) -> ElementReplica:
342363
"""Rapidly create a copy of this instance with additional fields set.
343-
364+
344365
Parameters
345366
-----------
346367
**kwargs
347368
Any of the fields defined in the model. This function is used to
348369
add lattice and coordinate information to each site, and thereby
349370
not store it in the StructureReplica object itself in addition to
350371
each site.
351-
372+
352373
Returns
353374
-----------
354375
ElementReplica
@@ -357,6 +378,7 @@ def add_attrs(self, **kwargs) -> ElementReplica:
357378
config.update(**kwargs)
358379
return ElementReplica(**config)
359380

381+
360382
class StructureReplica(BaseModel):
361383
"""Define a fixed schema structure.
362384
@@ -367,10 +389,10 @@ class StructureReplica(BaseModel):
367389
When the `.sites` attr of `StructureReplica` is accessed, all prior attributes
368390
(respective aliases: `lattice`, `frac_coords`, and `coords`) are assigned to the
369391
retrieved sites.
370-
Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
392+
Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
371393
and `cart_coords` both in the .Structure object and each .PeriodicSite within it.
372394
373-
395+
374396
Parameters
375397
-----------
376398
lattice : LightLattice
@@ -385,21 +407,25 @@ class StructureReplica(BaseModel):
385407
charge (optional) : float
386408
The total charge on the structure.
387409
"""
388-
389-
lattice : LightLattice = Field(description="The lattice in 3x3 matrix form.")
390-
species : list[ElementReplica] = Field(description="The elements in the structure.")
391-
frac_coords : ListMatrix3D = Field(description="The direct coordinates of the sites in the structure.")
392-
cart_coords : ListMatrix3D = Field(description="The Cartesian coordinates of the sites in the structure.")
393-
charge : float | None = Field(None, description="The net charge on the structure.")
410+
411+
lattice: LightLattice = Field(description="The lattice in 3x3 matrix form.")
412+
species: list[ElementReplica] = Field(description="The elements in the structure.")
413+
frac_coords: ListMatrix3D = Field(
414+
description="The direct coordinates of the sites in the structure."
415+
)
416+
cart_coords: ListMatrix3D = Field(
417+
description="The Cartesian coordinates of the sites in the structure."
418+
)
419+
charge: float | None = Field(None, description="The net charge on the structure.")
394420

395421
@property
396422
def sites(self) -> list[ElementReplica]:
397423
"""Return a list of sites in the structure with lattice and coordinate info."""
398424
return [
399425
species.add_attrs(
400-
lattice = self.lattice,
401-
cart_coords = self.cart_coords[idx],
402-
frac_coords = self.frac_coords[idx],
426+
lattice=self.lattice,
427+
cart_coords=self.cart_coords[idx],
428+
frac_coords=self.frac_coords[idx],
403429
)
404430
for idx, species in enumerate(self.species)
405431
]
@@ -431,7 +457,7 @@ def num_sites(self) -> int:
431457
@classmethod
432458
def from_pymatgen(cls, pmg_obj: Structure) -> Self:
433459
"""Create a StructureReplica from a pymatgen .Structure.
434-
460+
435461
Parameters
436462
-----------
437463
pmg_obj : pymatgen .Structure
@@ -444,41 +470,46 @@ def from_pymatgen(cls, pmg_obj: Structure) -> Self:
444470
raise ValueError(
445471
"Currently, `StructureReplica` is intended to represent only ordered materials."
446472
)
447-
473+
448474
lattice = LightLattice(pmg_obj.lattice.matrix)
449475
properties = [{} for _ in range(len(pmg_obj))]
450476
for idx, site in enumerate(pmg_obj):
451-
for k in ("charge","magmom","velocities","selective_dynamics"):
477+
for k in ("charge", "magmom", "velocities", "selective_dynamics"):
452478
if (prop := site.properties.get(k)) is not None:
453479
properties[idx][k] = prop
454480

455481
species = [
456482
ElementReplica(
457-
element = ElementSymbol[next(iter(site.species.remove_charges().as_dict()))],
458-
**properties[idx]
483+
element=ElementSymbol[
484+
next(iter(site.species.remove_charges().as_dict()))
485+
],
486+
**properties[idx],
459487
)
460488
for idx, site in enumerate(pmg_obj)
461489
]
462490

463491
return cls(
464492
lattice=lattice,
465-
species = species,
466-
frac_coords = [site.frac_coords for site in pmg_obj],
467-
cart_coords = [site.coords for site in pmg_obj],
468-
charge = pmg_obj.charge,
493+
species=species,
494+
frac_coords=[site.frac_coords for site in pmg_obj],
495+
cart_coords=[site.coords for site in pmg_obj],
496+
charge=pmg_obj.charge,
469497
)
470-
498+
471499
def to_pymatgen(self) -> Structure:
472500
"""Convert to a pymatgen .Structure."""
473-
return Structure.from_sites([site.to_periodic_site() for site in self], charge = self.charge)
474-
501+
return Structure.from_sites(
502+
[site.to_periodic_site() for site in self], charge=self.charge
503+
)
504+
475505
@classmethod
476506
def from_poscar(cls, poscar_path: str | Path) -> Self:
477507
"""Define convenience method to create a StructureReplica from a VASP POSCAR."""
478508
return cls.from_structure(Poscar.from_file(poscar_path).structure)
479509

480510
def __str__(self):
481511
"""Define format for printing a Structure."""
512+
482513
def _format_float(val: float | int) -> str:
483514
nspace = 2 if val >= 0.0 else 1
484515
return " " * nspace + f"{val:.8f}"

0 commit comments

Comments
 (0)