Skip to content

Commit 12fcd46

Browse files
committed
Fix systems count for GMTKN55
1 parent 460d59c commit 12fcd46

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

ml_peg/analysis/molecular/GMTKN55/analyse_GMTKN55.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,10 @@ def structure_info() -> dict[str, dict[str, float] | list | NDArray]:
5959
"systems": [],
6060
"excluded": [],
6161
"weights": {},
62-
"counts": {},
6362
}
6463
for model_name in MODELS:
6564
for subset in [dir.name for dir in sorted((CALC_PATH / model_name).glob("*"))]:
66-
count = 0
6765
for system_path in sorted((CALC_PATH / model_name / subset).glob("*.xyz")):
68-
count += 1
6966
structs = read(system_path, index=":")
7067
info["subsets"].append(subset)
7168

@@ -80,7 +77,6 @@ def structure_info() -> dict[str, dict[str, float] | list | NDArray]:
8077
)
8178
)
8279
info["weights"][subset] = structs[0].info["weight"]
83-
info["counts"][subset] = count
8480

8581
# Convert to numpy arrays for filtering
8682
info["categories"] = np.array(info["categories"])
@@ -219,25 +215,20 @@ def category_errors(
219215
all_categories = INFO["categories"]
220216
all_subsets = INFO["subsets"]
221217
all_weights = INFO["weights"]
222-
all_counts = INFO["counts"]
223218
excluded = INFO["excluded"]
224219

225220
# Filter excluded systems
226221
categories = all_categories[np.logical_not(excluded)]
227222

228223
for category in set(categories):
229-
# Filter non-excluded subsets in current category
230-
filtered_subsets = np.unique(
231-
all_subsets[np.logical_not(excluded)][categories == category]
224+
# Filter non-excluded subsets in current category and count systems
225+
filtered_subsets, counts = np.unique(
226+
all_subsets[np.logical_not(excluded)][categories == category],
227+
return_counts=True,
232228
)
233229

234-
# Get number of systems in each subset
235-
counts = np.array([all_counts[subset] for subset in filtered_subsets])
236-
237-
# Get error for each subset
230+
# Get error and weight for each subset
238231
errors = [subset_errors[model_name][subset] for subset in filtered_subsets]
239-
240-
# Get weight and count for each subset
241232
weights = np.array([all_weights[subset] for subset in filtered_subsets])
242233

243234
results[model_name][category] = np.sum(errors * weights * counts) / np.sum(
@@ -269,18 +260,16 @@ def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, floa
269260

270261
all_subsets = INFO["subsets"]
271262
all_weights = INFO["weights"]
272-
all_counts = INFO["counts"]
273263
excluded = INFO["excluded"]
274264

275-
# Filter all non-excluded subsets
276-
filtered_subsets = np.unique(all_subsets[np.logical_not(excluded)])
265+
# Filter all non-excluded subsets and count systems
266+
filtered_subsets, counts = np.unique(
267+
all_subsets[np.logical_not(excluded)], return_counts=True
268+
)
277269

278-
# Get error for each subset
270+
# Get error and weight for each subset
279271
errors = [subset_errors[model_name][subset] for subset in filtered_subsets]
280-
281-
# Get weight and count for each subset
282272
weights = np.array([all_weights[subset] for subset in filtered_subsets])
283-
counts = np.array([all_counts[subset] for subset in filtered_subsets])
284273

285274
results[model_name] = np.sum(errors * weights * counts) / np.sum(counts)
286275

0 commit comments

Comments
 (0)