-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpe.py
71 lines (55 loc) · 2.56 KB
/
pe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Defines the Policy Evaluation solver."""
from typing import Optional
from xaddpy.xadd.xadd import DeltaFunctionSubstitution
from pyRDDLGym_symbolic.mdp.policy import Policy
from pyRDDLGym_symbolic.solver.base import SymbolicSolver
class PolicyEvaluation(SymbolicSolver):
"""Policy Evaluation solver."""
def __init__(self, policy: Policy, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy = policy
self.reward: Optional[int] = None
self._embed_policy_to_cpfs()
def _embed_policy_to_cpf(self, cpf: int, policy: Policy) -> int:
"""Embeds the policy into a single CPF."""
cpf_ = cpf
var_set = self.context.collect_vars(cpf)
for a_var in policy.actions:
pi_a = policy[a_var] # pi(a|s).
# Boolean actions.
if a_var in self.mdp.bool_a_vars:
# The decision ID of the boolean variable node.
var_id = self.context._expr_to_id[a_var]
# Make leaf values numbers.
pi_a = self.context.unary_op(pi_a, 'float')
# Marginalize out the boolean action variable.
cpf_ = self.context.apply(cpf_, pi_a, op='prod')
restrict_high = self.context.op_out(cpf_, var_id, op='restrict_high')
restrict_low = self.context.op_out(cpf_, var_id, op='restrict_low')
cpf_ = self.context.apply(restrict_high, restrict_low, op='add')
else:
if a_var not in var_set:
continue
leaf_op = DeltaFunctionSubstitution(
a_var, cpf_, self.context, is_linear=self.mdp.is_linear)
cpf_ = self.context.reduce_process_xadd_leaf(pi_a, leaf_op, [], [])
return cpf_
def _embed_policy_to_cpfs(self):
"""Embeds the policy into the CPFs."""
cpfs = self.mdp.cpfs
policy = self.policy
# Update next state and interm variable CPFs.
for v, cpf in cpfs.items():
cpfs[v] = self._embed_policy_to_cpf(cpf, policy)
# Update the reward CPF.
self.reward = self._embed_policy_to_cpf(self.mdp.reward, policy)
def bellman_backup(self, dd: int) -> int:
"""Performs the policy evaluation Bellman backup.
Args:
dd: The current value function XADD ID to which Bellman back is applied.
Returns:
The new value function XADD ID.
"""
# Regress the value function.
regr = self.regress(dd, reward=self.reward, cpfs=self.mdp.cpfs)
return regr