diff --git a/examples/5_agent_zoo.py b/examples/5_agent_zoo.py index 786e5e694a..f47faf516a 100644 --- a/examples/5_agent_zoo.py +++ b/examples/5_agent_zoo.py @@ -5,10 +5,11 @@ import sys from pathlib import Path -from examples.tools.argument_parser import empty_parser - +# This may be necessary to get the repository root into path SMARTS_REPO_PATH = Path(__file__).parents[1].absolute() sys.path.insert(0, str(SMARTS_REPO_PATH)) + +from examples.tools.argument_parser import empty_parser from smarts.core.agent import Agent from smarts.core.agent_interface import AgentInterface, AgentType from smarts.zoo import registry @@ -30,19 +31,51 @@ def rla_entrypoint(max_episode_steps=1000): def main(): + name = "random_lane_control-v0" + print(f"=== Before registering `{name}` ===") + print(registry.agent_registry) registry.register( - "random_lane_control-v0", rla_entrypoint + name, rla_entrypoint ) # This registers "__main__:random_lane_control-v0" + print(f"=== After registering `{name}` ===") print(registry.agent_registry) - agent_spec = registry.make(locator="__main__:random_lane_control-v0") + agent_spec = registry.make(locator=f"__main__:{name}") agent_interface = agent_spec.interface agent = agent_spec.build_agent() # alternatively this will build the agent agent, agent_interface = registry.make_agent( - locator="__main__:random_lane_control-v0" + locator=f"__main__:{name}" + ) + # just "random_lane_control-v0" also works because the agent has already been registered in this file. + agent, agent_interface = registry.make_agent( + locator=name ) + locator = "zoo.policies:chase-via-points-agent-v0" + # Here is an example of using the module component of the locator to dynamically load agents: + agent, agent_interface = registry.make_agent( + locator=locator + ) + print(f"=== After loading `{locator}` ===") + print(registry.agent_registry) + + + ## This agent requires installation + # agent, agent_interface = registry.make_agent( + # locator="zoo.policies:discrete-soft-actor-critic-agent-v0" + # ) + + locator = "non_existing.module:md-v44" + try: + agent, agent_interface = registry.make_agent( + locator="non_existing.module:md-v44" + ) + except (ModuleNotFoundError, ImportError): + print(f"Such as with '{locator}'. Module resolution can fail if the module cannot be found " + "from the PYTHONPATH environment variable apparent as `sys.path` in python.") + + if __name__ == "__main__": parser = empty_parser(Path(__file__).stem) diff --git a/examples/6_experiment_base.py b/examples/6_experiment_base.py index b1c66bc9c4..61e1c016ea 100644 --- a/examples/6_experiment_base.py +++ b/examples/6_experiment_base.py @@ -84,7 +84,7 @@ class ExperimentCfg: config_name="experiment_default", version_base=None, ) -def main(experiment_config: ExperimentCfg) -> None: +def experiment_main(experiment_config: ExperimentCfg) -> None: typed_experiment_config: ExperimentCfg = OmegaConf.to_object(cfg=experiment_config) print(f"Loading configuration from `{CONFIG_LOCATION}`") if typed_experiment_config.show_config: @@ -147,4 +147,4 @@ def main(experiment_config: ExperimentCfg) -> None: if __name__ == "__main__": - main() + experiment_main() diff --git a/smarts/core/utils/class_factory.py b/smarts/core/utils/class_factory.py index 66d10fd48c..13bbab7a84 100644 --- a/smarts/core/utils/class_factory.py +++ b/smarts/core/utils/class_factory.py @@ -109,8 +109,7 @@ def find_factory(self, locator): self._raise_on_invalid_locator(locator) mod_name, _, name = locator.partition(":") - - if name is not None: + if name != "": # There is a module component. try: # Import the module so that the agent may register itself in the index @@ -157,7 +156,7 @@ def __repr__(self) -> str: out = "" for i, name in enumerate(self.index.keys()): out = f"{out}{name.ljust(max_justify)} " - if i % columns == 0: + if i % columns == 0 and len(self.index) != i + 1: out += "\n" out += "\n"