-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmpc_controller.py
115 lines (86 loc) · 4.1 KB
/
mpc_controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import time
from decimal import Decimal
from typing import List
from scipy.integrate import quad
from scipy.optimize import minimize
from .controller import ControlAction
from .objective.control_objective import ControlObjective
from .controller import Controller
from .problem.mpc_problem import MPCProblem
from ..model.model import Model
from ..simulation.world_state import WorldState
import numpy as np
def current_milli_time():
return time.time_ns() // 1_000_000
class TookTooLong(Exception):
def __init__(self, x):
self.x = x
class MinimizeStopper(object):
def __init__(self, max_sec):
self.max_milli = max_sec * 1000
self.start = current_milli_time()
def __call__(self, xk):
elapsed = current_milli_time() - self.start
if elapsed > self.max_milli:
raise TookTooLong(xk)
class MPCController(Controller):
def __init__(self, mpc_problem: MPCProblem, model: Model, fps: int = 1):
super().__init__(mpc_problem, fps)
self.mpc_problem = mpc_problem
self.model = model
def calculate_control_actions(self, time_delta: Decimal, latest_world: WorldState) -> List[ControlAction]:
logging.debug("================================= Controller starting step.")
initial_guess = [float(v) for k, v in latest_world.variables.items() if k in latest_world.mvs]
logging.debug("Optimization initial guess %s", initial_guess)
# min objective function
result = None
try:
returned = minimize(cost_function, np.array(initial_guess),
args=(latest_world, self.model, self.mpc_problem),
tol=0.1,
options={"maxiter": 15}, callback=MinimizeStopper(0.5)) # todo specify constraints
logging.debug("Optimization result %s", returned)
result = returned.x.astype(Decimal)
except TookTooLong as e:
result = e.x.astype(Decimal)
# print("\n\nOptimization result {}\n\n".format(returned))
# TODO simulate, validating constraints are not violated
new_mvs = dict(zip(latest_world.mvs, result))
logging.debug("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Controller step done.")
return [ControlAction(k, v) for k, v in new_mvs.items()]
def cost_function(mv_values: np.ndarray, latest_world: WorldState, model: Model,
mpc_problem: MPCProblem):
new_mvs = dict(zip(latest_world.mvs, mv_values.astype(Decimal)))
logging.debug("\t ========= Cost function being computed for mvs: %s =========", new_mvs)
# Create world state with same cvs, but different mvs
updated_control = latest_world.apply_assignment(new_mvs)
logging.debug("\tupdated world: \n %s", updated_control)
value = evaluate_world_state(updated_control, model, mpc_problem)
logging.debug("\tcost of world: %s", value)
return value
def evaluate_world_state(world_state: WorldState, model: Model, mpc_problem: MPCProblem):
"""
Evaluates a proposed world state, to see how close to objective we are. Higher is worse
:param mpc_problem:
:param model:
:param world_state:
:return:
"""
flags = mpc_problem.active_flags
weights = mpc_problem.weights
hz = mpc_problem.optimisation_horizon
logging.debug("\tEvaluating world.")
obj = 0
for cv in world_state.cvs:
logging.debug("\t\tCV: %s", cv)
control_objective = mpc_problem.control_objectives[cv]
integration = quad(f, 0, hz, args=(control_objective, model, world_state))[0]
logging.debug("\t\tIntegration value: %s", integration)
obj += float(int(flags[cv])) * float(weights[cv]) * float(integration)
return obj
def f(t: float, control_objective: ControlObjective, model: Model, world_state: WorldState):
predicted_world = model.progress(Decimal(t), world_state)
distance = control_objective.distance_until_satisfied(predicted_world)
logging.debug("\t\tPredicted world as a result has a distance of %s after %s seconds", distance, t)
return max(0, distance ** 2)