Skip to content

Commit

Permalink
small style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Nov 4, 2019
1 parent be99ac7 commit 535f405
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mdn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def non_trainable_weights(self):
def call(self, x, mask=None):
with tf.name_scope('MDN'):
mdn_out = layers.concatenate([self.mdn_mus(x),
self.mdn_sigmas(x),
self.mdn_pi(x)],
name='mdn_outputs')
self.mdn_sigmas(x),
self.mdn_pi(x)],
name='mdn_outputs')
return mdn_out

def compute_output_shape(self, input_shape):
Expand Down Expand Up @@ -245,8 +245,8 @@ def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):
# m = np.random.choice(range(len(pis)), p=pis)
mus_vector = mus[m * output_dim:(m + 1) * output_dim]
sig_vector = sigs[m * output_dim:(m + 1) * output_dim]
scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag
cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.
cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature
scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag
cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.
cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
return sample

0 comments on commit 535f405

Please sign in to comment.