@@ -108,7 +108,7 @@ def act(
108
108
obs_dict : dict [AgentName , Observation ],
109
109
deterministic : bool = False ,
110
110
get_value : bool = False ,
111
- state_dict : dict [AgentName , tuple ] = None ,
111
+ state_dict : Optional [ dict [AgentName , tuple ] ] = None ,
112
112
):
113
113
if len (obs_dict ) == 0 :
114
114
return {}, {}, {}
@@ -174,7 +174,7 @@ def evaluate(
174
174
self ,
175
175
obs_batch : dict [PolicyName , Observation ],
176
176
action_batch : dict [PolicyName , Action ],
177
- state : dict [PolicyName , tuple ] = None ,
177
+ state : Optional [ dict [PolicyName , tuple ] ] = None ,
178
178
) -> dict [PolicyName , Tuple [Tensor , Tensor , Tensor ]]:
179
179
180
180
obs = obs_batch [self .policy_name ]
@@ -381,13 +381,13 @@ def value(
381
381
family_actions , crowd_actions = split_dict (action_batch )
382
382
383
383
family_obs , family_keys = pack (family_obs )
384
- family_values = self .family_agent .value (family_obs )
384
+ family_values = self .family_agent .value (family_obs , () )
385
385
386
386
augment_observations (crowd_obs , family_actions )
387
387
388
388
crowd_obs , crowd_keys = pack (crowd_obs )
389
389
390
- crowd_values = self .agent .value (crowd_obs )
390
+ crowd_values = self .agent .value (crowd_obs , () )
391
391
392
392
crowd_values = unpack (crowd_values , crowd_keys )
393
393
family_values = unpack (family_values , family_keys )
0 commit comments