Skip to content

Commit fb9835c

Browse files
committed
Add missing docstrings.
1 parent 331025c commit fb9835c

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

smarts/ray/sensors/ray_sensor_resolver.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_ray_worker_actors(self, count: int):
7171
if len(self._current_workers) != count:
7272
# we need to cache because using options(name) is extremely slow
7373
self._current_workers = [
74-
ProcessWorker.options(
74+
RayProcessWorker.options(
7575
name=f"sensor_worker_{i}", get_if_exists=True
7676
).remote()
7777
for i in range(count)
@@ -109,7 +109,7 @@ def observe(
109109
# Start remote tasks
110110
agent_ids_for_grouping = list(agent_ids)
111111
agent_groups = [
112-
agent_ids_for_grouping[i::len_workers] for i in range(len_workers)
112+
frozenset(agent_ids_for_grouping[i::len_workers]) for i in range(len_workers)
113113
]
114114
for i, agent_group in enumerate(agent_groups):
115115
if not agent_group:
@@ -162,14 +162,29 @@ def step(self, sim_frame, sensor_states):
162162

163163

164164
@ray.remote
165-
class ProcessWorker:
165+
class RayProcessWorker:
166+
"""A `ray` based process worker for parallel operation on sensors."""
166167
def __init__(self) -> None:
167168
self._simulation_local_constants: Optional[SimulationLocalConstants] = None
168169

169170
def update_local_constants(self, sim_local_constants):
171+
"""Updates the process worker.
172+
173+
Args:
174+
sim_local_constants (SimulationLocalConstants | None): The current simulation reset state.
175+
"""
170176
self._simulation_local_constants = loads(sim_local_constants)
171177

172178
def do_work(self, remote_sim_frame, agent_ids):
179+
"""Run the sensors against the current simulation state.
180+
181+
Args:
182+
remote_sim_frame (SimulationFrame): The current simulation state.
183+
agent_ids (set[str]): The agent ids to operate on.
184+
185+
Returns:
186+
tuple[dict, dict, dict]: The updated sensor states: (observations, dones, updated_sensors)
187+
"""
173188
sim_frame = loads(remote_sim_frame)
174189
return Sensors.observe_serializable_sensor_batch(
175190
sim_frame, self._simulation_local_constants, agent_ids

0 commit comments

Comments
 (0)