-
Notifications
You must be signed in to change notification settings - Fork 438
[feat]: Add support for gpt-oss #949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yeshsurya
wants to merge
6
commits into
linkedin:main
Choose a base branch
from
yeshsurya:yeshwanth/gpt_oss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
25eb714
[Feat]: Adding support for gpt-oss
yeshsurya f2e25e3
[feat]: completing test invocation
yeshsurya a9e7ff5
Merge branch 'linkedin:main' into yeshwanth/gpt_oss
yeshsurya 6eef643
[chrome]: style compliance
yeshsurya 6bc1158
[doc]: Adding to readme
yeshsurya f68437e
[update]: unpack result into tuple
yeshsurya File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from typing import List | ||
| from typing import Optional | ||
| from typing import Union | ||
|
|
||
| import torch | ||
|
|
||
| from transformers.modeling_outputs import MoeCausalLMOutputWithPast | ||
| from transformers.modeling_outputs import MoeModelOutputWithPast | ||
| from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func | ||
|
|
||
| from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss | ||
|
|
||
|
|
||
| def lce_forward( | ||
| self, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_router_logits: Optional[bool] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| logits_to_keep: Union[int, torch.Tensor] = 0, | ||
| skip_logits: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> MoeCausalLMOutputWithPast: | ||
| r""" | ||
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||
|
|
||
| logits_to_keep (`int` or `torch.Tensor`, *optional*): | ||
| If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all | ||
| `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that | ||
| token can save memory, which becomes pretty significant for long sequences or large vocabulary size. | ||
| If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. | ||
| This is useful when using packed tensor format (single dimension for batch and sequence length). | ||
|
|
||
| Returns: | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| >>> from transformers import AutoTokenizer, GptOssForCausalLM | ||
|
|
||
| >>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b") | ||
| >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") | ||
|
|
||
| >>> prompt = "Hey, are you conscious? Can you talk to me?" | ||
| >>> inputs = tokenizer(prompt, return_tensors="pt") | ||
|
|
||
| >>> # Generate | ||
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||
| ```""" | ||
|
|
||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_router_logits = ( | ||
| output_router_logits if output_router_logits is not None else self.config.output_router_logits | ||
| ) | ||
|
|
||
| output_hidden_states = ( | ||
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
|
|
||
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
| outputs: MoeModelOutputWithPast = self.model( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| output_router_logits=output_router_logits, | ||
| cache_position=cache_position, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| hidden_states = outputs.last_hidden_state | ||
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | ||
| kept_hidden_states = hidden_states[:, slice_indices, :] | ||
|
|
||
| shift_labels = kwargs.pop("shift_labels", None) | ||
| logits = None | ||
| loss = None | ||
|
|
||
| if skip_logits is None: | ||
| skip_logits = self.training and (labels is not None or shift_labels is not None) | ||
|
|
||
| if skip_logits: | ||
| loss = LigerForCausalLMLoss( | ||
yeshsurya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| hidden_states=kept_hidden_states, | ||
| lm_head_weight=self.lm_head.weight, | ||
| labels=labels, | ||
| shift_labels=shift_labels, | ||
| hidden_size=self.config.hidden_size, | ||
| **kwargs, | ||
| ) | ||
| else: # if in inference model materialize logits | ||
| logits = self.lm_head(kept_hidden_states) | ||
| if labels is not None or shift_labels is not None: | ||
| loss = self.loss_function( | ||
| logits=logits, | ||
| labels=labels, | ||
| shift_labels=shift_labels, | ||
| vocab_size=self.vocab_size, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| aux_loss = None | ||
| if output_router_logits: | ||
| aux_loss = load_balancing_loss_func( | ||
| outputs.router_logits, | ||
| self.num_experts, | ||
| self.num_experts_per_tok, | ||
| attention_mask, | ||
| ) | ||
| if labels is not None: | ||
| loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device | ||
|
|
||
| return MoeCausalLMOutputWithPast( | ||
| loss=loss, | ||
| aux_loss=aux_loss, | ||
| logits=logits, | ||
| past_key_values=outputs.past_key_values, | ||
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| router_logits=outputs.router_logits, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not seem to match the output of this function