Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions libensemble/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,18 @@ def _prep_fields(self, results: npt.NDArray) -> npt.NDArray:
(name, results.dtype[name]) for name in results.dtype.names if name in self.gen_specs["persis_in"]
]

new_dtype = filtered_dtype + [("sim_ended", bool)]
# SH TODO - only add if not there and also is this not APOSMM specific?
# SH TODO - can remove as its added in aposmm persis_in (mirrored in gen_specs due to the _sync_gen_specs function)
# SH TODO - however as set to True below should only be APOSMM subclass.
# SH TODO - works in libE - but will need to check in Optimas - till then add if not there. But may not be needed.
# SH TODO - also test here if need the new dtype when using the merged gen_specs. *wrs now rply||

print(f'filtered_dtype: {filtered_dtype}')
if "sim_ended" not in [name for name, _ in filtered_dtype]:
new_dtype = filtered_dtype + [("sim_ended", bool)]
else:
new_dtype = filtered_dtype

new_results = np.zeros(len(results), dtype=new_dtype)

for field in new_results.dtype.names:
Expand All @@ -176,7 +187,7 @@ def _prep_fields(self, results: npt.NDArray) -> npt.NDArray:
except ValueError:
continue

new_results["sim_ended"] = True
new_results["sim_ended"] = True #SH TODO - APOSMM specific and only needed if was added here.
return new_results

def ingest(self, results: List[dict], tag: int = EVAL_GEN_TAG) -> None:
Expand All @@ -201,6 +212,8 @@ def ingest_numpy(self, results: npt.NDArray, tag: int = EVAL_GEN_TAG) -> None:
tag, np.copy(results)
) # SH for threads check - might need deepcopy due to dtype=object
else:
print(f'\n=======self.running_gen_f: {self.running_gen_f} type ({type(self.running_gen_f)})\n')
import pdb; pdb.set_trace()
self.running_gen_f.send(tag, None)

def finalize(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion libensemble/libE.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
from libensemble.tools.alloc_support import AllocSupport
from libensemble.tools.tools import _USER_SIM_ID_WARNING
from libensemble.utils import launcher
from libensemble.utils.misc import specs_dump
from libensemble.utils.misc import specs_dump, sync_gen_specs
from libensemble.utils.pydantic_bindings import libE_wrapper
from libensemble.utils.timer import Timer
from libensemble.version import __version__
Expand Down Expand Up @@ -241,6 +241,7 @@ def libE(
for spec in [ensemble.sim_specs, ensemble.gen_specs, ensemble.alloc_specs, ensemble.libE_specs]
]
exit_criteria = specs_dump(ensemble.exit_criteria, by_alias=True, exclude_none=True)
sync_gen_specs(gen_specs)

# Extract platform info from settings or environment
platform_info = get_platform(libE_specs)
Expand Down
14 changes: 7 additions & 7 deletions libensemble/tests/regression_tests/test_asktell_aposmm_nlopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,21 @@
max_active_runs=workflow.nworkers, # should this match nworkers always? practically?
)

# SH TODO - dont want this stuff duplicated
# SH initial_batch_size conflicts with initial_sample_size
# with server type gen, deciding outside seems correct way.
workflow.gen_specs = GenSpecs(
persis_in=["x", "x_on_cube", "sim_id", "local_min", "local_pt", "f"],
generator=aposmm,
batch_size=5,
initial_batch_size=10,
user={"initial_sample_size": 100},
batch_size=5, # SH what happens if not set - test this
initial_batch_size=10, # SH what happens if not set - does it get from gen? - test this
# persis_in=["x", "x_on_cube", "sim_id", "local_min", "local_pt", "f"],
# user={"initial_sample_size": 100},
)

workflow.libE_specs.gen_on_manager = True
workflow.add_random_streams()

H, _, _ = workflow.run()

# Perform the run
H, _, _ = workflow.run()

if workflow.is_manager:
print("[Manager]:", H[np.where(H["local_min"])]["x"])
Expand Down
42 changes: 42 additions & 0 deletions libensemble/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy.typing as npt


# SH TODO - some of these need docstrings.

def extract_H_ranges(Work: dict) -> str:
"""Convert received H_rows into ranges for labeling"""
work_H_rows = Work["libE_info"]["H_rows"]
Expand Down Expand Up @@ -272,3 +274,43 @@ def np_to_list_dicts(array: npt.NDArray, mapping: dict = {}, allow_arrays: bool
entry["_id"] = entry.pop("sim_id")

return out


def _merge_fields(a, b):
"""Merge dict or list fields from b into a."""
if isinstance(a, dict):
for k, v in b.items():
if k not in a:
a[k] = v
elif isinstance(a, list):
a.extend(x for x in b if x not in a)


def _find_generator(gen_specs):
"""Find the generator object in the gen_specs"""
from libensemble.generators import LibensembleGenerator
#SH TODO - import in here to avoid circular import. but i will move these to diff file in utils or
#SH TODO - do within the gens as a generator function.
#SH TMP TEST - dont think optimas sould need but lets see.
generator = gen_specs.get("generator") or gen_specs.get("user").get('generator')
if generator and not isinstance(generator, LibensembleGenerator):
if hasattr(generator, 'gen') and isinstance(generator.gen, LibensembleGenerator):
generator = generator.gen

return generator


def sync_gen_specs(gen_specs):
"""Automatically populate gen_specs with values from generator if available."""
generator = _find_generator(gen_specs)
if not generator or not hasattr(generator, "gen_specs"):
return

for field_name, field_value in generator.gen_specs.items():
if isinstance(field_value, (dict, list)) and field_value:
if field_name not in gen_specs:
gen_specs[field_name] = field_value.copy()
else:
_merge_fields(gen_specs[field_name], field_value)
elif field_name not in gen_specs:
gen_specs[field_name] = field_value