Skip to content

Question on the code of masked cross entropy loss #11

@LuckyJinging

Description

@LuckyJinging

When I use the Player2Vec algorithm, I am confused by the masked cross entropy loss. mask/tf.reduce_sum(mask) has taken the average of items which are equal to 1. Why does it need to do another global average (tf.reduce_mean(loss)) instead of summing (tf.reduce_sum(loss))?

def masked_softmax_cross_entropy(preds: tf.Tensor, labels: tf.Tensor,
                                 mask: tf.Tensor) -> tf.Tensor:
    """
    Softmax cross-entropy loss with masking.
    :param preds: the last layer logits of the input data
    :param labels: the labels of the input data
    :param mask: the mask for train/val/test data
    """
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
    loss *= mask
    return tf.reduce_mean(loss)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions