Skip to content

Commit b999e14

Browse files
authored
fix: matrax env metrics when using global state (#1174)
1 parent 003a0e0 commit b999e14

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

mava/wrappers/matrax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def modify_timestep(
4343
self, timestep: TimeStep
4444
) -> TimeStep[Union[Observation, ObservationGlobalState]]:
4545
"""Modify the timestep for `step` and `reset`."""
46+
metrics: Dict[str, Any] = {"env_metrics": {}}
47+
4648
obs_data = {
4749
"agents_view": timestep.observation.agent_obs,
4850
"action_mask": self.action_mask,
@@ -52,9 +54,8 @@ def modify_timestep(
5254
global_state = jnp.concatenate(timestep.observation.agent_obs, axis=0)
5355
global_state = jnp.tile(global_state, (self.num_agents, 1))
5456
obs_data["global_state"] = global_state
55-
return timestep.replace(observation=ObservationGlobalState(**obs_data))
57+
return timestep.replace(observation=ObservationGlobalState(**obs_data), extras=metrics)
5658

57-
metrics: Dict[str, Any] = {"env_metrics": {}}
5859
return timestep.replace(observation=Observation(**obs_data), extras=metrics)
5960

6061
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:

0 commit comments

Comments
 (0)