Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TimmWrapper #34564

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft

Add TimmWrapper #34564

wants to merge 40 commits into from

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Nov 1, 2024

What does this PR do?

Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.

Continue of

General Usage

import torch
from urllib.request import urlopen
from PIL import Image
from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor

checkpoint = "timm/resnet50.a1_in1k"
img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

image_processor = AutoImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(img, return_tensors="pt")
model = AutoModelForImageClassification.from_pretrained(checkpoint)

with torch.no_grad():
    logits = model(**inputs).logits

top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)

Pipeline

Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines

import torch
from urllib.request import urlopen
from PIL import Image

from transformers import pipeline

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
pipe = pipeline("image-classification", model="timm/resnet18.a1_in1k")
print(pipe(img))

Trainer

Timm models can now be loaded and trained with the trainer class.

Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans

python run_image_classification.py \                
    --dataset_name beans \
    --output_dir ./beans_outputs/ \
    --remove_unused_columns False \
    --label_column_name labels \
    --do_train \
    --do_eval \
    --push_to_hub \
    --push_to_hub_model_id vit-base-beans \
    --learning_rate 2e-5 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --save_total_limit 3 \
    --seed 1337 \
    --model_name_or_path timm/resnet18.a1_in1k \
    --ignore_mismatched_sizes

Other features enabled

  • Device map:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map="auto")
  • Torch dtype:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, torch_dtype="bfloat16")
  • Quantization:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, load_in_4bit=True)
  • Intermediate hidden states: output_hidden_states=True or output_hidden_states=[1, 2, 3] (to select specific hidden states)
model = TimmWrapperForImageClassification.from_pretrained(checkpoint)
output = model(**intpus, output_hidden_states=True)

TODO

  • Gamma/beta renaming issue
  • Update timm in CI 0.9.6 -> 1.0.11 to enable output_hidden_states tests
  • Weights are loaded by transformers instead of timm, which architectures are affected?
  • Tests for image processor

@qubvel qubvel marked this pull request as draft November 1, 2024 15:52
@qubvel qubvel added the run-slow label Nov 1, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 2 changes in this file:

  1. state_dict keys renaming moved into separate method to be able to override it for TimmWrapper (disable gamma/beta renaming + add prefix)
  2. metadata is None for timm checkpoints -> assuming these are pytorch checkpoints

Comment on lines +417 to +420
default_image_processor_filename = (
"config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME
)
kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timm checkpoints store image processor config in config.json

@qubvel
Copy link
Member Author

qubvel commented Nov 4, 2024

@rwightman could you please make the first review in case you have bandwidth

@qubvel qubvel requested a review from rwightman November 5, 2024 00:00
@rwightman
Copy link

rwightman commented Nov 5, 2024

