Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Oct 16, 2024
1 parent 296a4c8 commit 79f0cc6
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 50 deletions.
179 changes: 134 additions & 45 deletions src/transformers/models/roberta/modular_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
BertAttention,
BertEmbeddings,
BertEncoder,
BertForMultipleChoice,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertIntermediate,
BertLayer,
BertModel,
Expand All @@ -50,9 +54,6 @@
BertSdpaSelfAttention,
BertSelfAttention,
BertSelfOutput,
BertForMultipleChoice,
BertForTokenClassification,
BertForQuestionAnswering,
)
from .configuration_roberta import RobertaConfig

Expand Down Expand Up @@ -374,27 +375,6 @@ def forward(
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
Expand Down Expand Up @@ -529,24 +509,14 @@ def _tie_weights(self):
self.bias = self.decoder.bias


@add_start_docstrings(
"""
RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
ROBERTA_START_DOCSTRING,
)
class RobertaForSequenceClassification(RobertaPreTrainedModel):
class RobertaForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config

self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)

# Initialize weights and apply final processing
self.post_init()
del classifier_dropout # noqa: F821
del self.dropout

@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -631,14 +601,84 @@ def forward(
class RobertaForMultipleChoice(BertForMultipleChoice):
def __init__(self, config):
super().__init__(config)
self.roberta = RobertaModel(config)
del classifier_dropout # noqa: F821
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
del classifier_dropout

@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)

outputs = self.roberta(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

class RobertaForTokenClassification(BertForTokenClassification):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)

if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class RobertaForTokenClassification(BertForTokenClassification):
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="Jean-Baptiste/roberta-large-ner-english",
Expand All @@ -647,8 +687,59 @@ class RobertaForTokenClassification(BertForTokenClassification):
expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
expected_loss=0.01,
)
def forward(**super_kwargs) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
super().forward()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class RobertaClassificationHead(nn.Module):
Expand All @@ -674,7 +765,6 @@ def forward(self, features, **kwargs):


class RobertaForQuestionAnswering(BertForQuestionAnswering):

@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint="deepset/roberta-base-squad2",
Expand All @@ -685,4 +775,3 @@ class RobertaForQuestionAnswering(BertForQuestionAnswering):
)
def forward(**super_kwargs) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
super().forward()

16 changes: 11 additions & 5 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def update_body(self, existing_body, new_statements):
break

return deduplicated_new_body

def _fix_init_location(self, new_body):
"""Fix the location of the super()__init__ in the new body, if we had new statements before it."""
start_index = 0
Expand Down Expand Up @@ -472,7 +472,7 @@ def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CS
for i, expr in enumerate(node.body):
if is_call_to_super(expr, func_name):
has_super_call = True
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i+1:]))
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
new_body = self._fix_init_location(new_body)
else:
expr = expr.visit(self.transformer)
Expand Down Expand Up @@ -565,7 +565,9 @@ def replace_call_to_super(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
# Keep decorators in `modular_xxx.py` if any, else original decorators
new_decorators = updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
class_finder.python_module.code_for_node(updated_methods[name]),
Expand Down Expand Up @@ -723,7 +725,9 @@ def visit_ClassDef(self, node):

def visit_SimpleStatementLine(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
simple_top_level_assign_structure = m.SimpleStatementLine(body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])])
simple_top_level_assign_structure = m.SimpleStatementLine(
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
)
if m.matches(parent_node, m.Module()) and m.matches(node, simple_top_level_assign_structure):
self.top_level_functions_classes_assignments[node.body[0].targets[0].target.value] = node

Expand All @@ -746,7 +750,9 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module):
# Find any class/function/assignment that was mistakenly added as part of the dependencies and remove it
unused = self.added_dependencies - self.all_used_functions_classes_assignments
nodes_to_remove = [
self.top_level_functions_classes_assignments[name] for name in unused if name in self.top_level_functions_classes_assignments
self.top_level_functions_classes_assignments[name]
for name in unused
if name in self.top_level_functions_classes_assignments
]
new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove]
# Return a new module with the updated body
Expand Down

0 comments on commit 79f0cc6

Please sign in to comment.