Skip to content

SimCTG BART training #27

@rahulseetharaman

Description

@rahulseetharaman

Hi @yxuansu, thanks for the wonderful library

I am trying to use SimCTG framework to train a BART model for a question generation task. I am facing the following issue in trying to train a BART model with SimCTG loss.

  File "experiment-5/simctg_train.py", line 224, in <module>
    train()
  File "experiment-5/simctg_train.py", line 122, in train
    mle_loss, cl_loss = simctgloss(last_hidden_states=last_hidden_states, logits=logits,
  File "/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "simctg/lossfunction.py", line 93, in forward
    assert labels.size() == input_ids.size()
AssertionError

While looking at the loss function, I did realize why this happens. Is the loss function designed to support only decoder only models like GPT for example ? How to adapt it for BART and T5 ? For bart and t5 the assertion that input ids and labels dimensions are the same need not hold.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions