19
19
from typing import Any
20
20
from typing_extensions import Self
21
21
22
+
22
23
class EmmetReplica (BaseModel ):
23
24
"""Define strongly typed, fixed schema versions of generic pymatgen objects."""
24
25
25
26
@classmethod
26
- def from_pymatgen (cls , pmg_obj : Any ) -> Self :
27
+ def from_pymatgen (cls , pmg_obj : Any ) -> Self :
27
28
"""Convert pymatgen objects to an EmmetReplica representation."""
28
29
raise NotImplementedError
29
30
30
31
def to_pymatgen (self ) -> Any :
31
32
"""Convert EmmetReplica object to pymatgen equivalent."""
32
33
raise NotImplementedError
33
-
34
+
34
35
@classmethod
35
- def from_dict (cls , dct : dict [str ,Any ]) -> Self :
36
+ def from_dict (cls , dct : dict [str , Any ]) -> Self :
36
37
"""MSONable-like function to create this object from a dict."""
37
38
raise NotImplementedError
38
39
39
- def as_dict (self ) -> dict [str ,Any ]:
40
+ def as_dict (self ) -> dict [str , Any ]:
40
41
"""MSONable-like function to create dict representation of this object."""
41
42
raise NotImplementedError
42
43
@@ -49,6 +50,7 @@ class SiteProperties(Enum):
49
50
velocities = "velocities"
50
51
selective_dynamics = "selective_dynamics"
51
52
53
+
52
54
class ElementSymbol (Enum ):
53
55
"""Lightweight representation of a chemical element."""
54
56
@@ -180,6 +182,7 @@ def __str__(self):
180
182
"""Get element name."""
181
183
return self .name
182
184
185
+
183
186
class LightLattice (tuple ):
184
187
"""Low memory representation of a Lattice as a tuple of a 3x3 matrix."""
185
188
@@ -188,7 +191,9 @@ def __new__(cls, matrix):
188
191
lattice_matrix = np .array (matrix )
189
192
if lattice_matrix .shape != (3 , 3 ):
190
193
raise ValueError ("Lattice matrix must be 3x3." )
191
- return super (LightLattice ,cls ).__new__ (cls ,tuple ([tuple (v ) for v in lattice_matrix .tolist ()]))
194
+ return super (LightLattice , cls ).__new__ (
195
+ cls , tuple ([tuple (v ) for v in lattice_matrix .tolist ()])
196
+ )
192
197
193
198
def as_dict (self ) -> dict [str , list | str ]:
194
199
"""Define MSONable-like as_dict."""
@@ -211,7 +216,7 @@ def volume(self) -> float:
211
216
212
217
class ElementReplica (EmmetReplica ):
213
218
"""Define a flexible schema for elements and periodic sites.
214
-
219
+
215
220
The only required field in this model is `element`.
216
221
This is intended to mimic a `pymatgen` `.Element` object.
217
222
Additionally, the `lattice` and coordinates of the site can be specified
@@ -239,43 +244,59 @@ class ElementReplica(EmmetReplica):
239
244
was allowed to relax on.
240
245
"""
241
246
242
- element : ElementSymbol = Field (description = "The element." )
243
- lattice : Matrix3D | None = Field (default = None , description = "The lattice in 3x3 matrix form." )
244
- cart_coords : Vector3D | None = Field (default = None , description = "The postion of the site in Cartesian coordinates." )
245
- frac_coords : Vector3D | None = Field (default = None , description = "The postion of the site in direct lattice vector coordinates." )
246
- charge : float | None = Field (default = None , description = "The on-site charge." )
247
- magmom : float | None = Field (default = None , description = "The on-site magnetic moment." )
248
- velocities : Vector3D | None = Field (default = None , description = "The Cartesian components of the site velocity." )
249
- selective_dynamics : tuple [bool , bool , bool ] | None = Field (default = None , description = "The degrees of freedom which are allowed to relax on the site." )
250
-
251
- def model_post_init (self , __context : Any ) -> None :
247
+ element : ElementSymbol = Field (description = "The element." )
248
+ lattice : Matrix3D | None = Field (
249
+ default = None , description = "The lattice in 3x3 matrix form."
250
+ )
251
+ cart_coords : Vector3D | None = Field (
252
+ default = None , description = "The postion of the site in Cartesian coordinates."
253
+ )
254
+ frac_coords : Vector3D | None = Field (
255
+ default = None ,
256
+ description = "The postion of the site in direct lattice vector coordinates." ,
257
+ )
258
+ charge : float | None = Field (default = None , description = "The on-site charge." )
259
+ magmom : float | None = Field (
260
+ default = None , description = "The on-site magnetic moment."
261
+ )
262
+ velocities : Vector3D | None = Field (
263
+ default = None , description = "The Cartesian components of the site velocity."
264
+ )
265
+ selective_dynamics : tuple [bool , bool , bool ] | None = Field (
266
+ default = None ,
267
+ description = "The degrees of freedom which are allowed to relax on the site." ,
268
+ )
269
+
270
+ def model_post_init (self , __context : Any ) -> None :
252
271
"""Ensure both Cartesian and direct coordinates are set, if necessary."""
253
272
if self .lattice :
254
273
if self .cart_coords is not None :
255
274
self .frac_coords = self .frac_coords or np .linalg .solve (
256
- np .array (self .lattice ).T , np .array (self .cart_coords )
257
- )
275
+ np .array (self .lattice ).T , np .array (self .cart_coords )
276
+ )
258
277
elif self .frac_coords is not None :
259
278
self .cart_coords = self .cart_coords or tuple (
260
279
np .matmul (np .array (self .lattice ).T , np .array (self .frac_coords ))
261
280
)
262
-
281
+
263
282
@classmethod
264
- def from_pymatgen (cls , pmg_obj : Element | PeriodicSite ) -> Self :
283
+ def from_pymatgen (cls , pmg_obj : Element | PeriodicSite ) -> Self :
265
284
"""Convert a pymatgen .PeriodicSite or .Element to .ElementReplica.
266
-
285
+
267
286
Parameters
268
287
-----------
269
288
site : pymatgen .Element or .PeriodicSite
270
289
"""
271
290
if isinstance (pmg_obj , Element ):
272
- return cls (element = ElementSymbol (pmg_obj .name ))
291
+ return cls (element = ElementSymbol (pmg_obj .name ))
273
292
274
293
return cls (
275
- element = ElementSymbol (next (iter (pmg_obj .species .remove_charges ().as_dict ()))),
276
- lattice = LightLattice (pmg_obj .lattice .matrix ),
277
- frac_coords = pmg_obj .frac_coords ,
278
- cart_coords = pmg_obj .coords ,
294
+ element = ElementSymbol (
295
+ next (iter (pmg_obj .species .remove_charges ().as_dict ()))
296
+ ),
297
+ lattice = LightLattice (pmg_obj .lattice .matrix ),
298
+ frac_coords = pmg_obj .frac_coords ,
299
+ cart_coords = pmg_obj .coords ,
279
300
)
280
301
281
302
def to_pymatgen (self ) -> PeriodicSite :
@@ -285,20 +306,20 @@ def to_pymatgen(self) -> PeriodicSite:
285
306
self .frac_coords ,
286
307
Lattice (self .lattice ),
287
308
coords_are_cartesian = False ,
288
- properties = self .properties
309
+ properties = self .properties ,
289
310
)
290
311
291
312
@property
292
- def species (self ) -> dict [str ,int ]:
313
+ def species (self ) -> dict [str , int ]:
293
314
"""Composition-like representation of site."""
294
- return {self .element .name : 1 }
315
+ return {self .element .name : 1 }
295
316
296
317
@property
297
- def properties (self ) -> dict [str ,float ]:
318
+ def properties (self ) -> dict [str , float ]:
298
319
"""Aggregate optional properties defined on the site."""
299
320
props = {}
300
321
for k in SiteProperties .__members__ :
301
- if (prop := getattr (self ,k , None )) is not None :
322
+ if (prop := getattr (self , k , None )) is not None :
302
323
props [k ] = prop
303
324
return props
304
325
@@ -324,7 +345,7 @@ def Z(self) -> int:
324
345
def name (self ) -> str :
325
346
"""Ensure compatibility with PeriodicSite."""
326
347
return self .element .name
327
-
348
+
328
349
@property
329
350
def species_string (self ) -> str :
330
351
"""Ensure compatibility with PeriodicSite."""
@@ -337,18 +358,18 @@ def label(self) -> str:
337
358
338
359
def __str__ (self ):
339
360
return self .label
340
-
361
+
341
362
def add_attrs (self , ** kwargs ) -> ElementReplica :
342
363
"""Rapidly create a copy of this instance with additional fields set.
343
-
364
+
344
365
Parameters
345
366
-----------
346
367
**kwargs
347
368
Any of the fields defined in the model. This function is used to
348
369
add lattice and coordinate information to each site, and thereby
349
370
not store it in the StructureReplica object itself in addition to
350
371
each site.
351
-
372
+
352
373
Returns
353
374
-----------
354
375
ElementReplica
@@ -357,6 +378,7 @@ def add_attrs(self, **kwargs) -> ElementReplica:
357
378
config .update (** kwargs )
358
379
return ElementReplica (** config )
359
380
381
+
360
382
class StructureReplica (BaseModel ):
361
383
"""Define a fixed schema structure.
362
384
@@ -367,10 +389,10 @@ class StructureReplica(BaseModel):
367
389
When the `.sites` attr of `StructureReplica` is accessed, all prior attributes
368
390
(respective aliases: `lattice`, `frac_coords`, and `coords`) are assigned to the
369
391
retrieved sites.
370
- Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
392
+ Compare this to pymatgen's .Structure, which stores the `lattice`, `frac_coords`,
371
393
and `cart_coords` both in the .Structure object and each .PeriodicSite within it.
372
394
373
-
395
+
374
396
Parameters
375
397
-----------
376
398
lattice : LightLattice
@@ -385,21 +407,25 @@ class StructureReplica(BaseModel):
385
407
charge (optional) : float
386
408
The total charge on the structure.
387
409
"""
388
-
389
- lattice : LightLattice = Field (description = "The lattice in 3x3 matrix form." )
390
- species : list [ElementReplica ] = Field (description = "The elements in the structure." )
391
- frac_coords : ListMatrix3D = Field (description = "The direct coordinates of the sites in the structure." )
392
- cart_coords : ListMatrix3D = Field (description = "The Cartesian coordinates of the sites in the structure." )
393
- charge : float | None = Field (None , description = "The net charge on the structure." )
410
+
411
+ lattice : LightLattice = Field (description = "The lattice in 3x3 matrix form." )
412
+ species : list [ElementReplica ] = Field (description = "The elements in the structure." )
413
+ frac_coords : ListMatrix3D = Field (
414
+ description = "The direct coordinates of the sites in the structure."
415
+ )
416
+ cart_coords : ListMatrix3D = Field (
417
+ description = "The Cartesian coordinates of the sites in the structure."
418
+ )
419
+ charge : float | None = Field (None , description = "The net charge on the structure." )
394
420
395
421
@property
396
422
def sites (self ) -> list [ElementReplica ]:
397
423
"""Return a list of sites in the structure with lattice and coordinate info."""
398
424
return [
399
425
species .add_attrs (
400
- lattice = self .lattice ,
401
- cart_coords = self .cart_coords [idx ],
402
- frac_coords = self .frac_coords [idx ],
426
+ lattice = self .lattice ,
427
+ cart_coords = self .cart_coords [idx ],
428
+ frac_coords = self .frac_coords [idx ],
403
429
)
404
430
for idx , species in enumerate (self .species )
405
431
]
@@ -431,7 +457,7 @@ def num_sites(self) -> int:
431
457
@classmethod
432
458
def from_pymatgen (cls , pmg_obj : Structure ) -> Self :
433
459
"""Create a StructureReplica from a pymatgen .Structure.
434
-
460
+
435
461
Parameters
436
462
-----------
437
463
pmg_obj : pymatgen .Structure
@@ -444,41 +470,46 @@ def from_pymatgen(cls, pmg_obj: Structure) -> Self:
444
470
raise ValueError (
445
471
"Currently, `StructureReplica` is intended to represent only ordered materials."
446
472
)
447
-
473
+
448
474
lattice = LightLattice (pmg_obj .lattice .matrix )
449
475
properties = [{} for _ in range (len (pmg_obj ))]
450
476
for idx , site in enumerate (pmg_obj ):
451
- for k in ("charge" ,"magmom" ,"velocities" ,"selective_dynamics" ):
477
+ for k in ("charge" , "magmom" , "velocities" , "selective_dynamics" ):
452
478
if (prop := site .properties .get (k )) is not None :
453
479
properties [idx ][k ] = prop
454
480
455
481
species = [
456
482
ElementReplica (
457
- element = ElementSymbol [next (iter (site .species .remove_charges ().as_dict ()))],
458
- ** properties [idx ]
483
+ element = ElementSymbol [
484
+ next (iter (site .species .remove_charges ().as_dict ()))
485
+ ],
486
+ ** properties [idx ],
459
487
)
460
488
for idx , site in enumerate (pmg_obj )
461
489
]
462
490
463
491
return cls (
464
492
lattice = lattice ,
465
- species = species ,
466
- frac_coords = [site .frac_coords for site in pmg_obj ],
467
- cart_coords = [site .coords for site in pmg_obj ],
468
- charge = pmg_obj .charge ,
493
+ species = species ,
494
+ frac_coords = [site .frac_coords for site in pmg_obj ],
495
+ cart_coords = [site .coords for site in pmg_obj ],
496
+ charge = pmg_obj .charge ,
469
497
)
470
-
498
+
471
499
def to_pymatgen (self ) -> Structure :
472
500
"""Convert to a pymatgen .Structure."""
473
- return Structure .from_sites ([site .to_periodic_site () for site in self ], charge = self .charge )
474
-
501
+ return Structure .from_sites (
502
+ [site .to_periodic_site () for site in self ], charge = self .charge
503
+ )
504
+
475
505
@classmethod
476
506
def from_poscar (cls , poscar_path : str | Path ) -> Self :
477
507
"""Define convenience method to create a StructureReplica from a VASP POSCAR."""
478
508
return cls .from_structure (Poscar .from_file (poscar_path ).structure )
479
509
480
510
def __str__ (self ):
481
511
"""Define format for printing a Structure."""
512
+
482
513
def _format_float (val : float | int ) -> str :
483
514
nspace = 2 if val >= 0.0 else 1
484
515
return " " * nspace + f"{ val :.8f} "
0 commit comments