Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modular_model_converter bugfix on assignments #35642

Merged
merged 3 commits into from
Jan 21, 2025

Conversation

nikosanto13
Copy link
Contributor

What does this PR do?

This PR improves the logic of modular_model_converter.py script in order to keep the assignments from modular files.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker, @Cyrilvallez

Additional details

Besides the changes on utils/modular_model_converter.py script:

  • Regenerated modeling_*.py files which resolved some bugs that affected docstrings, e.g IJEPA's _IMAGE_CLASS checkpoint and EXPECTED_OUTPUT and phi's checkpoint (which was inherited from Mistral, since the corresponding modular file was lacking a checkpoint definition)
  • Had to copy-paste the mistral docstring in the modular_starcoder2.py. Apparently, this is needed because the forward pass of modular-defined Starcoder2Model is decorated with STARCODER2_INPUTS_DOCSTRING. Copy-pasting is somewhat a contradiction to the "modular" logic, but if I'm not mistaken there wasn't an easier solution.

@nikosanto13 nikosanto13 force-pushed the modular-bugfix-assignments branch 2 times, most recently from 7be3ec8 to 2580dc6 Compare January 13, 2025 15:36
@Cyrilvallez
Copy link
Member

Hey! Thanks for the contribution! However, the merging rule for assignment was specifically chosen to avoid having to redefine the big docstrings (what you did with Starcoder2) or very common variables while we think of a better solution for automatic docstrings.
However, using regex patterns instead of hard matching for ASSIGNMENTS_TO_KEEP may be a good idea, as I think it's not always exactly "_CHECKPOINT_FOR_DOC" for older models.
Is there a specific reason why you need this change BTW? 🤗

@nikosanto13
Copy link
Contributor Author

nikosanto13 commented Jan 16, 2025

@Cyrilvallez hey, thanks for the update. Well I observed the erroneous documentation with phi and jepa modular files and I figured this should be fixed this with the regex patterns (also to avoid any potential issues in the future). The fix had the byproduct of having to redefine the docstring (since the docstring variable name now matches the regex pattern).

I guess if the errors in documentations are less of an issue, then it could be skipped.

@Cyrilvallez
Copy link
Member

I see! Super cool initiative! Let's do it then!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment for some guidance on how to proceed! Basically, I am against adding the DOCSTRING pattern as explained (it should not be needed, see comments, and will make our lives harder while figuring out a better way for docstrings), but the rest is very nice!

Comment on lines 524 to 523
ASSIGNMENTS_REGEX_TO_KEEP = [
r"_CHECKPOINT",
r"_EXPECTED",
r"DOCSTRING",
]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ASSIGNMENTS_REGEX_TO_KEEP = [
r"_CHECKPOINT",
r"_EXPECTED",
r"DOCSTRING",
]
# Similar to the above list, but for regex patterns
ASSIGNMENTS_REGEX_TO_KEEP = [
r"_CHECKPOINT",
r"_EXPECTED",
]

Here let's remove the docstring part, as I explained I am against it as it forces us to redefine annoyingly long docstrings! For Emu3, it is actually an issue because of a very slight overlook, switching

class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
    def __init__(self, config: Emu3Config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

in modular_emu3.py to

class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
    def __init__(self, config: Emu3Config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

    @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
    def forward(self, **super_kwargs):
        super().forward(**super_kwargs)

automatically takes care of the docstring based on current rules!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah now I see, ty for clarifying

utils/modular_model_converter.py Outdated Show resolved Hide resolved
src/transformers/models/starcoder2/modular_starcoder2.py Outdated Show resolved Hide resolved
@nikosanto13 nikosanto13 force-pushed the modular-bugfix-assignments branch 2 times, most recently from 8d51d2f to 7d1294a Compare January 17, 2025 21:47
@nikosanto13 nikosanto13 force-pushed the modular-bugfix-assignments branch from 7d1294a to 74ce50a Compare January 17, 2025 22:00
…cstring assingment, remove verbatim assignments in modular converter
@nikosanto13 nikosanto13 force-pushed the modular-bugfix-assignments branch from 74ce50a to eeea23a Compare January 17, 2025 22:05
@nikosanto13
Copy link
Contributor Author

@Cyrilvallez thanks for ur help. I didn't like my change for the docstring either (neither I knew about future plans on automatic docstrings), glad to get feedback there.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cc @Cyrilvallez

@@ -640,8 +642,7 @@ def forward(
)


# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
_IMAGE_CLASS_CHECKPOINT = "facebook/ijepa_vith14_1k"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" modular should be changed as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modular is correct, but the assignment wasn't passing because the name wasn't strictly matching ASSIGNMENTS_TO_KEEP (is this what you meant?)

Although now I see that there is a leftover _IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224" in configuration_ijepa.py

Copy link
Collaborator

@ArthurZucker ArthurZucker Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah had one wrong in my modular

}
ASSIGNMENTS_REGEX_TO_KEEP = [
r"_CHECKPOINT",
r"_EXPECTED",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add _FOR_DOC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_CONFIG_FOR_DOC is already handled by VARIABLES_AT_THE_BEGINNING but I suppose I should add it to cover a few edge cases (like _TOKENIZER_FOR_DOC, although there are no assignments in modular files with that name in lhs)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

@qubvel qubvel removed their request for review January 20, 2025 18:12
@ArthurZucker ArthurZucker merged commit 920f34a into huggingface:main Jan 21, 2025
14 checks passed
@nikosanto13
Copy link
Contributor Author

@ArthurZucker I'm frequently using the modeling files of speech ssl models (wav2vec2, hubert, wavlm etc.) and I see that they don't have modular files yet (but have large duplication). Would you welcome a PR on that?

@ArthurZucker
Copy link
Collaborator

Yeah for sure! 🤗

@Cyrilvallez
Copy link
Member

If needed you can get more details about modular with #35737 (I rewrote the doc recently) 🤗

@nikosanto13
Copy link
Contributor Author

awesome, I'll definitely use it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants