Skip to content

Commit

Permalink
added mantis_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Lo committed May 25, 2024
1 parent c1bdb64 commit e83c642
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 33 deletions.
17 changes: 15 additions & 2 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"vizwiz",
"textvqa",
"gqa",
"mantiseval",
"hateful_memes",
"imagenet",
]
Expand Down Expand Up @@ -107,14 +108,26 @@ def get_img_path(self, question):
return os.path.join(self.image_dir_path, question["image_id"])
elif self.dataset_name == "textvqa" or self.dataset_name == "gqa":
return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg")
elif self.dataset_name == "mantiseval":
img_paths = []
for img_id in question['image_id']:
img_paths.append(os.path.join(self.image_dir_path, f"{img_id}.jpg"))
return img_paths
else:
raise Exception(f"Unknown VQA dataset {self.dataset_name}")

def __getitem__(self, idx):
question = self.questions[idx]
img_path = self.get_img_path(question)
image = Image.open(img_path)
image.load()
if self.dataset_name == "mantiseval":
image = []
for path in img_path:
img = Image.open(path)
img.load()
image.append(img)
else:
image = Image.open(img_path)
image.load()
results = {
"image": image,
"question": question["question"],
Expand Down
10 changes: 9 additions & 1 deletion open_flamingo/eval/eval_models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from transformers import Blip2Processor, Blip2ForConditionalGeneration
from eval_models.eval_model import BaseEvalModel
from utils import unwrap_model
from utils import unwrap_model, combine_images
from transformers.modeling_outputs import CausalLMOutputWithPast


Expand All @@ -27,9 +27,14 @@ def required_args(self):

def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
batch_images = None
for i in range(len(batch)):
if len(batch[i]) > 1:
batch[i] = combine_images(batch[i])
"""
assert all(
len(example) == 1 for example in batch
), "BLIP-2 only supports one image per example"
"""
for example in batch:
if batch_images is None:
batch_images = self.processor.image_processor(
Expand Down Expand Up @@ -111,6 +116,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str:

def get_gqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_mantiseval_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"
Expand Down
3 changes: 3 additions & 0 deletions open_flamingo/eval/eval_models/open_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
def get_gqa_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_mantiseval_prompt(self, question, answer=None) -> str:
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

Expand Down
70 changes: 40 additions & 30 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
HatefulMemesDataset,
)
from ok_vqa_utils import postprocess_ok_vqa_generation
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation, compute_mantis_accuracy

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -152,29 +152,30 @@
default=None,
)

## VQAV2, OK-VQA, VizWiz, TextVQA, GQA Datasets
for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa']:
## VQAV2, OK-VQA, VizWiz, TextVQA, GQA, Mantis-Eval Datasets
for task in ['vqav2', 'okvqa', 'vizwiz', 'textvqa', 'gqa', 'mantiseval']:
parser.add_argument(
f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' else f"--{task}_train_image_dir_path",
f"--{task}_image_dir_path" if task=='gqa' or task=='textvqa' or task=='mantiseval' else f"--{task}_train_image_dir_path",
type=str,
default=None,
)
if task!='gqa' and task!='textvqa':
if task != 'mantiseval':
if task!='gqa' and task!='textvqa':
parser.add_argument(
f"--{task}_test_image_dir_path",
type=str,
default=None,
)
parser.add_argument(
f"--{task}_test_image_dir_path",
f"--{task}_train_questions_json_path",
type=str,
default=None,
)
parser.add_argument(
f"--{task}_train_annotations_json_path",
type=str,
default=None,
)
parser.add_argument(
f"--{task}_train_questions_json_path",
type=str,
default=None,
)
parser.add_argument(
f"--{task}_train_annotations_json_path",
type=str,
default=None,
)
parser.add_argument(
f"--{task}_test_questions_json_path",
type=str,
Expand Down Expand Up @@ -315,7 +316,7 @@ def main():
}
)

for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
for vqa_task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
if var_args[f"eval_{vqa_task}"]:
print(f"Evaluating on {vqa_task}...")

Expand Down Expand Up @@ -601,16 +602,16 @@ def evaluate_vqa(
float: accuracy score
"""
var_args = vars(args)
for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
for task in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
if dataset_name == task:
task = task
train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"]
train_questions_json_path = var_args[f"{task}_train_questions_json_path"]
train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"]
test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" else f"{task}_image_dir_path"]
train_image_dir_path = var_args[f"{task}_train_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"]
train_questions_json_path = var_args[f"{task}_train_questions_json_path"] if task!="mantiseval" else var_args[f"{task}_test_questions_json_path"]
train_annotations_json_path = var_args[f"{task}_train_annotations_json_path"] if task!="mantiseval" else var_args[f"{task}_test_annotations_json_path"]
test_image_dir_path = var_args[f"{task}_test_image_dir_path" if task!="textvqa" and task!="gqa" and task!="mantiseval" else f"{task}_image_dir_path"]
test_questions_json_path = var_args[f"{task}_test_questions_json_path"]
test_annotations_json_path = var_args[f"{task}_test_annotations_json_path"]
if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa"]:
if dataset_name not in ["okvqa", "vqav2", "vizwiz", "textvqa", "gqa", "mantiseval"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")

train_dataset = VQADataset(
Expand Down Expand Up @@ -675,7 +676,10 @@ def evaluate_vqa(
context_images = [x["image"] for x in batch_demo_samples[i]]
else:
context_images = []
batch_images.append(context_images + [batch["image"][i]])
if dataset_name == "mantiseval":
batch_images.append(context_images + batch["image"][i])
else:
batch_images.append(context_images + [batch["image"][i]])

context_text = "".join(
[
Expand Down Expand Up @@ -703,7 +707,7 @@ def evaluate_vqa(
num_beams=num_beams,
length_penalty=length_penalty,
)

process_function = (
postprocess_ok_vqa_generation
if dataset_name == "okvqa"
Expand Down Expand Up @@ -732,11 +736,17 @@ def evaluate_vqa(
f.write(json.dumps(all_predictions, indent=4))

if test_annotations_json_path is not None:
acc = compute_vqa_accuracy(
f"{dataset_name}results_{random_uuid}.json",
test_questions_json_path,
test_annotations_json_path,
)
if dataset_name == "mantiseval":
acc = compute_mantis_accuracy(
f"{dataset_name}results_{random_uuid}.json",
test_annotations_json_path,
)
else:
acc = compute_vqa_accuracy(
f"{dataset_name}results_{random_uuid}.json",
test_questions_json_path,
test_annotations_json_path,
)
# delete the temporary file
os.remove(f"{dataset_name}results_{random_uuid}.json")

Expand Down
23 changes: 23 additions & 0 deletions open_flamingo/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import torch.nn as nn
from contextlib import suppress
from PIL import Image


def random_seed(seed=42, rank=0):
Expand Down Expand Up @@ -122,3 +123,25 @@ def get_autocast(precision):
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress

def combine_images(images):
img_heights, _ = zip(*(img.size for img in images))
avg_height = sum(img_heights) // len(img_heights)
for i, img in enumerate(images):
images[i] = img.resize((int(img.size[0] * avg_height / img.size[1]), avg_height))
resized_heights, resized_widths = zip(*(img.size for img in images))
total_width = sum(resized_widths)
max_height = max(resized_heights)
new_img = Image.new("RGB", (total_width + 10 * (len(images) - 1), max_height))
x_offset = 0
for i, img in enumerate(images):
if i > 0:
new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0))
x_offset += 1
new_img.paste(Image.new("RGB", (8, max_height), (255, 255, 255)), (x_offset, 0))
x_offset += 8
new_img.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0))
x_offset += 1
new_img.paste(img, (x_offset, 0))
x_offset += img.size[0]
return new_img
20 changes: 20 additions & 0 deletions open_flamingo/eval/vqa_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,26 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p

return vqaEval.accuracy["overall"]

def compute_mantis_accuracy(result_json_path, annotation_json_path):
dataset = json.load(open(annotation_json_path, "r"))
gt_ans = {}
for ann in dataset["annotations"]:
gt_ans[ann["question_id"]] = {"answer": ann["answers"][0]["answer"], "type": ann["question_type"]}
results = json.load(open(result_json_path, "r"))
assert type(results) == list, "results is not an array of objects"
correct = 0
for res in results:
res_ans = res["answer"].lower().strip('()\n ')
if gt_ans[res["question_id"]]["type"] == "multi-choice":
if len(res_ans) > 1:
for c in res_ans:
if c.isalpha():
res_ans = c
break
if res_ans == gt_ans[res["question_id"]]["answer"].lower().strip('()\n '):
correct+=1
acc = correct / len(results)
return acc

def postprocess_vqa_generation(predictions):
answer = re.split("Question|Answer|Short", predictions, 1)[0]
Expand Down

0 comments on commit e83c642

Please sign in to comment.