@@ -59,13 +59,10 @@ def structure_info() -> dict[str, dict[str, float] | list | NDArray]:
5959 "systems" : [],
6060 "excluded" : [],
6161 "weights" : {},
62- "counts" : {},
6362 }
6463 for model_name in MODELS :
6564 for subset in [dir .name for dir in sorted ((CALC_PATH / model_name ).glob ("*" ))]:
66- count = 0
6765 for system_path in sorted ((CALC_PATH / model_name / subset ).glob ("*.xyz" )):
68- count += 1
6966 structs = read (system_path , index = ":" )
7067 info ["subsets" ].append (subset )
7168
@@ -80,7 +77,6 @@ def structure_info() -> dict[str, dict[str, float] | list | NDArray]:
8077 )
8178 )
8279 info ["weights" ][subset ] = structs [0 ].info ["weight" ]
83- info ["counts" ][subset ] = count
8480
8581 # Convert to numpy arrays for filtering
8682 info ["categories" ] = np .array (info ["categories" ])
@@ -219,25 +215,20 @@ def category_errors(
219215 all_categories = INFO ["categories" ]
220216 all_subsets = INFO ["subsets" ]
221217 all_weights = INFO ["weights" ]
222- all_counts = INFO ["counts" ]
223218 excluded = INFO ["excluded" ]
224219
225220 # Filter excluded systems
226221 categories = all_categories [np .logical_not (excluded )]
227222
228223 for category in set (categories ):
229- # Filter non-excluded subsets in current category
230- filtered_subsets = np .unique (
231- all_subsets [np .logical_not (excluded )][categories == category ]
224+ # Filter non-excluded subsets in current category and count systems
225+ filtered_subsets , counts = np .unique (
226+ all_subsets [np .logical_not (excluded )][categories == category ],
227+ return_counts = True ,
232228 )
233229
234- # Get number of systems in each subset
235- counts = np .array ([all_counts [subset ] for subset in filtered_subsets ])
236-
237- # Get error for each subset
230+ # Get error and weight for each subset
238231 errors = [subset_errors [model_name ][subset ] for subset in filtered_subsets ]
239-
240- # Get weight and count for each subset
241232 weights = np .array ([all_weights [subset ] for subset in filtered_subsets ])
242233
243234 results [model_name ][category ] = np .sum (errors * weights * counts ) / np .sum (
@@ -269,18 +260,16 @@ def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, floa
269260
270261 all_subsets = INFO ["subsets" ]
271262 all_weights = INFO ["weights" ]
272- all_counts = INFO ["counts" ]
273263 excluded = INFO ["excluded" ]
274264
275- # Filter all non-excluded subsets
276- filtered_subsets = np .unique (all_subsets [np .logical_not (excluded )])
265+ # Filter all non-excluded subsets and count systems
266+ filtered_subsets , counts = np .unique (
267+ all_subsets [np .logical_not (excluded )], return_counts = True
268+ )
277269
278- # Get error for each subset
270+ # Get error and weight for each subset
279271 errors = [subset_errors [model_name ][subset ] for subset in filtered_subsets ]
280-
281- # Get weight and count for each subset
282272 weights = np .array ([all_weights [subset ] for subset in filtered_subsets ])
283- counts = np .array ([all_counts [subset ] for subset in filtered_subsets ])
284273
285274 results [model_name ] = np .sum (errors * weights * counts ) / np .sum (counts )
286275
0 commit comments