diff --git a/mdn/__init__.py b/mdn/__init__.py index 25719db..b568856 100644 --- a/mdn/__init__.py +++ b/mdn/__init__.py @@ -62,9 +62,9 @@ def non_trainable_weights(self): def call(self, x, mask=None): with tf.name_scope('MDN'): mdn_out = layers.concatenate([self.mdn_mus(x), - self.mdn_sigmas(x), - self.mdn_pi(x)], - name='mdn_outputs') + self.mdn_sigmas(x), + self.mdn_pi(x)], + name='mdn_outputs') return mdn_out def compute_output_shape(self, input_shape): @@ -245,8 +245,8 @@ def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0): # m = np.random.choice(range(len(pis)), p=pis) mus_vector = mus[m * output_dim:(m + 1) * output_dim] sig_vector = sigs[m * output_dim:(m + 1) * output_dim] - scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag - cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared. - cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature + scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag + cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared. + cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1) return sample