-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Add Phi3.5 Vision Model #41977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Phi3.5 Vision Model #41977
Conversation
|
cc @zucchini-nlp when you get a chance! |
|
Phi3.5 💀 I will take a look some time this week |
|
@zucchini-nlp PR should be ready for review next week. Will ping you at that time. |
|
@zucchini-nlp PR is ready for initial review 🤗. CI has been broken from weekend but it was all green previously so no issues there. |
| def get_image_features(self, pixel_values: torch.Tensor, image_sizes, num_images, num_crops): | ||
| # Process image using CLIP model. | ||
| vision_outputs = self.vision_model(pixel_values, output_hidden_states=True) | ||
|
|
||
| # Extract the hidden states from the second last layer. | ||
| hidden_state = vision_outputs.hidden_states[-2][:, 1:] | ||
| hidden_state = hidden_state.reshape(num_images, num_crops, -1, self.image_dim_out) | ||
|
|
||
| # Transform the image features to text embedding space. | ||
| image_features = self.transform_image_embeds(hidden_state, image_sizes) | ||
| return image_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Majorly because of this function where the image features are transformed and projected in a bit non-standard way, also the utilization of image sizes makes it difficult to run inference with num_return_sequence>1 as then the image inputs are not synced with the repeated input_ids. Thus appropriate tests like beam search and any other tests are skipped.
| for prompt in text: | ||
| prompt_splits = re.split(r"(\<\|image\|\>)", prompt) | ||
|
|
||
| tokenized_outputs = [] | ||
| for split in prompt_splits: | ||
| if split == "<|image|>": | ||
| if image_token_counter >= len(num_image_tokens): | ||
| raise ValueError("More image placeholders in the text than images provided.") | ||
| image_tokens = [self.image_token_id] * num_image_tokens[image_token_counter] | ||
| tokenized_outputs.extend(image_tokens) | ||
| image_token_counter += 1 | ||
| else: | ||
| text_tokens = self.tokenizer(split)["input_ids"] | ||
| tokenized_outputs.extend(text_tokens) | ||
|
|
||
| tokenized_prompts.append(tokenized_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of custom tokenization and the creation of subsequent attention mask, it's difficult to support assisted decoding and IMO it's not a super important generation method to support in this case.
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaswanth19 thanks for the PR, looks much cleaner already!
I left some comments, mostly nitty-picking for better standardization. Also I believe there's one test failing with Phi3.5V :)
| processor_dict = self.prepare_processor_dict() | ||
| self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) | ||
|
|
||
| @unittest.skip("Not possible as processor creates a custom attention mask.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask format doesn't look custom, even though prepared manually instead of passing to tokenizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am skipping this test because it requires offset mapping which is quite difficult to fetch because of the way we tokenize the prompt.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, phi3_v |
Closes #36036