-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathcommon.py
80 lines (62 loc) · 2.77 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import cPickle
import tensorflow as tf
import numpy as np
def save_params(fname, saver, session):
saver.save(session, fname)
def load_er(fname, batch_size, history_length, traj_length):
f = file(fname, 'rb')
er = cPickle.load(f)
er.batch_size = batch_size
er = set_er_stats(er, history_length, traj_length)
return er
def set_er_stats(er, history_length, traj_length):
state_dim = er.states.shape[-1]
action_dim = er.actions.shape[-1]
er.prestates = np.empty((er.batch_size, history_length, state_dim), dtype=np.float32)
er.poststates = np.empty((er.batch_size, history_length, state_dim), dtype=np.float32)
er.traj_states = np.empty((er.batch_size, traj_length, state_dim), dtype=np.float32)
er.traj_actions = np.empty((er.batch_size, traj_length-1, action_dim), dtype=np.float32)
er.states_min = np.min(er.states[:er.count], axis=0)
er.states_max = np.max(er.states[:er.count], axis=0)
er.actions_min = np.min(er.actions[:er.count], axis=0)
er.actions_max = np.max(er.actions[:er.count], axis=0)
er.states_mean = np.mean(er.states[:er.count], axis=0)
er.actions_mean = np.mean(er.actions[:er.count], axis=0)
er.states_std = np.std(er.states[:er.count], axis=0)
er.states_std[er.states_std == 0] = 1
er.actions_std = np.std(er.actions[:er.count], axis=0)
return er
def re_parametrization(state_e, state_a):
nu = state_e - state_a
nu = tf.stop_gradient(nu)
return state_a + nu, nu
def normalize(x, mean, std):
return (x - mean)/std
def denormalize(x, mean, std):
return x * std + mean
def sample_gumbel(shape, eps=1e-20):
"""Sample from Gumbel(0, 1)"""
U = tf.random_uniform(shape,minval=0,maxval=1)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(tf.shape(logits))
return tf.nn.softmax(y / temperature)
def gumbel_softmax(logits, temperature, hard=True):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
k = tf.shape(logits)[-1]
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y