Skip to content

Training conditional transformer  #16

@radiradev

Description

@radiradev

Hello,

I am trying to understand these lines could you further elaborate what is the procedure of training the transformer here?

`# target includes all sequence elements (no need to handle first one
# differently because we are conditioning)
target = z_indices

    # in the case we do not want to encode condition anyhow (e.g. inputs are features)
    if isinstance(self.transformer, (GPTFeats, GPTClass, GPTFeatsClass)):
        # make the prediction
        logits, _, _ = self.transformer(z_indices[:, :-1], c)
        # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
        if isinstance(self.transformer, GPTFeatsClass):
            cond_size = c['feature'].size(-1) + c['target'].size(-1)
        else:
            cond_size = c.size(-1)
        logits = logits[:, cond_size-1:]`

Using the features and all of the indices what exactly are we trying to predict? Isn't the target all the z_indices that we are already giving to the transformer? Or are we just predicting the last z_index given the features and the previous z_indices?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions