Skip to content

Commit 3710e28

Browse files
Add Video-ChatGPT benchmarks (NVlabs#144)
Co-authored-by: Zhijian Liu <[email protected]>
1 parent 616dbe7 commit 3710e28

13 files changed

+610
-67
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import argparse
2+
import json
3+
import math
4+
import os
5+
6+
import shortuuid
7+
import torch
8+
from decord import VideoReader
9+
from PIL import Image
10+
from torch.utils.data import DataLoader, Dataset
11+
from tqdm import tqdm
12+
13+
from llava import conversation as conversation_lib
14+
from llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
15+
from llava.conversation import SeparatorStyle, conv_templates
16+
from llava.data.dataset import LazySupervisedDataset
17+
from llava.mm_utils import (
18+
KeywordsStoppingCriteria,
19+
get_model_name_from_path,
20+
is_gemma_tokenizer,
21+
process_images,
22+
tokenizer_image_token,
23+
)
24+
from llava.model.builder import load_pretrained_model
25+
from llava.utils import disable_torch_init
26+
27+
28+
def split_list(lst, n):
29+
"""Split a list into n (roughly) equal-sized chunks"""
30+
chunk_size = math.ceil(len(lst) / n) # integer division
31+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
32+
33+
34+
def get_chunk(lst, n, k):
35+
chunks = split_list(lst, n)
36+
return chunks[k]
37+
38+
39+
# Custom dataset class
40+
class CustomDataset(Dataset):
41+
def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
42+
self.questions = questions
43+
self.image_folder = image_folder
44+
self.tokenizer = tokenizer
45+
self.image_processor = image_processor
46+
self.model_config = model_config
47+
if hasattr(model_config, "num_video_frames") and model_config.num_video_frames is not None:
48+
self.num_video_frames = model_config.num_video_frames
49+
else:
50+
self.num_video_frames = 8
51+
52+
if hasattr(model_config, "fps") and model_config.fps is not None:
53+
self.fps = model_config.fps
54+
else:
55+
self.fps = 0.0
56+
57+
def __getitem__(self, index):
58+
line = self.questions[index]
59+
60+
# load visual
61+
video_name = line["video_name"]
62+
video_formats = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
63+
prepend = ["", "v_"]
64+
video_path = None
65+
for fmt in video_formats:
66+
for pre in prepend:
67+
temp_path = os.path.join(self.image_folder, f"{pre}{video_name}{fmt}")
68+
if os.path.exists(temp_path):
69+
video_path = temp_path
70+
break
71+
if video_path is not None:
72+
break
73+
74+
images, frames_loaded = LazySupervisedDataset._load_video(video_path, self.num_video_frames, self.fps, args)
75+
image_tensor = process_images(images, self.image_processor, self.model_config)
76+
num_frames_loaded_successfully = len(images)
77+
78+
if "Q" in line:
79+
questions = [line["Q"]]
80+
elif "Q1" in line:
81+
questions = [line["Q1"], line["Q2"]]
82+
83+
input_ids_list = []
84+
for qs in questions:
85+
qs = qs.replace("<image>\n", "").replace("\n<image>", "").replace("<image>", "")
86+
qs = qs.replace("<video>\n", "").replace("\n<video>", "").replace("<video>", "")
87+
qs = "<image>\n" * num_frames_loaded_successfully + qs
88+
89+
conv = conv_templates[args.conv_mode].copy()
90+
conv.append_message(conv.roles[0], qs)
91+
conv.append_message(conv.roles[1], None)
92+
prompt = conv.get_prompt()
93+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
94+
input_ids_list.append(input_ids)
95+
96+
return input_ids_list, image_tensor
97+
98+
def __len__(self):
99+
return len(self.questions)
100+
101+
102+
def collate_fn(batch):
103+
input_ids, image_tensors = zip(*batch)
104+
input_ids = list(input_ids)
105+
image_tensors = torch.stack(image_tensors, dim=0)
106+
return input_ids, image_tensors
107+
108+
109+
# DataLoader
110+
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
111+
assert batch_size == 1, "batch_size must be 1"
112+
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
113+
data_loader = DataLoader(
114+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn
115+
)
116+
return data_loader
117+
118+
119+
def get_key(sample_set):
120+
question = sample_set["Q"] if "Q" in sample_set else (sample_set["Q1"] + sample_set["Q2"])
121+
k = question + sample_set["A"] + sample_set["video_name"]
122+
return k
123+
124+
125+
def eval_model(args):
126+
# Model
127+
disable_torch_init()
128+
model_path = os.path.expanduser(args.model_path)
129+
model_name = get_model_name_from_path(model_path)
130+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_name, args.model_base)
131+
args.image_processor = image_processor
132+
133+
conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv_mode]
134+
135+
gt_questions = json.load(open(args.gt_file))
136+
gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
137+
138+
answers_file = os.path.join(args.output_dir, f"{args.output_name}.jsonl")
139+
os.makedirs(args.output_dir, exist_ok=True)
140+
if os.path.exists(answers_file):
141+
with open(answers_file) as f:
142+
cache_ans = f.readlines()
143+
cache_set = list(json.loads(line) for line in cache_ans)
144+
cache_set = {get_key(line) for line in cache_set}
145+
else:
146+
cache_set = set()
147+
148+
ans_file = open(answers_file, "a")
149+
data_loader = create_data_loader(gt_questions, args.image_folder, tokenizer, image_processor, model.config)
150+
151+
for (input_ids_list, image_tensor), sample_q in tqdm(zip(data_loader, gt_questions), total=len(gt_questions)):
152+
input_ids_list = input_ids_list[0]
153+
sample_set = sample_q
154+
if get_key(sample_set) in cache_set:
155+
print(f"skip exist answer")
156+
continue
157+
outputs_list = []
158+
for input_ids in input_ids_list:
159+
input_ids = input_ids.to(device="cuda", non_blocking=True).unsqueeze(0)
160+
161+
with torch.inference_mode():
162+
output_ids = model.generate(
163+
input_ids,
164+
images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True),
165+
do_sample=True if args.temperature > 0 else False,
166+
temperature=args.temperature,
167+
top_p=args.top_p,
168+
num_beams=args.num_beams,
169+
max_new_tokens=args.max_new_tokens,
170+
use_cache=True,
171+
)
172+
173+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
174+
outputs_list.append(outputs)
175+
176+
if len(outputs_list) == 1:
177+
sample_set["pred"] = outputs_list[0]
178+
elif len(outputs_list) == 2:
179+
sample_set["pred1"] = outputs_list[0]
180+
sample_set["pred2"] = outputs_list[1]
181+
182+
ans_file.write(json.dumps(sample_set) + "\n")
183+
ans_file.flush()
184+
ans_file.close()
185+
186+
187+
if __name__ == "__main__":
188+
parser = argparse.ArgumentParser()
189+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
190+
parser.add_argument("--model-base", type=str, default=None)
191+
parser.add_argument("--image-folder", type=str, default="")
192+
parser.add_argument(
193+
"--gt_file", help="Path to the ground truth file containing question and answer.", required=True
194+
)
195+
parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True)
196+
parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True)
197+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
198+
parser.add_argument("--num-chunks", type=int, default=1)
199+
parser.add_argument("--chunk-idx", type=int, default=0)
200+
parser.add_argument("--temperature", type=float, default=0.2)
201+
parser.add_argument("--top_p", type=float, default=None)
202+
parser.add_argument("--num_beams", type=int, default=1)
203+
parser.add_argument("--max_new_tokens", type=int, default=1024)
204+
args = parser.parse_args()
205+
206+
eval_model(args)

