Skip to content

Commit 8e3bacb

Browse files
committed
Type fixes
1 parent 72a7385 commit 8e3bacb

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

coltra/groups.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def act(
108108
obs_dict: dict[AgentName, Observation],
109109
deterministic: bool = False,
110110
get_value: bool = False,
111-
state_dict: dict[AgentName, tuple] = None,
111+
state_dict: Optional[dict[AgentName, tuple]] = None,
112112
):
113113
if len(obs_dict) == 0:
114114
return {}, {}, {}
@@ -174,7 +174,7 @@ def evaluate(
174174
self,
175175
obs_batch: dict[PolicyName, Observation],
176176
action_batch: dict[PolicyName, Action],
177-
state: dict[PolicyName, tuple] = None,
177+
state: Optional[dict[PolicyName, tuple]] = None,
178178
) -> dict[PolicyName, Tuple[Tensor, Tensor, Tensor]]:
179179

180180
obs = obs_batch[self.policy_name]
@@ -381,13 +381,13 @@ def value(
381381
family_actions, crowd_actions = split_dict(action_batch)
382382

383383
family_obs, family_keys = pack(family_obs)
384-
family_values = self.family_agent.value(family_obs)
384+
family_values = self.family_agent.value(family_obs, ())
385385

386386
augment_observations(crowd_obs, family_actions)
387387

388388
crowd_obs, crowd_keys = pack(crowd_obs)
389389

390-
crowd_values = self.agent.value(crowd_obs)
390+
crowd_values = self.agent.value(crowd_obs, ())
391391

392392
crowd_values = unpack(crowd_values, crowd_keys)
393393
family_values = unpack(family_values, family_keys)

coltra/wrappers/agent_wrappers.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def act(
6666
norm_obs = self.normalize_obs(obs_batch)
6767
return self.agent.act(norm_obs, state_batch, deterministic, get_value, **kwargs)
6868

69-
def value(self, obs_batch: Observation, **kwargs) -> Tensor:
70-
return self.agent.value(self.normalize_obs(obs_batch))
69+
def value(
70+
self, obs_batch: Observation, state_batch: tuple = (), **kwargs
71+
) -> tuple[Tensor, tuple]:
72+
return self.agent.value(
73+
self.normalize_obs(obs_batch), state_batch=state_batch, **kwargs
74+
)
7175

7276
def evaluate(
7377
self, obs_batch: Observation, action_batch: Action
@@ -113,9 +117,13 @@ def unnormalize_value(self, value: Tensor):
113117
return self._ret_var * value + self._ret_mean
114118

115119
def value(
116-
self, obs_batch: Observation, real_value: bool = False, **kwargs
117-
) -> Tensor:
118-
value = self.agent.value(obs_batch)
120+
self,
121+
obs_batch: Observation,
122+
state_batch: tuple = (),
123+
real_value: bool = False,
124+
**kwargs,
125+
) -> tuple[Tensor, tuple]:
126+
value, state = self.agent.value(obs_batch, state_batch=state_batch, **kwargs)
119127
if real_value:
120128
value = self.unnormalize_value(value)
121-
return value
129+
return value, state

0 commit comments

Comments
 (0)