Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fc1c777
added rices for picking demos for captioning and vqa
anas-awadalla Jul 1, 2023
33bf8c5
add caching
anas-awadalla Jul 2, 2023
50df567
script to cache RICES features; RICES DDP
Jul 2, 2023
e5ea7a0
RICES for Hatefulmemes
Jul 3, 2023
9278322
make RICES a command line argument
Jul 6, 2023
8e0ee06
remove DDP
Jul 7, 2023
cb06a69
refactor classification
i-gao Jul 9, 2023
4fd3988
add prompt ensembling
i-gao Jul 10, 2023
2aba5d7
enforce correct class ordering
Jul 10, 2023
8a677fd
refactor cached classification
Jul 10, 2023
b50a4d9
add waterbirds and celeba
i-gao Jul 11, 2023
da25e19
fix classification caching
Jul 12, 2023
086eb6d
Merge branch 'rices' into wilds
i-gao Jul 12, 2023
9b88177
Merge branch 'main' into wilds
i-gao Aug 5, 2023
8759b5b
waterbirds
i-gao Aug 5, 2023
992e71d
add wandb; refactor
i-gao Aug 6, 2023
cfba7ae
adding camelyon17
ssagawa Aug 16, 2023
fb3afad
class-conditional sampling
i-gao Aug 24, 2023
fef4df2
merge
ssagawa Aug 25, 2023
1d84adf
remove length normalization
i-gao Aug 25, 2023
45cb422
rices + class-conditional
i-gao Aug 28, 2023
b1e6049
fix
ssagawa Aug 29, 2023
615b085
gather on cpu
ssagawa Aug 29, 2023
817eda4
fix memory leak
i-gao Sep 20, 2023
20a1374
rename to eval_models
i-gao Sep 21, 2023
670a171
attempt to merge mllm -> wilds
i-gao Sep 21, 2023
a885c25
Merge branch 'wilds' into merge-wilds-mllm
i-gao Sep 21, 2023
96daeec
fixes
i-gao Sep 21, 2023
3f1d66b
fixes
i-gao Sep 21, 2023
b51caf9
kwarg change
i-gao Oct 9, 2023
17cd546
fixes for num_beams > 1
i-gao Oct 9, 2023
e104a0a
fix for null past_key_values, getting supported tasks
hannahyklee Oct 24, 2023
55a8f7c
add coco dataset name to evaluate.py
hannahyklee Oct 25, 2023
aa995e3
fix generation use_cache issue for mpt
i-gao Nov 1, 2023
5d10063
Merge branch 'merge-wilds-mllm' into mllm-eval-hl
anas-awadalla Nov 29, 2023
427d974
Merge pull request #275 from hannahyklee/mllm-eval-hl
anas-awadalla Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions open_flamingo/eval/classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,13 @@
"no",
"yes",
]

WATERBIRDS_CLASSNAMES = [
"landbird",
"waterbird",
]

CAMELYON17_CLASSNAMES = [
"normal tissue",
"tumor tissue",
]
60 changes: 57 additions & 3 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@
import os

from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder

from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES
from open_flamingo.eval.classification_utils import (
IMAGENET_CLASSNAMES,
WATERBIRDS_CLASSNAMES,
CAMELYON17_CLASSNAMES,
)


SUPPORTED_TASKS = [
"coco",
"flickr",
"flickr30",
"vqav2",
"ok_vqa",
"vizwiz",
"textvqa",
"hateful_memes",
"imagenet",
"waterbirds",
"camelyon17",
]


Expand Down Expand Up @@ -58,7 +66,7 @@ def __getitem__(self, idx):
self.image_val_dir_path, self.annotations[idx]["filename"]
)
)
elif self.dataset_name == "flickr":
elif self.dataset_name == "flickr30":
image = Image.open(
os.path.join(
self.image_train_dir_path, self.annotations[idx]["filename"]
Expand Down Expand Up @@ -133,6 +141,7 @@ def __init__(self, root, **kwargs):
self.class_id_to_name = dict(
zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES)
)
self.class_id_array = torch.tensor([y for _, y in self.samples])

def __getitem__(self, idx):
sample, target = super().__getitem__(idx)
Expand All @@ -150,6 +159,7 @@ def __init__(self, image_dir_path, annotations_path):
self.image_dir_path = image_dir_path
with open(annotations_path, "r") as f:
self.annotations = [json.loads(line) for line in f]
self.class_id_array = torch.tensor([y["label"] for y in self.annotations])

def __len__(self):
return len(self.annotations)
Expand All @@ -166,3 +176,47 @@ def __getitem__(self, idx):
"class_name": "yes" if annotation["label"] == 1 else "no",
"class_id": annotation["label"],
}


class WILDSDataset(Dataset):
def __init__(self, dataset_name: str, split: str, root_dir: str):
import wilds

full_dataset = wilds.get_dataset(
dataset_name,
root_dir=root_dir,
download=True,
)
self.dataset = full_dataset.get_subset(split)
if dataset_name == "waterbirds":
self.class_id_to_name = {i: s for i, s in enumerate(WATERBIRDS_CLASSNAMES)}
self.grouper = wilds.common.grouper.CombinatorialGrouper(
dataset=full_dataset,
groupby_fields=["background", "y"],
)
elif dataset_name == "camelyon17":
self.class_id_to_name = {i: s for i, s in enumerate(CAMELYON17_CLASSNAMES)}
self.grouper = wilds.common.grouper.CombinatorialGrouper(
dataset=full_dataset,
groupby_fields=["hospital"],
)
else:
raise Exception(f"Unimplemented WILDS dataset {dataset_name}")
self.class_id_array = self.dataset.y_array

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
x, y, m = self.dataset[idx]
y = y.item()
return {
"id": idx,
"image": x,
"class_id": y,
"class_name": self.class_id_to_name[y],
"domain": self.grouper.group_str(
self.grouper.metadata_to_group(m.unsqueeze(0)).item()
),
"metadata": m,
}
13 changes: 6 additions & 7 deletions open_flamingo/eval/eval_models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import torch

from transformers import Blip2Processor, Blip2ForConditionalGeneration
from models.eval_model import BaseEvalModel
from utils import unwrap_model
from transformers.modeling_outputs import CausalLMOutputWithPast
from open_flamingo.eval.eval_models import BaseEvalModel
from open_flamingo.eval.utils import unwrap_model


class EvalModel(BaseEvalModel):
Expand Down Expand Up @@ -99,18 +98,18 @@ def get_outputs(

def get_vqav2_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_ok_vqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_vizwiz_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_textvqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"
7 changes: 1 addition & 6 deletions open_flamingo/eval/eval_models/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class BaseEvalModel(abc.ABC):

def __init__(self, model_args: List[str], init_on_device=False):
"""Initialize model.

Args:
args: arguments to model. These should be parsed, or if the model
has no applicable arguments, an error should be thrown if `args`
Expand Down Expand Up @@ -128,13 +127,11 @@ def prepare_text(
):
"""
Prepare text for model. Note that padding is always on the left.

Args:
batch: list of text strings
padding: whether to pad the text
truncation: whether to truncate the text
max_length: maximum length of the text

Returns:
input_ids: tensor of shape (B, T_txt)
attention_mask: tensor of shape (B, T_txt)
Expand All @@ -158,12 +155,10 @@ def get_outputs(
**decode_kwargs,
) -> List[str]:
"""Call generate on a batch of images and text.

Args:
batch_text: list of text strings
batch_images: images to provide to model. Should be a list of lists,
where each list contains the images for a single example.

Returns:
List of decoded output strings.
"""
Expand Down Expand Up @@ -200,7 +195,7 @@ def supported_tasks(self):
Parsed by checking whether the model has a method called `get_{task}_prompt`.
"""
return [
task.split("_")[1]
"_".join(task.split("_")[1:-1])
for task in dir(self)
if task.startswith("get_") and task.endswith("_prompt")
]
Expand Down
2 changes: 1 addition & 1 deletion open_flamingo/eval/eval_models/idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,6 @@ def get_coco_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
30 changes: 17 additions & 13 deletions open_flamingo/eval/eval_models/open_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from open_flamingo.eval.eval_models.eval_model import BaseEvalModel
from open_flamingo.src.factory import create_model_and_transforms
from open_flamingo.eval.utils import unwrap_model
from transformers.modeling_outputs import CausalLMOutputWithPast
from open_flamingo.src import VLMOutputWithPast


class EvalModel(BaseEvalModel):
Expand All @@ -33,11 +33,12 @@ def __init__(self, model_args, init_on_device=False):
)

# load the checkpoint
checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu")
if "model_state_dict" in checkpoint:
checkpoint = checkpoint["model_state_dict"]
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
self.model.load_state_dict(checkpoint, strict=False)
if "checkpoint_path" in model_args:
checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu")
if "model_state_dict" in checkpoint:
checkpoint = checkpoint["model_state_dict"]
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
self.model.load_state_dict(checkpoint, strict=False)

self._check_init()

Expand All @@ -46,9 +47,8 @@ def required_args(self):
"""Return list of required arguments to initialize model."""
return [
"vision_encoder_path",
"model_familyl",
"model_family",
"lm_path",
"checkpoint_path",
"tokenizer_path",
"cross_attn_every_n_layers",
"vision_encoder_pretrained",
Expand Down Expand Up @@ -92,7 +92,6 @@ def __call__(
vision_x=vision_x,
lang_x=_lang_x,
attention_mask=_attention_mask,
clear_conditioned_layers=False,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
Expand All @@ -105,7 +104,7 @@ def __call__(
logits.append(outputs.logits)

logits = torch.cat(logits, dim=1)
return CausalLMOutputWithPast(
return VLMOutputWithPast(
logits=logits,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
Expand Down Expand Up @@ -195,10 +194,9 @@ def get_rank_classifications(
if use_cache:
with torch.inference_mode():
precomputed = self.__call__(
vision_x=None,
vision_x=batch_images,
lang_x=ctx_input_ids,
attention_mask=ctx_attention_mask,
clear_conditioned_layers=False,
use_cache=True,
)

Expand Down Expand Up @@ -282,11 +280,17 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
def get_coco_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
def get_flickr30_prompt(self, caption=None) -> str:
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_imagenet_prompt(self, label=None) -> str:
return f"<image>Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"

def get_hateful_memes_prompt(self, text, label=None) -> str:
return f"<image>is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"

def get_waterbirds_prompt(self, label=None) -> str:
return f"<image>Question: Is this a landbird or waterbird? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"

def get_camelyon17_prompt(self, label=None) -> str:
return f"<image>Question: Is this a normal tissue or cancer tissue? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
Loading