llava/eval/model_vqa_video.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def get_model_output(model, image_processor, tokenizer, video_path, qs, args):
8989
if conv.sep_style == SeparatorStyle.LLAMA_3:
9090
keywords = [conv.sep, conv.sep2]
9191
stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)]
92+
stop_str = None
9293
else:
9394
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
9495
keywords = [stop_str]
@@ -111,7 +112,7 @@ def get_model_output(model, image_processor, tokenizer, video_path, qs, args):
111112

112113
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
113114
outputs = outputs.strip()
114-
if outputs.endswith(stop_str):
115+
if stop_str is not None and outputs.endswith(stop_str):
115116
outputs = outputs[: -len(stop_str)]
116117
outputs = outputs.strip()
117118
return outputs

llava/eval/video/eval_benchmark_correctness.py renamed to llava/eval/video/eval_benchmark_1_correctness.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1-
# import openai
21
import argparse
32
import ast
43
import json
54
import os
65
from multiprocessing.pool import Pool
76

87
import openai
8+
from openai import BadRequestError
9+
10+
from .utils import get_client
11+
12+
client = None
913

1014

1115
def parse_args():
1216
parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
1317
parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
1418
parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
1519
parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
16-
parser.add_argument("--api_key", help="OpenAI API key.")
17-
parser.add_argument("--api_base", default="", type=str, help="OpenAI API base.")
20+
parser.add_argument("--api_key", required=True, help="OpenAI API key.")
21+
parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base.")
1822
parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
23+
parser.add_argument("--model", default="gpt-3.5-turbo", type=str, help="OpenAI model.")
1924
args = parser.parse_args()
2025
return args
2126

