You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
k = tf.shape(logits)[-1]
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
i have a question here, why should there be a stop_gradient before y_hard-y.
y_hard comes from equal,as i think,the equal function could be backpropagated just as max function did。
am i wrong?
The text was updated successfully, but these errors were encountered:
def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
k = tf.shape(logits)[-1]
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
i have a question here, why should there be a stop_gradient before y_hard-y.
y_hard comes from equal,as i think,the equal function could be backpropagated just as max function did。
am i wrong?
The text was updated successfully, but these errors were encountered: