Skip to content

Initial NaFlex ViT model and training support #2466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jun 5, 2025
Merged

Initial NaFlex ViT model and training support #2466

merged 34 commits into from
Jun 5, 2025

Conversation

rwightman
Copy link
Collaborator

@rwightman rwightman commented Apr 8, 2025

Working:

  • 'flex' ViT w/ NaFlex position embedding resize, pre-patched input, attention padding masks
  • Single node train.py works with a custom naflex data-pipeline via a dataset wrapper that handles random seq-len & batch-size selection, constrains images to seq-len while keeping aspect ratio (with randomizations)
  • A much faster patch embed kernel resample, torch only, can be used in forward()
  • A 4-GPU distributed training run completed with decent results
  • NaFlex patch mode compatible mixup & cutmix and random erasing implementation
  • Add randomization of the patch_size along with seq_len
  • SigLip-2 NaFlex vision encoder weight port (tested with un-pushed OpenCLIP mods), matches expected results
  • weight loading / translation for existing vits (all but 2 existing vision_transformer.py vits load with use_naflex=True flag in create_model())

Not tested / not completed:

  • bigger distributed runs needed
  • dataset wrapper for iterable datasets (wds, tfds, iterable hfds) needs to be added
  • more model definitions
  • Integration of naflex data pipeline components into OpenCLIP

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman rwightman marked this pull request as draft April 8, 2025 04:39
rwightman added 26 commits April 8, 2025 07:59
… add, classic vit weight loading for naflex model
…king loader based patch compatible RandomErasing for NaFlex mode.
…w. Remove subregion mode, not going to be worth it.
… embeds and 'aspect preserving mode' to Flex Embeds. Some more docstrings and typing.
… creating classic vits as naflex. Cleanup, improvements.
rwightman added 2 commits June 4, 2025 17:03
…ndling from train.py onwards. Add docstrings and type annotations (thanks Claude).
@rwightman rwightman marked this pull request as ready for review June 5, 2025 03:56
@stas-sl
Copy link

stas-sl commented Jun 5, 2025

Hi Ross,

I've been following your progress a bit, as I'm also interested in the FlexiViT/Navit/Naflex architectures. I'm still trying to wrap my head around all the details, but I had a question regarding patch embedding resizing methods.

How essential do you think the PI variant of patch embedding resizing is, compared to the other variants mentioned in the FlexiViT paper (Appendix A.1), such as Vanilla, Token-LN, Image-LN, or Untied?

From my understanding, if you're using pretrained patch embeddings with a new patch size without introducing additional layers or operations in the transformer, then yes, PI resizing seems to work best. However, if you're training from scratch, their figure suggests that simple bilinear resizing performs quite well - except for very large patch sizes (like 48). So I'm wondering: is the added complexity of the PI method really worth it in that case, or can we just use bilinear resizing?

image

I also noticed you recently added the PatchEmbedInterpolator class, which I assume implements the PI method. Do you think it’s worth supporting other resizing methods (like plain bilinear), or is that already possible and I’m just missing something?

@rwightman
Copy link
Collaborator Author

rwightman commented Jun 5, 2025

@stas-sl if you train/fine-tune with a diff patch size using basic interpolation as you say, yeah, I imagine it will be fine, if you train while resizing to different patch sizes using the simple interpolation and don't get crazy in the range of sizes covered, I expect it'd be robust to sizes in the range used (at inference time). But I haven't tried this extensively.

However, using the simple resize on existing model weights yields pretty poor results compared to the PI method. Originally with the PI method I had based on the original JAX impl it was damned slow, however I completely redid it native torch tensors and a WAY faster basis vector computation and it runs quite nicely at train time, so that's why I decided to just support the PI mode. I was just testing this yesterday and the NaFlex pipeline appears to be working well when both randomizing sequence length AND patch size at train time, neat.

@rwightman rwightman merged commit a5e551b into main Jun 5, 2025
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants