Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All functionality has been consolidated into a single file for CLI/UI/Checkpointing and Added fix for issue 702 and added code for that as well, added instructions in local_inference /README.md as well #757

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/quickstart/finetuning/LLM_finetuning_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ To boost the performance of fine-tuning with FSDP, we can make use a number of f

- **Activation Checkpointing** which is a technique to save memory by discarding the intermediate activation in forward pass instead of keeping it in the memory with the cost recomputing them in the backward pass. FSDP Activation checkpointing is shard aware meaning we need to apply it after wrapping the model with FSDP. In our script we are making use of that.

- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost.
- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost.
46 changes: 33 additions & 13 deletions recipes/quickstart/inference/local_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,46 @@
## Hugging face setup
**Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed.

## Multimodal Inference
For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library.
## Multimodal Inference and CLI inference with or without PEFT LoRA weights

The way to run this would be:
```
python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
```
---
## Multi-modal Inferencing Using gradio UI for inferencing
For multi-modal inferencing using gradio UI we have added [multi_modal_infer_gradio_UI.py](multi_modal_infer_gradio_UI.py) which used gradio and transformers library.
### Model Overview
- Base model: `meta-llama/Llama-3.2-11B-Vision-Instruct`
- Uses PEFT library (v0.13.1) for efficient fine-tuning
- Supports vision-language tasks with instruction capabilities

### Steps to Run
### Features in
`multi_modal_infer.py`

The way to run this would be:
- Ensure having proper access to llama 3.2 vision models, then run the command given below
All functionality has been consolidated into a single file with three main modes:
### Steps to run are given below:
1. **Basic Inference**
```bash
python multi_modal_infer.py \
--image_path "path/to/image.jpg" \
--prompt_text "Describe this image" \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token"
```

2. **Gradio UI Mode**
```bash
python multi_modal_infer.py \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token" \
--gradio_ui
```
python multi_modal_infer_gradio_UI.py --hf_token <your hf_token here>

3. **LoRA Fine-tuning Integration**
```bash
python multi_modal_infer.py \
--image_path "path/to/image.jpg" \
--prompt_text "Describe this image" \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token" \
--finetuning_path "path/to/lora/weights"
```


## Text-only Inference
For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.

Expand Down
239 changes: 161 additions & 78 deletions recipes/quickstart/inference/local_inference/multi_modal_infer.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,191 @@
import argparse
import os
import sys

import torch
from accelerate import Accelerator
from PIL import Image as PIL_Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor
from peft import PeftModel
import gradio as gr

# Initialize accelerator
accelerator = Accelerator()

device = accelerator.device

# Constants
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
MAX_OUTPUT_TOKENS = 2048
MAX_IMAGE_SIZE = (1120, 1120)


def load_model_and_processor(model_name: str):
"""
Load the model and processor based on the 11B or 90B model.
"""
def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
"""Load model and processor with optional LoRA adapter"""
print(f"Loading model: {model_name}")
model = MllamaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=device,
token=hf_token
)
processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)

processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)

if finetuning_path and os.path.exists(finetuning_path):
print(f"Loading LoRA adapter from '{finetuning_path}'...")
model = PeftModel.from_pretrained(
model,
finetuning_path,
is_adapter=True,
torch_dtype=torch.bfloat16
)
print("LoRA adapter merged successfully")

model, processor = accelerator.prepare(model, processor)
return model, processor

def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
"""Process and validate image input"""
if image is not None:
return image.convert("RGB")
if image_path and os.path.exists(image_path):
return PIL_Image.open(image_path).convert("RGB")
raise ValueError("No valid image provided")

def process_image(image_path: str) -> PIL_Image.Image:
"""
Open and convert an image from the specified path.
"""
if not os.path.exists(image_path):
print(f"The image file '{image_path}' does not exist.")
sys.exit(1)
with open(image_path, "rb") as f:
return PIL_Image.open(f).convert("RGB")


def generate_text_from_image(
model, processor, image, prompt_text: str, temperature: float, top_p: float
):
"""
Generate text from an image using the model and processor.
"""
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
"""Generate text from image using model"""
conversation = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
}
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=False
)
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(image, prompt, return_tensors="pt").to(device)
output = model.generate(
**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
)
return processor.decode(output[0])[len(prompt) :]


def main(
image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
):
"""
Call all the functions.
"""
model, processor = load_model_and_processor(model_name)
image = process_image(image_path)
result = generate_text_from_image(
model, processor, image, prompt_text, temperature, top_p
)
print("Generated Text: " + result)

output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
return processor.decode(output[0])[len(prompt):]

def gradio_interface(model_name: str, hf_token: str):
"""Create Gradio UI with LoRA support"""
# Initialize model state
current_model = {"model": None, "processor": None}

def load_or_reload_model(enable_lora: bool, lora_path: str = None):
current_model["model"], current_model["processor"] = load_model_and_processor(
model_name,
hf_token,
lora_path if enable_lora else None
)
return "Model loaded successfully" + (" with LoRA" if enable_lora else "")

def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
if image is not None:
try:
processed_image = process_image(image=image)
result = generate_text_from_image(
current_model["model"],
current_model["processor"],
processed_image,
user_prompt,
temperature,
top_p
)
history.append((user_prompt, result))
except Exception as e:
history.append((user_prompt, f"Error: {str(e)}"))
return history

def clear_chat():
return []

with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")

with gr.Row():
with gr.Column(scale=1):
# Model loading controls
with gr.Group():
enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
lora_path = gr.Textbox(
label="LoRA Weights Path",
placeholder="Path to LoRA weights folder",
visible=False
)
load_status = gr.Textbox(label="Load Status", interactive=False)
load_button = gr.Button("Load/Reload Model")

# Image and parameter controls
image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)

with gr.Column(scale=2):
chat_history = gr.Chatbot(label="Chat", height=512)
user_prompt = gr.Textbox(
show_label=False,
placeholder="Enter your prompt",
lines=2
)

with gr.Row():
generate_button = gr.Button("Generate")
clear_button = gr.Button("Clear")

# Event handlers
enable_lora.change(
fn=lambda x: gr.update(visible=x),
inputs=[enable_lora],
outputs=[lora_path]
)

load_button.click(
fn=load_or_reload_model,
inputs=[enable_lora, lora_path],
outputs=[load_status]
)

generate_button.click(
fn=describe_image,
inputs=[
image_input, user_prompt, temperature,
top_k, top_p, max_tokens, chat_history
],
outputs=[chat_history]
)

clear_button.click(fn=clear_chat, outputs=[chat_history])

# Initial model load
load_or_reload_model(False)
return demo

def main(args):
"""Main execution flow"""
if args.gradio_ui:
demo = gradio_interface(args.model_name, args.hf_token)
demo.launch()
else:
model, processor = load_model_and_processor(
args.model_name,
args.hf_token,
args.finetuning_path
)
image = process_image(image_path=args.image_path)
result = generate_text_from_image(
model, processor, image,
args.prompt_text,
args.temperature,
args.top_p
)
print("Generated Text:", result)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate text from an image and prompt using the 3.2 MM Llama model."
)
parser.add_argument("--image_path", type=str, help="Path to the image file")
parser.add_argument(
"--prompt_text", type=str, help="Prompt text to describe the image"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for generation (default: 0.7)",
)
parser.add_argument(
"--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
)
parser.add_argument(
"--model_name",
type=str,
default=DEFAULT_MODEL,
help=f"Model name (default: '{DEFAULT_MODEL}')",
)

parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
parser.add_argument("--image_path", type=str, help="Path to the input image")
parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")

args = parser.parse_args()
main(
args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
)
main(args)
Loading
Loading