-
Notifications
You must be signed in to change notification settings - Fork 378
Expand file tree
/
Copy pathtest_rec_r1.py
More file actions
executable file
·253 lines (199 loc) · 8.74 KB
/
test_rec_r1.py
File metadata and controls
executable file
·253 lines (199 loc) · 8.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random
from PIL import Image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
def setup_distributed():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()
return local_rank, world_size, rank
local_rank, world_size, rank = setup_distributed()
device = f"cuda:{local_rank}"
print(f"Process {rank} using {device}")
main_rank = 0
steps = 100
if rank == main_rank:
print("Steps: ", steps)
RUN_NAME = "Qwen2.5-VL-3B-Instruct-rec"
MODEL_PATH=f"/training/shz/project/vlm-r1/VLM-R1/checkpoints/rl/{RUN_NAME}/checkpoint-{steps}"
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
BSZ=2
DATA_ROOT = "/training/shz/dataset/vlm-r1/rec_jsons_processed"
# TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
# IMAGE_ROOT = "/training/shz/dataset/coco"
TEST_DATASETS = ['lisa_test']
IMAGE_ROOT = "/training/shz/dataset/lisa"
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map={"": local_rank},
)
# default processer
processor = AutoProcessor.from_pretrained(MODEL_PATH)
def extract_bbox_answer(content):
# Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
answer_tag_pattern = r'<answer>(.*?)</answer>'
bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
if content_answer_match:
content_answer = content_answer_match.group(1).strip()
bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
if bbox_match:
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
return bbox
return [0, 0, 0, 0]
def iou(box1, box2):
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[2]-1, box2[2]-1)
inter_y2 = min(box1[3]-1, box2[3]-1)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
else:
inter = 0
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
return float(inter)/union
def resize_bbox(bbox, input_height, input_width, image_height, image_width):
bbox[0] = bbox[0] / input_width * image_width
bbox[1] = bbox[1] / input_height * image_height
bbox[2] = bbox[2] / input_width * image_width
bbox[3] = bbox[3] / input_height * image_height
return bbox
num_samples = 2000
for ds in TEST_DATASETS:
if rank == main_rank:
print(f"Processing {ds}...")
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
data = json.load(open(ds_path, "r"))
random.seed(42)
random.shuffle(data)
data = data[:num_samples]
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
# Split data for distributed evaluation
per_rank_data = len(data) // world_size
start_idx = rank * per_rank_data
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
rank_data = data[start_idx:end_idx]
messages = []
for x in rank_data:
image_path = os.path.join(IMAGE_ROOT, x['image'])
message = [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{
"role": "user",
"content": [
{
"type": "image",
"image": f"file://{image_path}"
},
{
"type": "text",
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
}
]
}]
messages.append(message)
rank_outputs = [] # List to store answers for this rank
all_outputs = [] # List to store all answers
# Process data
for i in tqdm(range(0, len(messages), BSZ), disable=rank != main_rank):
batch_messages = messages[i:i + BSZ]
# Preparation for inference
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
image_inputs, video_inputs = process_vision_info(batch_messages)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
padding_side="left",
return_tensors="pt",
)
inputs = inputs.to(device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
batch_output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
batch_output = []
for i, output_text in enumerate(batch_output_text):
input_height = int(inputs['image_grid_thw'][i][1]*14)
input_width = int(inputs['image_grid_thw'][i][2]*14)
image = Image.open(batch_messages[i][0]['content'][0]['image'].split("file://")[1])
image_width, image_height = image.size
batch_output.append((output_text, input_height, input_width, image_height, image_width))
rank_outputs.extend(batch_output)
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
# Gather all outputs from all ranks
all_outputs = [None] * len(data)
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
gathered_results = [None] * world_size
dist.all_gather_object(gathered_results, rank_results)
assert gathered_results[-1][-1][0] == len(data) - 1
# The main process will collect all results
if rank == main_rank:
for results in gathered_results:
for idx, output in results:
assert idx < len(all_outputs)
all_outputs[idx] = output
assert all_outputs[-1] is not None
final_output = []
correct_number = 0
for input_example, model_output in zip(data, all_outputs):
original_output, input_height, input_width, image_height, image_width = model_output
ground_truth = input_example['solution']
model_answer = extract_bbox_answer(original_output)
resized_model_answer = resize_bbox(model_answer, input_height, input_width, image_height, image_width)
# Count correct answers
correct = 0
if model_answer is not None:
if iou(resized_model_answer, ground_truth) > 0.5:
correct = 1
correct_number += correct
# Create a result dictionary for this example
result = {
'image': input_example['image'],
'question': input_example['problem'],
'ground_truth': ground_truth,
'model_output': original_output,
'input_size': (input_height, input_width),
'image_size': (image_height, image_width),
'extracted_answer': resized_model_answer,
'correct': correct
}
final_output.append(result)
# Calculate and print accuracy
accuracy = correct_number / len(data) * 100
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
# Save results to a JSON file
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(output_path, "w") as f:
json.dump({
'accuracy': accuracy,
'results': final_output
}, f, indent=2)
print(f"Results saved to {output_path}")
print("-"*100)
# Synchronize all processes
dist.barrier()