|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import math |
| 4 | +import os |
| 5 | + |
| 6 | +import shortuuid |
| 7 | +import torch |
| 8 | +from decord import VideoReader |
| 9 | +from PIL import Image |
| 10 | +from torch.utils.data import DataLoader, Dataset |
| 11 | +from tqdm import tqdm |
| 12 | + |
| 13 | +from llava import conversation as conversation_lib |
| 14 | +from llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
| 15 | +from llava.conversation import SeparatorStyle, conv_templates |
| 16 | +from llava.data.dataset import LazySupervisedDataset |
| 17 | +from llava.mm_utils import ( |
| 18 | + KeywordsStoppingCriteria, |
| 19 | + get_model_name_from_path, |
| 20 | + is_gemma_tokenizer, |
| 21 | + process_images, |
| 22 | + tokenizer_image_token, |
| 23 | +) |
| 24 | +from llava.model.builder import load_pretrained_model |
| 25 | +from llava.utils import disable_torch_init |
| 26 | + |
| 27 | + |
| 28 | +def split_list(lst, n): |
| 29 | + """Split a list into n (roughly) equal-sized chunks""" |
| 30 | + chunk_size = math.ceil(len(lst) / n) # integer division |
| 31 | + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] |
| 32 | + |
| 33 | + |
| 34 | +def get_chunk(lst, n, k): |
| 35 | + chunks = split_list(lst, n) |
| 36 | + return chunks[k] |
| 37 | + |
| 38 | + |
| 39 | +# Custom dataset class |
| 40 | +class CustomDataset(Dataset): |
| 41 | + def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): |
| 42 | + self.questions = questions |
| 43 | + self.image_folder = image_folder |
| 44 | + self.tokenizer = tokenizer |
| 45 | + self.image_processor = image_processor |
| 46 | + self.model_config = model_config |
| 47 | + if hasattr(model_config, "num_video_frames") and model_config.num_video_frames is not None: |
| 48 | + self.num_video_frames = model_config.num_video_frames |
| 49 | + else: |
| 50 | + self.num_video_frames = 8 |
| 51 | + |
| 52 | + if hasattr(model_config, "fps") and model_config.fps is not None: |
| 53 | + self.fps = model_config.fps |
| 54 | + else: |
| 55 | + self.fps = 0.0 |
| 56 | + |
| 57 | + def __getitem__(self, index): |
| 58 | + line = self.questions[index] |
| 59 | + |
| 60 | + # load visual |
| 61 | + video_name = line["video_name"] |
| 62 | + video_formats = [".mp4", ".avi", ".mov", ".mkv", ".webm"] |
| 63 | + prepend = ["", "v_"] |
| 64 | + video_path = None |
| 65 | + for fmt in video_formats: |
| 66 | + for pre in prepend: |
| 67 | + temp_path = os.path.join(self.image_folder, f"{pre}{video_name}{fmt}") |
| 68 | + if os.path.exists(temp_path): |
| 69 | + video_path = temp_path |
| 70 | + break |
| 71 | + if video_path is not None: |
| 72 | + break |
| 73 | + |
| 74 | + images, frames_loaded = LazySupervisedDataset._load_video(video_path, self.num_video_frames, self.fps, args) |
| 75 | + image_tensor = process_images(images, self.image_processor, self.model_config) |
| 76 | + num_frames_loaded_successfully = len(images) |
| 77 | + |
| 78 | + if "Q" in line: |
| 79 | + questions = [line["Q"]] |
| 80 | + elif "Q1" in line: |
| 81 | + questions = [line["Q1"], line["Q2"]] |
| 82 | + |
| 83 | + input_ids_list = [] |
| 84 | + for qs in questions: |
| 85 | + qs = qs.replace("<image>\n", "").replace("\n<image>", "").replace("<image>", "") |
| 86 | + qs = qs.replace("<video>\n", "").replace("\n<video>", "").replace("<video>", "") |
| 87 | + qs = "<image>\n" * num_frames_loaded_successfully + qs |
| 88 | + |
| 89 | + conv = conv_templates[args.conv_mode].copy() |
| 90 | + conv.append_message(conv.roles[0], qs) |
| 91 | + conv.append_message(conv.roles[1], None) |
| 92 | + prompt = conv.get_prompt() |
| 93 | + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
| 94 | + input_ids_list.append(input_ids) |
| 95 | + |
| 96 | + return input_ids_list, image_tensor |
| 97 | + |
| 98 | + def __len__(self): |
| 99 | + return len(self.questions) |
| 100 | + |
| 101 | + |
| 102 | +def collate_fn(batch): |
| 103 | + input_ids, image_tensors = zip(*batch) |
| 104 | + input_ids = list(input_ids) |
| 105 | + image_tensors = torch.stack(image_tensors, dim=0) |
| 106 | + return input_ids, image_tensors |
| 107 | + |
| 108 | + |
| 109 | +# DataLoader |
| 110 | +def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): |
| 111 | + assert batch_size == 1, "batch_size must be 1" |
| 112 | + dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config) |
| 113 | + data_loader = DataLoader( |
| 114 | + dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn |
| 115 | + ) |
| 116 | + return data_loader |
| 117 | + |
| 118 | + |
| 119 | +def get_key(sample_set): |
| 120 | + question = sample_set["Q"] if "Q" in sample_set else (sample_set["Q1"] + sample_set["Q2"]) |
| 121 | + k = question + sample_set["A"] + sample_set["video_name"] |
| 122 | + return k |
| 123 | + |
| 124 | + |
| 125 | +def eval_model(args): |
| 126 | + # Model |
| 127 | + disable_torch_init() |
| 128 | + model_path = os.path.expanduser(args.model_path) |
| 129 | + model_name = get_model_name_from_path(model_path) |
| 130 | + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_name, args.model_base) |
| 131 | + args.image_processor = image_processor |
| 132 | + |
| 133 | + conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv_mode] |
| 134 | + |
| 135 | + gt_questions = json.load(open(args.gt_file)) |
| 136 | + gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx) |
| 137 | + |
| 138 | + answers_file = os.path.join(args.output_dir, f"{args.output_name}.jsonl") |
| 139 | + os.makedirs(args.output_dir, exist_ok=True) |
| 140 | + if os.path.exists(answers_file): |
| 141 | + with open(answers_file) as f: |
| 142 | + cache_ans = f.readlines() |
| 143 | + cache_set = list(json.loads(line) for line in cache_ans) |
| 144 | + cache_set = {get_key(line) for line in cache_set} |
| 145 | + else: |
| 146 | + cache_set = set() |
| 147 | + |
| 148 | + ans_file = open(answers_file, "a") |
| 149 | + data_loader = create_data_loader(gt_questions, args.image_folder, tokenizer, image_processor, model.config) |
| 150 | + |
| 151 | + for (input_ids_list, image_tensor), sample_q in tqdm(zip(data_loader, gt_questions), total=len(gt_questions)): |
| 152 | + input_ids_list = input_ids_list[0] |
| 153 | + sample_set = sample_q |
| 154 | + if get_key(sample_set) in cache_set: |
| 155 | + print(f"skip exist answer") |
| 156 | + continue |
| 157 | + outputs_list = [] |
| 158 | + for input_ids in input_ids_list: |
| 159 | + input_ids = input_ids.to(device="cuda", non_blocking=True).unsqueeze(0) |
| 160 | + |
| 161 | + with torch.inference_mode(): |
| 162 | + output_ids = model.generate( |
| 163 | + input_ids, |
| 164 | + images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True), |
| 165 | + do_sample=True if args.temperature > 0 else False, |
| 166 | + temperature=args.temperature, |
| 167 | + top_p=args.top_p, |
| 168 | + num_beams=args.num_beams, |
| 169 | + max_new_tokens=args.max_new_tokens, |
| 170 | + use_cache=True, |
| 171 | + ) |
| 172 | + |
| 173 | + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| 174 | + outputs_list.append(outputs) |
| 175 | + |
| 176 | + if len(outputs_list) == 1: |
| 177 | + sample_set["pred"] = outputs_list[0] |
| 178 | + elif len(outputs_list) == 2: |
| 179 | + sample_set["pred1"] = outputs_list[0] |
| 180 | + sample_set["pred2"] = outputs_list[1] |
| 181 | + |
| 182 | + ans_file.write(json.dumps(sample_set) + "\n") |
| 183 | + ans_file.flush() |
| 184 | + ans_file.close() |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + parser = argparse.ArgumentParser() |
| 189 | + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") |
| 190 | + parser.add_argument("--model-base", type=str, default=None) |
| 191 | + parser.add_argument("--image-folder", type=str, default="") |
| 192 | + parser.add_argument( |
| 193 | + "--gt_file", help="Path to the ground truth file containing question and answer.", required=True |
| 194 | + ) |
| 195 | + parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True) |
| 196 | + parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True) |
| 197 | + parser.add_argument("--conv-mode", type=str, default="llava_v1") |
| 198 | + parser.add_argument("--num-chunks", type=int, default=1) |
| 199 | + parser.add_argument("--chunk-idx", type=int, default=0) |
| 200 | + parser.add_argument("--temperature", type=float, default=0.2) |
| 201 | + parser.add_argument("--top_p", type=float, default=None) |
| 202 | + parser.add_argument("--num_beams", type=int, default=1) |
| 203 | + parser.add_argument("--max_new_tokens", type=int, default=1024) |
| 204 | + args = parser.parse_args() |
| 205 | + |
| 206 | + eval_model(args) |
0 commit comments