@@ -26,6 +31,9 @@ def annotate(prediction_set, caption_files, output_dir, args):
2631
Returns a score for correctness.
2732
"""
2833
# Set the OpenAI API key.
34+
openai.api_key = args.api_key
35+
if args.api_base is not None:
36+
openai.api_base = args.api_base
2937
for file in caption_files:
3038
key = file[:-5] # Strip file extension
3139
qa_set = prediction_set[key]
@@ -34,8 +42,9 @@ def annotate(prediction_set, caption_files, output_dir, args):
3442
pred = qa_set["pred"]
3543
try:
3644
# Compute the correctness score
37-
completion = openai.chat.completions.create(
38-
model="gpt-4",
45+
# completion = create_chat_completion(
46+
completion = client.chat.completions.create(
47+
model=args.model,
3948
messages=[
4049
{
4150
"role": "system",
@@ -63,13 +72,20 @@ def annotate(prediction_set, caption_files, output_dir, args):
6372
)
6473
# Convert response to a Python dictionary.
6574
response_message = completion.choices[0].message.content
66-
# response_message = completion["choices"][0]["message"]["content"]
6775
response_dict = ast.literal_eval(response_message)
6876
result_qa_pair = [response_dict, qa_set]
6977

7078
# Save the question-answer pairs to a json file.
7179
with open(f"{output_dir}/{key}.json", "w") as f:
7280
json.dump(result_qa_pair, f)
81+
except BadRequestError as e:
82+
print(f"BadRequestError processing file '{key}': {e}")
83+
response_dict = {"score": 0}
84+
qa_set["pred"] = ""
85+
result_qa_pair = [response_dict, qa_set]
86+
# Save the question-answer pairs to a json file.
87+
with open(f"{output_dir}/{key}.json", "w") as f:
88+
json.dump(result_qa_pair, f)
7389

7490
except Exception as e:
7591
print(f"Error processing file '{key}': {e}")
@@ -83,7 +99,7 @@ def main():
8399
args = parse_args()
84100

85101
file = open(args.pred_path)
86-
pred_contents = json.load(file)
102+
pred_contents = [eval(i.strip()) for i in file.readlines()]
87103

88104
# Dictionary to store the count of occurrences for each video_id
89105
video_id_counts = {}
@@ -122,14 +138,16 @@ def main():
122138
prediction_set[id] = qa_set
123139

124140
# Set the OpenAI API key.
125-
# openai.api_key = args.api_key
126-
openai.api_key = os.environ["OPENAI_API_KEY"]
141+
openai.api_key = args.api_key
127142
num_tasks = args.num_tasks
128143

129144
# While loop to ensure that all captions are processed.
130145
while True:
131146
try:
132147
# Files that have not been processed yet.
148+
global client
149+
client = get_client()
150+
133151
completed_files = os.listdir(output_dir)
134152
print(f"completed_files: {len(completed_files)}")
135153

@@ -149,12 +167,8 @@ def main():
149167
task_args = [(prediction_set, part, args.output_dir, args) for part in all_parts]
150168

151169
# Use a pool of workers to process the files in parallel.
152-
# with Pool() as pool:
153-
# pool.starmap(annotate, task_args)
154-
from tqdm import tqdm
155-
156-
for task_arg in tqdm(task_args):
157-
annotate(*task_arg)
170+
with Pool() as pool:
171+
pool.starmap(annotate, task_args)
158172

159173
except Exception as e:
160174
print(f"Error: {e}")
@@ -188,6 +202,16 @@ def main():
188202

189203
print("Average score for correctness:", average_score)
190204

205+
result_file = os.path.join(os.path.dirname(os.path.dirname(args.output_json)), "results.json")
206+
sample_set = {"gpt": args.model, "task": "1_correctness", "score": average_score}
207+
with open(result_file, "a") as f:
208+
f.write(json.dumps(sample_set) + "\n")
209+
191210

192211
if __name__ == "__main__":
212+
import time
213+
214+
start_time = time.time()
193215
main()
216+
end_time = time.time()
217+
print(f"took {end_time - start_time} seconds")

0 commit comments

Comments
 (0)