Skip to content

Commit 3d96dcd

Browse files
nrontsiskyr-pol
andauthored
Update to GPflow v2 and include safety extensions (#45)
* No special subsequent calls to scipy optimiser * Added safety specific code in safe_pilco folder * Added master_alg, experiments and post_process for more systematic evaluation (still under development) * 4 first experiments run successflly * Linear Cars done too * Syncing utils * Cleaned main repo * Re-enabled rendering in examples, deleted a few unecessary things * Running safe_cars when run direclty, deleted bash script * Updated requirements/minor fixes * Updated mgpr to 2.0 * Updated smgpr, test_sparse_predictions passes * Test cascade passes, updated pilco, rewards and controllers * Fixed test_rewards * Fixed controllers and test_controllers * Fixed rbf controller bug. Mountain car and inverted pendulum run successfully with occasional numerical errors * Got rid of the autoflow wrapper functions which are not useful anymore * Updated inverted pendulum and pendulum swingup examples. Fixed noise variance for RBF controllers. Modified mgpr priors * Updated swimmer and double pendulum. Fix in rbf controllers. * Minor fixes, ready to test all plain pilco envs * Small updates to improve numerical stability and identify operations that could cause falures * Updated safe pilco - linear cars seem to work fine now * Cleaned up unecessary comments and logging * Updated requirements * Adding matplotlib dependecy apparently required by Tensorflow. see travis output on this branch before this commit. * Remove comments * Various fixes * Change coverage settings * Attempt to cleaup swimmer * Remove safe pilco include * Remove self.t from LinearController * Move examples, utils to examples folder. * Remove utils import * Updated imports in safe pilco examples. renamed safe-pilco-extension to safe_pilco_extension * read.me update for the new version * Update READMEs * Update README Co-authored-by: kyr-pol <[email protected]>
1 parent 7125590 commit 3d96dcd

29 files changed

+994
-501
lines changed

.coveragerc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[report]
2-
omit = *tests*, *examples*, setup.py
2+
omit = *tests*, *examples*, setup.py, *safe-pilco-extension*
33
exclude_lines =
44
pragma: no cover
55
raise AssertionError

README.md

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,37 @@
22
[![Build Status](https://travis-ci.org/nrontsis/PILCO.svg?branch=master)](https://travis-ci.org/nrontsis/PILCO)
33
[![codecov](https://codecov.io/gh/nrontsis/PILCO/branch/master/graph/badge.svg)](https://codecov.io/gh/nrontsis/PILCO)
44

5-
A modern \& clean implementation of the [PILCO](https://ieeexplore.ieee.org/abstract/document/6654139/) Algorithm in `TensorFlow`.
5+
A modern \& clean implementation of the [PILCO](https://ieeexplore.ieee.org/abstract/document/6654139/) Algorithm in `TensorFlow v2`.
66

77
Unlike PILCO's [original implementation](http://mlg.eng.cam.ac.uk/pilco/) which was written as a self-contained package of `MATLAB`, this repository aims to provide a clean implementation by heavy use of modern machine learning libraries.
88

9-
In particular, we use `TensorFlow` to avoid the need for hardcoded gradients and scale to GPU architectures. Moreover, we use [`GPflow`](https://github.com/GPflow/GPflow) for Gaussian Process Regression.
9+
In particular, we use `TensorFlow v2` to avoid the need for hardcoded gradients and scale to GPU architectures. Moreover, we use [`GPflow v2`](https://github.com/GPflow/GPflow) for Gaussian Process Regression.
1010

1111
The core functionality is tested against the original `MATLAB` implementation.
1212

1313
## Example of usage
14-
Before using, or installing, PILCO, you need to have `Tensorflow 1.13.1` installed (either the gpu or the cpu version). It is recommended to install everything in a fresh `conda` environment with `python>=3.7`. Given `Tensorflow`, PILCO can be installed as follows
14+
Before using `PILCO` you have to install it by running:
1515
```
1616
git clone https://github.com/nrontsis/PILCO && cd PILCO
1717
python setup.py develop
1818
```
19+
It is recommended to install everything in a fresh conda environment with `python>=3.7`
1920

20-
The examples included in this repo use [`OpenAI gym 0.15.3`](https://github.com/openai/gym#installation) and [`mujoco-py 2.0.2.7`](https://github.com/openai/mujoco-py#install-mujoco). Once these dependencies are installed, you can run one of the examples as follows
21+
The examples included in this repo use [`OpenAI gym 0.15.3`](https://github.com/openai/gym#installation) and [`mujoco-py 2.0.2.7`](https://github.com/openai/mujoco-py#install-mujoco). Theses dependecies should be installed manually. Then, you can run one of the examples as follows
2122
```
2223
python examples/inverted_pendulum.py
2324
```
24-
While running an example, `Tensorflow` might print a lot of warnings, [some of which are deprecated](https://github.com/tensorflow/tensorflow/issues/25996). If necessary, you can suppress them by running
25-
```python
26-
tf.logging.set_verbosity(tf.logging.ERROR)
27-
```
28-
right after including `TensorFlow` in Python.
25+
26+
## Example Extension: Safe PILCO
27+
As an example of the extensibility of the framework, we include in the folder `safe_pilco_extension` an extension of the standard PILCO algorithm that takes safety constraints (defined on the environment's state space) into account as in [https://arxiv.org/abs/1712.05556](https://arxiv.org/pdf/1712.05556.pdf). The `safe_swimmer_run.py` and `safe_cars_run.py` in the `examples` folder demonstrate the use of this extension.
2928

3029
## Credits:
3130

3231
The following people have been involved in the development of this package:
3332
* [Nikitas Rontsis](https://github.com/nrontsis)
34-
* [Kyriakos Polymenakos](https://github.com/kyr-pol)
33+
* [Kyriakos Polymenakos](https://github.com/kyr-pol/)
3534

3635
## References
3736

38-
See the following publications for a description of the algorithm: [1](https://ieeexplore.ieee.org/abstract/document/6654139/), [2](http://mlg.eng.cam.ac.uk/pub/pdf/DeiRas11.pdf),
39-
[3](https://pdfs.semanticscholar.org/c9f2/1b84149991f4d547b3f0f625f710750ad8d9.pdf)
40-
37+
See the following publications for a description of the algorithm: [1](https://ieeexplore.ieee.org/abstract/document/6654139/), [2](http://mlg.eng.cam.ac.uk/pub/pdf/DeiRas11.pdf),
38+
[3](https://pdfs.semanticscholar.org/c9f2/1b84149991f4d547b3f0f625f710750ad8d9.pdf)

examples/inv_double_pendulum.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from pilco.controllers import RbfController, LinearController
66
from pilco.rewards import ExponentialReward
77
import tensorflow as tf
8-
from tensorflow import logging
98
from utils import rollout, policy
9+
from gpflow import set_trainable
1010
np.random.seed(0)
1111

1212
# Introduces a simple wrapper for the gym environment
@@ -29,7 +29,7 @@ def state_trans(self, s):
2929

3030
def step(self, action):
3131
ob, r, done, _ = self.env.step(action)
32-
if np.abs(ob[0])> 0.98 or np.abs(ob[-3]) > 0.1 or np.abs(ob[-2]) > 0.1 or np.abs(ob[-1]) > 0.1:
32+
if np.abs(ob[0])> 0.90 or np.abs(ob[-3]) > 0.15 or np.abs(ob[-2]) > 0.15 or np.abs(ob[-1]) > 0.15:
3333
done = True
3434
return self.state_trans(ob), r, done, {}
3535

@@ -41,32 +41,32 @@ def render(self):
4141
self.env.render()
4242

4343

44-
SUBS = 1
45-
bf = 40
46-
maxiter=80
47-
state_dim = 6
48-
control_dim = 1
49-
max_action=1.0 # actions for these environments are discrete
50-
target = np.zeros(state_dim)
51-
weights = 3.0 * np.eye(state_dim)
52-
weights[0,0] = 0.5
53-
weights[3,3] = 0.5
54-
m_init = np.zeros(state_dim)[None, :]
55-
S_init = 0.01 * np.eye(state_dim)
56-
T = 40
57-
J = 1
58-
N = 12
59-
T_sim = 130
60-
restarts=True
61-
lens = []
62-
63-
with tf.Session() as sess:
44+
if __name__=='__main__':
45+
SUBS = 1
46+
bf = 40
47+
maxiter=10
48+
state_dim = 6
49+
control_dim = 1
50+
max_action=1.0 # actions for these environments are discrete
51+
target = np.zeros(state_dim)
52+
weights = 5.0 * np.eye(state_dim)
53+
weights[0,0] = 1.0
54+
weights[3,3] = 1.0
55+
m_init = np.zeros(state_dim)[None, :]
56+
S_init = 0.005 * np.eye(state_dim)
57+
T = 40
58+
J = 5
59+
N = 12
60+
T_sim = 130
61+
restarts=True
62+
lens = []
63+
6464
env = DoublePendWrapper()
6565

6666
# Initial random rollouts to generate a dataset
67-
X,Y = rollout(env, None, timesteps=T, random=True, SUBS=SUBS)
67+
X, Y, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, render=True)
6868
for i in range(1,J):
69-
X_, Y_ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True)
69+
X_, Y_, _, _ = rollout(env, None, timesteps=T, random=True, SUBS=SUBS, verbose=True, render=True)
7070
X = np.vstack((X, X_))
7171
Y = np.vstack((Y, Y_))
7272

@@ -77,21 +77,19 @@ def render(self):
7777

7878
R = ExponentialReward(state_dim=state_dim, t=target, W=weights)
7979

80-
pilco = PILCO(X, Y, controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init)
80+
pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init)
8181

8282
# for numerical stability
8383
for model in pilco.mgpr.models:
84-
# model.kern.lengthscales.prior = gpflow.priors.Gamma(1,10) priors have to be included before
85-
# model.kern.variance.prior = gpflow.priors.Gamma(1.5,2) before the model gets compiled
86-
model.likelihood.variance = 0.001
87-
model.likelihood.variance.trainable = False
84+
model.likelihood.variance.assign(0.001)
85+
set_trainable(model.likelihood.variance, False)
8886

8987
for rollouts in range(N):
9088
print("**** ITERATION no", rollouts, " ****")
9189
pilco.optimize_models(maxiter=maxiter, restarts=2)
9290
pilco.optimize_policy(maxiter=maxiter, restarts=2)
9391

94-
X_new, Y_new = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS)
92+
X_new, Y_new, _, _ = rollout(env, pilco, timesteps=T_sim, verbose=True, SUBS=SUBS, render=True)
9593

9694
# Since we had decide on the various parameters of the reward function
9795
# we might want to verify that it behaves as expected by inspection
@@ -102,7 +100,7 @@ def render(self):
102100

103101
# Update dataset
104102
X = np.vstack((X, X_new[:T, :])); Y = np.vstack((Y, Y_new[:T, :]))
105-
pilco.mgpr.set_XY(X, Y)
103+
pilco.mgpr.set_data((X, Y))
106104

107105
lens.append(len(X_new))
108106
print(len(X_new))

examples/inverted_pendulum.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,36 @@
44
from pilco.controllers import RbfController, LinearController
55
from pilco.rewards import ExponentialReward
66
import tensorflow as tf
7-
from tensorflow import logging
7+
from gpflow import set_trainable
8+
# from tensorflow import logging
89
np.random.seed(0)
910

1011
from utils import rollout, policy
1112

12-
with tf.Session(graph=tf.Graph()) as sess:
13-
env = gym.make('InvertedPendulum-v2')
14-
# Initial random rollouts to generate a dataset
15-
X,Y = rollout(env=env, pilco=None, random=True, timesteps=40)
16-
for i in range(1,3):
17-
X_, Y_ = rollout(env=env, pilco=None, random=True, timesteps=40)
18-
X = np.vstack((X, X_))
19-
Y = np.vstack((Y, Y_))
13+
env = gym.make('InvertedPendulum-v2')
14+
# Initial random rollouts to generate a dataset
15+
X,Y, _, _ = rollout(env=env, pilco=None, random=True, timesteps=40, render=True)
16+
for i in range(1,5):
17+
X_, Y_, _, _ = rollout(env=env, pilco=None, random=True, timesteps=40, render=True)
18+
X = np.vstack((X, X_))
19+
Y = np.vstack((Y, Y_))
2020

2121

22-
state_dim = Y.shape[1]
23-
control_dim = X.shape[1] - state_dim
24-
controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=5)
25-
#controller = LinearController(state_dim=state_dim, control_dim=control_dim)
22+
state_dim = Y.shape[1]
23+
control_dim = X.shape[1] - state_dim
24+
controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=10)
25+
# controller = LinearController(state_dim=state_dim, control_dim=control_dim)
2626

27-
pilco = PILCO(X, Y, controller=controller, horizon=40)
28-
# Example of user provided reward function, setting a custom target state
29-
# R = ExponentialReward(state_dim=state_dim, t=np.array([0.1,0,0,0]))
30-
# pilco = PILCO(X, Y, controller=controller, horizon=40, reward=R)
27+
pilco = PILCO((X, Y), controller=controller, horizon=40)
28+
# Example of user provided reward function, setting a custom target state
29+
# R = ExponentialReward(state_dim=state_dim, t=np.array([0.1,0,0,0]))
30+
# pilco = PILCO(X, Y, controller=controller, horizon=40, reward=R)
3131

32-
# Example of fixing a parameter, optional, for a linear controller only
33-
#pilco.controller.b = np.array([[0.0]])
34-
#pilco.controller.b.trainable = False
35-
36-
for rollouts in range(3):
37-
pilco.optimize_models()
38-
pilco.optimize_policy()
39-
import pdb; pdb.set_trace()
40-
X_new, Y_new = rollout(env=env, pilco=pilco, timesteps=100)
41-
print("No of ops:", len(tf.get_default_graph().get_operations()))
42-
# Update dataset
43-
X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new))
44-
pilco.mgpr.set_XY(X, Y)
32+
for rollouts in range(3):
33+
pilco.optimize_models()
34+
pilco.optimize_policy()
35+
import pdb; pdb.set_trace()
36+
X_new, Y_new, _, _ = rollout(env=env, pilco=pilco, timesteps=100, render=True)
37+
# Update dataset
38+
X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new))
39+
pilco.mgpr.set_data((X, Y))

examples/linear_cars_env.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
from gym import spaces
3+
from gym.core import Env
4+
5+
class LinearCars(Env):
6+
def __init__(self):
7+
self.action_space = spaces.Box(low=-0.4, high=0.4, shape=(1,))
8+
self.observation_space = spaces.Box(low=-100, high=100, shape=(4,))
9+
self.M = 1 # car mass [kg]
10+
self.b = 0.001 # friction coef [N/m/s]
11+
self.Dt = 0.50 # timestep [s]
12+
13+
self.A = np.array([[0, self.Dt, 0, 0],
14+
[0, -self.b*self.Dt/self.M, 0, 0],
15+
[0, 0, 0, self.Dt],
16+
[0, 0, 0, 0]])
17+
18+
self.B = np.array([0,self.Dt/self.M, 0, 0]).reshape((4,1))
19+
20+
self.initial_state = np.array([-6.0, 1.0, -5.0, 1.0]).reshape((4,1))
21+
22+
def step(self, action):
23+
self.state += self.A @ self.state + self.B * action
24+
#0.1 * np.random.normal(scale=[[1e-3], [1e-3], [1e-3], [0.001]], size=(4,1))
25+
26+
if self.state[0] < 0:
27+
reward = -1
28+
else:
29+
reward = 1
30+
return np.reshape(self.state[:], (4,)), reward, False, None
31+
32+
def reset(self):
33+
self.state = self.initial_state + 0.03 * np.random.normal(size=(4,1))
34+
return np.reshape(self.state[:], (4,))

examples/mountain_car.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import gym
3+
from pilco.models import PILCO
4+
from pilco.controllers import RbfController, LinearController
5+
from pilco.rewards import ExponentialReward
6+
import tensorflow as tf
7+
np.random.seed(0)
8+
from utils import policy, rollout, Normalised_Env
9+
10+
11+
SUBS = 5
12+
T = 25
13+
env = gym.make('MountainCarContinuous-v0')
14+
# Initial random rollouts to generate a dataset
15+
X1,Y1, _, _ = rollout(env=env, pilco=None, random=True, timesteps=T, SUBS=SUBS, render=True)
16+
for i in range(1,5):
17+
X1_, Y1_,_,_ = rollout(env=env, pilco=None, random=True, timesteps=T, SUBS=SUBS, render=True)
18+
X1 = np.vstack((X1, X1_))
19+
Y1 = np.vstack((Y1, Y1_))
20+
env.close()
21+
22+
env = Normalised_Env('MountainCarContinuous-v0', np.mean(X1[:,:2],0), np.std(X1[:,:2], 0))
23+
X = np.zeros(X1.shape)
24+
X[:, :2] = np.divide(X1[:, :2] - np.mean(X1[:,:2],0), np.std(X1[:,:2], 0))
25+
X[:, 2] = X1[:,-1] # control inputs are not normalised
26+
Y = np.divide(Y1 , np.std(X1[:,:2], 0))
27+
28+
state_dim = Y.shape[1]
29+
control_dim = X.shape[1] - state_dim
30+
m_init = np.transpose(X[0,:-1,None])
31+
S_init = 0.5 * np.eye(state_dim)
32+
controller = RbfController(state_dim=state_dim, control_dim=control_dim, num_basis_functions=25)
33+
34+
R = ExponentialReward(state_dim=state_dim,
35+
t=np.divide([0.5,0.0] - env.m, env.std),
36+
W=np.diag([0.5,0.1])
37+
)
38+
pilco = PILCO((X, Y), controller=controller, horizon=T, reward=R, m_init=m_init, S_init=S_init)
39+
40+
best_r = 0
41+
all_Rs = np.zeros((X.shape[0], 1))
42+
for i in range(len(all_Rs)):
43+
all_Rs[i,0] = R.compute_reward(X[i,None,:-1], 0.001 * np.eye(state_dim))[0]
44+
45+
ep_rewards = np.zeros((len(X)//T,1))
46+
47+
for i in range(len(ep_rewards)):
48+
ep_rewards[i] = sum(all_Rs[i * T: i*T + T])
49+
50+
for model in pilco.mgpr.models:
51+
model.likelihood.variance.assign(0.05)
52+
set_trainable(model.likelihood.variance, False)
53+
54+
r_new = np.zeros((T, 1))
55+
for rollouts in range(5):
56+
pilco.optimize_models()
57+
pilco.optimize_policy(maxiter=100, restarts=3)
58+
import pdb; pdb.set_trace()
59+
X_new, Y_new,_,_ = rollout(env=env, pilco=pilco, timesteps=T, SUBS=SUBS, render=True)
60+
61+
for i in range(len(X_new)):
62+
r_new[:, 0] = R.compute_reward(X_new[i,None,:-1], 0.001 * np.eye(state_dim))[0]
63+
total_r = sum(r_new)
64+
_, _, r = pilco.predict(m_init, S_init, T)
65+
66+
print("Total ", total_r, " Predicted: ", r)
67+
X = np.vstack((X, X_new)); Y = np.vstack((Y, Y_new));
68+
all_Rs = np.vstack((all_Rs, r_new)); ep_rewards = np.vstack((ep_rewards, np.reshape(total_r,(1,1))))
69+
pilco.mgpr.set_data((X, Y))

0 commit comments

Comments
 (0)