@@ -181,10 +181,17 @@ class MeltingTemperature(BaseModel, title="Input options for melting temperature
181181 step : Annotated [int , Field (default = 200 , ge = 20 )]
182182 attempts : Annotated [int , Field (default = 5 , ge = 1 )]
183183
184- class MaterialsProject (BaseModel , title = 'Input options for materials project' ):
184+
185+ class MaterialsProject (BaseModel , title = "Input options for materials project" ):
185186 api_key : Annotated [str , Field (default = "" , exclude = True )]
186187 conventional : Annotated [bool , Field (default = True )]
187- target_natoms : Annotated [int , Field (default = 1500 , description = 'The structure parsed from materials project would be repeated to approximately this value' )]
188+ target_natoms : Annotated [
189+ int ,
190+ Field (
191+ default = 1500 ,
192+ description = "The structure parsed from materials project would be repeated to approximately this value" ,
193+ ),
194+ ]
188195
189196 @field_validator ("api_key" , mode = "after" )
190197 def resolve_api_key (cls , v : str ) -> str :
@@ -198,6 +205,7 @@ def resolve_api_key(cls, v: str) -> str:
198205 )
199206 return value
200207
208+
201209class Calculation (BaseModel , title = "Main input class" ):
202210 monte_carlo : Optional [MonteCarlo ] = MonteCarlo ()
203211 composition_scaling : Optional [CompositionScaling ] = CompositionScaling ()
@@ -515,40 +523,58 @@ def _validate_all(self) -> "Input":
515523 self ._original_lattice = self .lattice .lower ()
516524 write_structure_file = True
517525
518- elif self .lattice .split ('-' )[0 ] == 'mp' :
519- #confirm here that API key exists
526+ elif self .lattice .split ("-" )[0 ] == "mp" :
527+ # confirm here that API key exists
520528 if not self .materials_project .api_key :
521- raise ValueError (' could not find API KEY, pls set it.' )
522- #now we need to fetch the structure
529+ raise ValueError (" could not find API KEY, pls set it." )
530+ # now we need to fetch the structure
523531 try :
524532 from mp_api .client import MPRester
525533 except ImportError :
526- raise ImportError ('Could not import mp_api, make sure you install mp_api package!' )
527- #now all good
534+ raise ImportError (
535+ "Could not import mp_api, make sure you install mp_api package!"
536+ )
537+ # now all good
528538 rest = {
529- "use_document_model" : False ,
530- "include_user_agent" : True ,
531- "api_key" : self .materials_project .api_key ,
532- }
539+ "use_document_model" : False ,
540+ "include_user_agent" : True ,
541+ "api_key" : self .materials_project .api_key ,
542+ }
533543 with MPRester (** rest ) as mpr :
534544 docs = mpr .materials .summary .search (material_ids = [self .lattice ])
535545
536546 structures = []
537547 for doc in docs :
538- struct = doc [' structure' ]
548+ struct = doc [" structure" ]
539549 if self .materials_project .conventional :
540550 aseatoms = struct .to_conventional ().to_ase_atoms ()
541551 else :
542552 aseatoms = struct .to_primitive ().to_ase_atoms ()
543553 structures .append (aseatoms )
544554 structure = structures [0 ]
545-
555+
546556 if np .prod (self .repeat ) == 1 :
547- x = int (np .ceil ((self .materials_project .target_natoms / len (structure ))** (1 / 3 )))
557+ x = int (
558+ np .ceil (
559+ (self .materials_project .target_natoms / len (structure ))
560+ ** (1 / 3 )
561+ )
562+ )
548563 structure = structure .repeat (x )
549564 else :
550565 structure = structure .repeat (self .repeat )
551566
567+ # extract composition
568+ types , typecounts = np .unique (
569+ structure .get_chemical_symbols (), return_counts = True
570+ )
571+
572+ for c , t in enumerate (types ):
573+ self ._element_dict [t ]["count" ] = typecounts [c ]
574+ self ._element_dict [t ]["composition" ] = typecounts [c ] / np .sum (
575+ typecounts
576+ )
577+
552578 self ._natoms = len (structure )
553579 self ._original_lattice = self .lattice .lower ()
554580 write_structure_file = True
0 commit comments