Skip to content

Conversation

@zhang-prog
Copy link

@zhang-prog zhang-prog commented Nov 13, 2025

What does this PR do?

This PR adds PaddleOCR-VL model to Hugging Face Transformers from PaddleOCR.

Relevant Links:

PaddleOCR
https://huggingface.co/PaddlePaddle/PaddleOCR-VL

Usage

Use a pipeline

from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
result = pipe(text=messages)
print(result)

Load model directly

from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@zucchini-nlp zucchini-nlp self-requested a review November 13, 2025 09:07
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @zhang-prog , thanks for the PR! Great model to have in transformers!

The main thing to fix first is the naming, it should clearly include "PaddlePaddleOCR" and follow the usual pattern depending on the modality. The config format also isn’t right; it needs to be fully nested, with text and vision configs inside. Additionally there are no tests or docs, several files are missing. You can run transformers add-new-model-like which would generate a placeholder with the necessary files. I also left some smaller comments here and there. Let me know if you hit any issues

Comment on lines +91 to +98
if height < factor:
width = round((width * factor) / height)
height = factor

if width < factor:
height = round((height * factor) / width)
width = factor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as Qwen but with support for H/W smaller than a factor. I think we made qwen-VL support small images as well, so prob directly importing will give expected result?

return h_bar, w_bar


class PaddleOCRVLImageProcessor(Qwen2VLImageProcessor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are currently recommending to add a Fast Image Processor first for new models, and add the slow version only as a complementary fallback

Can you add a FastProcessor as well? There is some info about fast processors in #36978

Comment on lines +182 to +183
self.min_pixels = min_pixels
self.max_pixels = max_pixels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use size instead of min/max pixels. We've been trying to standardize attribute naming lately and size is a common arg for it

Comment on lines +346 to +360
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"image_std",
"min_pixels",
"image_mean",
"merge_size",
"image_processor_type",
"temporal_patch_size",
"patch_size",
"max_pixels",
]

image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need these anymore with the recent change in v5

tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|IMAGE_PLACEHOLDER|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the token in tokenizer so we can assume t's always available?
https://huggingface.co/docs/transformers/en/main_classes/tokenizer#multimodal-tokenizer

Comment on lines +1728 to +1741
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use self.loss_fn here

Comment on lines +1743 to +1746
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed as long as the forward is decorated with can_return_tuple

rope_deltas=self.rope_deltas,
)

def prepare_inputs_for_generation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as qwen2-5-vl, can be deleted after inheriting from it


return model_inputs

def _get_image_nums_and_video_nums(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as qwen2-5-vl, can be deleted after inheriting from it


return image_nums, video_nums

def _expand_inputs_for_generation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as qwen2-5-vl, can be deleted after inheriting from it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants