File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -43,6 +43,8 @@ def modify_timestep(
43
43
self , timestep : TimeStep
44
44
) -> TimeStep [Union [Observation , ObservationGlobalState ]]:
45
45
"""Modify the timestep for `step` and `reset`."""
46
+ metrics : Dict [str , Any ] = {"env_metrics" : {}}
47
+
46
48
obs_data = {
47
49
"agents_view" : timestep .observation .agent_obs ,
48
50
"action_mask" : self .action_mask ,
@@ -52,9 +54,8 @@ def modify_timestep(
52
54
global_state = jnp .concatenate (timestep .observation .agent_obs , axis = 0 )
53
55
global_state = jnp .tile (global_state , (self .num_agents , 1 ))
54
56
obs_data ["global_state" ] = global_state
55
- return timestep .replace (observation = ObservationGlobalState (** obs_data ))
57
+ return timestep .replace (observation = ObservationGlobalState (** obs_data ), extras = metrics )
56
58
57
- metrics : Dict [str , Any ] = {"env_metrics" : {}}
58
59
return timestep .replace (observation = Observation (** obs_data ), extras = metrics )
59
60
60
61
def reset (self , key : chex .PRNGKey ) -> Tuple [State , TimeStep ]:
You can’t perform that action at this time.
0 commit comments