@qubvel I'm starting to work through it now, wanted to get eval working to check some familiar models and just wasted way too much time realizing I needed --remove_unused_columns False for anything to work at all :/ ... that's a really poor setup when most datasets have an 'image' column and not a 'pixel_values' column (realize that's nothing to do with this PR, heh) :/

Annnyways, first pass of the code things looked sane but need to spend some time looking closer at the details and testing some cases.

@rwightman
Copy link

A few high level q...

  • Does 'Wrapper' add any worthwhile value/info in the name vs

    • TimmPreTrainedModel(PreTrainedModel)
    • TimmModel(TimmPretrainedModel)
    • TimmModelForImageClassification(..)
  • Is there a reason TimmWrapperModel has .timm_model, instead of something more generic like .model

  • Any reservations in changing TimmWrapperModelForImageClassification to not use TimmWrapperModel? There are a few issues with the handling details for classifier, possible optimizations for forward call sequence and it'd probably be a bit cleaner to just duplicate a bit of redunant code and keep the two impl separate and a bit different.

  • I thought we were going to set pretained=True so timm can load the weights, their are a number of weight adaptation / translation things that don't run if this isn't use, cannot change num_classes cleanly for instance.

  • What happens if we try to push these models to the hub? Do they get uploaded/written in a form that timm can read?

@qubvel
Copy link
Member Author

qubvel commented Nov 5, 2024

@rwightman thanks for the review! Indeed there are some default params I'm also confused about, it's even more for object detection 🥲

Does 'Wrapper' add any worthwhile value/info in the name vs
TimmPreTrainedModel(PreTrainedModel)
TimmModel(TimmPretrainedModel)
TimmModelForImageClassification(..)

Left it as it was in the previous PR, however, TimmModelForImageClassification sounds better to me, I can rename it

Is there a reason TimmWrapperModel has .timm_model, instead of something more generic like .model

The prefix "timm_model" is unique and is used in certain tests to identify when weights come from a timm model. It's also utilized in the _fix_state_dict_key method to determine whether to add the prefix when loading weights from the original checkpoint. For these reasons, I would prefer to keep it as "timm_model"

Any reservations in changing TimmWrapperModelForImageClassification to not use TimmWrapperModel? There are a few issues with the handling details for classifier, possible optimizations for forward call sequence and it'd probably be a bit cleaner to just duplicate a bit of redunant code and keep the two impl separate and a bit different.

Originally, this was implemented without TimmWrapperModel in TimmWrapperModelForImageClassification, but I introduced it to reduce code repetition. This approach also aligns better with common patterns in the transformers repo. Could you provide more details on the issues with the classifier? In any case, the code will remain the same for both models if we aim to maintain output_hidden_states functionality across them.

I thought we were going to set pretained=True so timm can load the weights, their are a number of weight adaptation / translation things that don't run if this isn't use, cannot change num_classes cleanly for instance.

I left a comment in the thread about this. We use transformers for weight loading to leverage features like device_map, torch_dtype, and quantization. I’m also unsure how to disable weight loading through transformers if it’s handled by timm, as I haven’t seen any examples of this in the repo. I can dig into it further if needed.

Do you have an estimate of how many models involve weight renaming? Is there a way to update checkpoints without breaking older versions of timm? Alternatively, could we manage similar weight renaming directly in transformers? (Though I think this approach may be less robust.)

What happens if we try to push these models to the hub? Do they get uploaded/written in a form that timm can read?

I’m not sure if it's currently compatible, but it would be great to enable it! The config should be compatible, however, the weights state dict will have the "model.timm_model." prefix. I can look into removing this prefix before saving. I will try to enable it and add a separate test for a few checkpoints. Thanks for bringing this up!

@rwightman
Copy link

rwightman commented Nov 5, 2024

Originally, this was implemented without TimmWrapperModel in TimmWrapperModelForImageClassification, but I introduced it to reduce code repetition. This approach also aligns better with common patterns in the transformers repo. Could you provide more details on the issues with the classifier? In any case, the code will remain the same for both models if we aim to maintain output_hidden_states functionality across them.

I feel in this case there is a difference in alignment with other models, because both ImageClassification and base model wrap timm model instances that differ, instead of adding their own head to the same timm model. I see some options, depending on the mix of hidden state flags, head vs no head where it'd be appropriate to go through different forward calls, where more than just one argument might be appropriate to change on creation, etc.

Also, thinking about the future and other tasks. I feel there is a high probability say for supporting native timm object detection more flexibility is desired so it'd be safer to have it uncoupled.

That and there's a resisitance to significant changes in transformers after something is in there, so feel it's better to leave uncoupled to have additional flexibility and avoid being stuck with tricky decisions that might impact timm moving forward.

@rwightman
Copy link

rwightman commented Nov 5, 2024

I thought we were going to set pretained=True so timm can load the weights, their are a number of weight adaptation / translation things that don't run if this isn't use, cannot change num_classes cleanly for instance.

I left a comment in the thread about this. We use transformers for weight loading to leverage features like device_map, torch_dtype, and quantization. I’m also unsure how to disable weight loading through transformers if it’s handled by timm, as I haven’t seen any examples of this in the repo. I can dig into it further if needed.

Do you have an estimate of how many models involve weight renaming? Is there a way to update checkpoints without breaking older versions of timm? Alternatively, could we manage similar weight renaming directly in transformers? (Though I think this approach may be less robust.)

Aside from renaming, it just doesn't work right now, you can't load weights for a model if the head size changes for a new classification task. The wrapper only works if you use the imagenet classifier.

I realize timm doesn't have the dtype, lazy init features but it's better to have it work I feel. Can potentially look at supporting some of that through timm. It hasn't been a priority as there aren't too many very large models in timm.

If there's no way to do use pretrained=True on creation then will probably need to figure out how to add & call a method in timm once transformers has the state_dict and before it's loaded into the model, not sure if there's a spot for such a call in transformers?

Empty init weights function to ensure compatibility of the class in the library.
"""
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear why this is here, what it's attempting to do. timm has model specific init fns though they aren't separately callable right now, doing something like this that could overwrite timm defaults would change model behaviour

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to initialize classifier, without this weights are not properly initialized, probably due to how the model is created in transformers

from transformers import TimmWrapperForImageClassification

# --------------
# With init
# --------------

model = TimmWrapperForImageClassification.from_pretrained("timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True)
print(model.timm_model.fc.weight[:3, :3])
# tensor([[-0.2117, -0.2422, -0.2540],
#         [-0.1106, -0.1856, -0.0152],
#         [-0.3430, -0.6446, -0.0530]], grad_fn=<SliceBackward0>)


# --------------
# Without init
# --------------

# patch with empty init weight function to check
def empty_init(self, module):
    pass
TimmWrapperForImageClassification._init_weights = empty_init

model = TimmWrapperForImageClassification.from_pretrained("timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True)
print(model.timm_model.fc.weight[:3, :3])
# tensor([[0., 0., 0.],
#         [0., 0., 0.],
#         [0., 0., 0.]], grad_fn=<SliceBackward0>)

Ideally, we should get rid of this, but it's not common for transformers to load external models, so it might require more code changes. For now, its a simple fix to enable model loading with initialized classifier if shapes are mismatched.


if is_dir and os.path.exists(os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)):
# timm models don't have a preprocessor_config.json file saved out
return False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think absence of a file is a good check

@qubvel
Copy link
Member Author

qubvel commented Nov 5, 2024

Aside from renaming, it just doesn't work right now, you can't load weights for a model if the head size changes for a new classification task. The wrapper only works if you use the imagenet classifier.

Fixed!

@qubvel
Copy link
Member Author

qubvel commented Nov 5, 2024

If there's no way to do use pretrained=True on creation then will probably need to figure out how to add & call a method in timm once transformers has the state_dict and before it's loaded into the model, not sure if there's a spot for such a call in transformers?

If there would be any timm model/class-specific function that fixes state_dict I suppose we can call it before loading weights, similar to _fix_state_dict_key in the current implementation. However, it will be supported only for newer timm versions

@rwightman
Copy link

Thinking about this a bit more, there is another issue with the hub / transformers first weight loading. timm wasn't originally hub first, so the library itself is still the primary source of truth for some models, doing a hub based load won't work.

Example this model https://huggingface.co/laion/CLIP-ViT-B-16-laion2B-s34B-b88K ... is an OpenCLIP first model, but timm can load it if you use the model name 'vit_base_patch16_clip_224.laion2b'

Indeed I feel that the wrapper should support all timm model names that work in timm, but right now if the model isn't on the hub w/ a timm primary config it isn't useable. Ideally it should work with both a hub model name OR any timm model name. The timm model name would require timm do the pretrained loading to resolve any translation to other hub name or weight source.

Some very popular examples of this

https://github.com/huggingface/pytorch-image-models/blob/51ac8d2efb926c6b7c34eeb1dc52bcf57999e2de/timm/models/vision_transformer.py#L1580-L1716

@rwightman
Copy link

Aside from renaming, it just doesn't work right now, you can't load weights for a model if the head size changes for a new classification task. The wrapper only works if you use the imagenet classifier.

Fixed!

Yay, I was able to run fine-tune using run_image_classification.py after this fix. An observation the output files for that script aren't directly useable to push to hub.

There is no config matching timm format or name, there is a config output by the image preprocessor save process to a different filename that's a jumble of the timm config.

Also, the state dict for the model has the timm_model prefix so it's not loadable in timm. Is there anyway to remove that prefix in the checkpoints? this would also make it more seamless doing local dir loads if someone had timm weights already and config files checked out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants