-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
251 lines (226 loc) · 11.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
main file to call the explanations methods and run experiments to generate saliency heatmap
© copyright 2024 Bytedance Ltd. and/or its affiliates.
Modified from Tyler Lawson, Saeed khorram. https://github.com/saeed-khorram/IGOS
"""
import torchvision.models as models
from torch.autograd import Variable
from datasets import load_dataset
from args import init_args
from utils import *
from methods_helper import *
from methods import iGOS_p, iGOS_pp
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_vicuna_v1
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
from cambrian.conversation import conv_llama_3, conv_phi3
from cambrian.mm_utils import process_images_cambrian
from cambrian.model.builder import load_pretrained_model_cambrian
from llava_next.model.builder import load_pretrained_model_next
from llava_next.conversation import conv_qwen
from llava_next.mm_utils import process_images_next
from mgm.model.builder import load_pretrained_model_mgm
from mgm.mm_utils import process_images_mgm
import pandas as pd
import cv2
import json
import os
import random
import pickle
from PIL import Image
join = os.path.join
def gen_explanations_llava(model, image_processor, data, args):
setting = f'{args.method}_L1_{args.L1}_L2_{args.L2}_L3_{args.L3}_momentum_{args.momentum}_igiter_{args.ig_iter}_lr_{args.lr}_iter_{args.iterations}'
out_dir = os.path.join(args.output_dir, setting)
os.makedirs(out_dir, exist_ok=True)
with open(join(out_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, indent=4)
if args.method == "iGOS+":
method = iGOS_p
elif args.method == "iGOS++":
method = iGOS_pp
else:
raise ValueError("the method does not exist.")
i_obj = 0
total_del, total_ins, total_time = 0, 0, 0
all_del_scores = []
all_ins_scores = []
save_list = []
for i_img in range(len(data)):
row = data[i_img]
image, qs, cur_prompt = get_data(args, row)
image_size = [image.size]
kernel_size = get_kernel_size(image.size)
if args.ablation_noise:
noise = np.random.randint(0, 256, size=np.asarray(image).shape, dtype=np.uint8)
blur = Image.fromarray(noise)
elif args.ablation_zero:
blur = np.zeros_like(np.asarray(image))
blur = Image.fromarray(blur.astype(np.uint8))
else:
blur = cv2.GaussianBlur(np.asarray(image), (kernel_size, kernel_size), sigmaX=kernel_size-1)
blur = Image.fromarray(blur.astype(np.uint8))
if args.model == 'cambrian':
image_tensor = process_images_cambrian([image], image_processor, model.config)
blur_tensor = process_images_cambrian([blur], image_processor, model.config)
# use different conversation template for different model size
# conv = conv_vicuna_v1.copy() # cambrian-13b
conv = conv_llama_3.copy() # cambrian-8b
# conv = conv_phi3.copy() # cambrian-phi3-3b
elif args.model == 'llava':
image_tensor = process_images([image], image_processor, model.config)[0]
image_tensor = image_tensor.unsqueeze(0).detach().half().cuda()
blur_tensor = process_images([blur], image_processor, model.config)[0]
blur_tensor = blur_tensor.unsqueeze(0).detach().half().cuda()
conv = conv_vicuna_v1.copy()
elif args.model == 'llava_next':
image_tensor, resolution = process_images_next([image], image_processor, model.config)
image_tensor = image_tensor[0].to(model.device).half()
blur_tensor, _ = process_images_next([blur], image_processor, model.config)
blur_tensor = blur_tensor[0].to(model.device).half()
conv = conv_qwen.copy()
elif args.model == 'mgm':
image_tensor = process_images_mgm([image], image_processor, model.config)[0]
image_tensor = image_tensor.unsqueeze(0).detach().half().cuda()
blur_tensor = process_images_mgm([blur], image_processor, model.config)[0]
blur_tensor = blur_tensor.unsqueeze(0).detach().half().cuda()
conv = conv_vicuna_v1.copy()
#print('original size:', image.size, 'processed:', image_tensor.shape)
#image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
output_ids = generate(args, model, input_ids, image_tensor, image_size)
output_text = tokenizer.decode(output_ids[0])
# output_ids_blur = generate(args, model, input_ids, blur_tensor, image_size)
# output_text_blur = tokenizer.decode(output_ids_blur[0])
# # print('output text blurred:', output_text_blur)
positions, keywords = find_keywords(args, model, input_ids, output_ids, image_tensor, blur_tensor,
image_size, output_text, tokenizer, args.use_yake)
if len(positions) == 0:
continue
# print('----------------')
# print(cur_prompt)
# print(output_text)
# print('detected keywords:', keywords)
pred_data=dict()
pred_data['labels'] = output_ids.detach()
pred_data['keywords'] = positions
pred_data['boxes'] = np.array([[0, 0, args.input_size, args.input_size]])
pred_data['no_res'] = False
pred_data['pred_text'] = output_text
pred_data['keywords_text'] = keywords
# calculate init area
pred_data = get_initial(pred_data, args.diverse_k, args.init_posi,
args.init_val, args.input_size, args.size)
for l_i, label in enumerate(pred_data['labels']):
#for l_i, keyword in enumerate(pred_data['keywords']):
label = label.unsqueeze(0)
keyword = pred_data['keywords']
now = time.time()
masks, loss_del, loss_ins, loss_l1, loss_tv, loss_l2, loss_comb_del, loss_comb_ins = method(
args,
model=model,
model_name=args.model,
init_mask=pred_data['init_masks'][0],
image=image_tensor,
baseline=blur_tensor,
label=label,
size=args.size,
iterations=args.iterations,
ig_iter=args.ig_iter,
L1=args.L1,
L2=args.L2,
L3=args.L3,
lr=args.lr,
opt=args.opt,
prompt=input_ids,
image_size=image_size,
positions=keyword,
resolution=resolution if args.model=='llava_next' else None
)
total_time += time.time() - now
# Calculate the scores for the masks
del_scores, ins_scores, del_curve, ins_curve, index = metric(
args,
image_tensor,
blur_tensor,
masks.detach(),
model,
args.model,
label,
l_i,
pred_data,
size=args.size,
prompt=input_ids,
image_size=image_size,
positions=keyword,
resolution=resolution if args.model=='llava_next' else None,
)
# save heatmaps, images, and del/ins curves
classes=None
cur_prompt = cur_prompt.strip().replace('\n','_')
save_heatmaps(masks, image_tensor, args.size, i_img, l_i, out_dir, args.model, None, classes, label)
save_curves(del_curve, ins_curve, index, i_img, l_i, out_dir)
save_images(image_tensor, i_img, l_i, out_dir, classes, label, pred_data, text=cur_prompt)
save_loss(loss_del, loss_ins, loss_l1, loss_tv, loss_l2, i_img, l_i, out_dir, loss_comb_del, loss_comb_ins)
# log info
all_del_scores.append(del_scores.sum().item())
all_ins_scores.append(ins_scores.sum().item())
i_obj += 1
with open(join(out_dir, f'{i_img}_{l_i}_output.txt'), 'w', encoding='utf-8') as f:
f.write('prompt: '+ cur_prompt + '\n')
f.write('answer: '+ pred_data['pred_text'] + '\n')
f.write('keywords: ' + ','.join(pred_data['keywords_text']) + '\n')
# f.write('answer with blurred image: ' + output_text_blur + '\n')
f.write(f'Deletion (Avg.): {del_scores.sum().item():.05f}' + '\n')
f.write(f'Insertion (Avg.): {ins_scores.sum().item():.05f}' + '\n')
f.write(f'Time (Avg.): {total_time / i_obj:.03f}' + '\n')
eprint(
f' Deletion (Avg.): {del_scores.sum().item():.05f}'
f' Insertion (Avg.): {ins_scores.sum().item():.05f}'
f' Time (Avg.): {total_time / i_obj:.03f}'
)
if __name__ == "__main__":
args = init_args()
eprint(f"args:\n {args}")
torch.manual_seed(args.manual_seed)
random.seed(args.manual_seed)
disable_torch_init()
if args.model == 'cambrian':
model_path = "nyu-visionx/cambrian-8b"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model_cambrian(model_path, args.model_base, model_name)
elif args.model == 'llava':
model_path = 'liuhaotian/llava-v1.5-7b'
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
elif args.model == 'llava_next':
model_path = "lmms-lab/llava-onevision-qwen2-72b-ov"
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model_next(model_path, args.model_base, model_name, device_map='auto')
elif args.model == 'mgm':
model_path = 'mgm/work_dirs/mgm13bhd'
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model_mgm(model_path, args.model_base, model_name)
for param in model.parameters():
param.requires_grad = False
model.gradient_checkpointing = True
if args.data_path.endswith('csv'):
df = pd.read_csv(args.data_path)
data = df.to_dict(orient='records')
elif args.data_path.endswith('jsonl'):
data = [json.loads(q) for q in open(args.data_path, "r")]
elif args.data_path.endswith('json'):
data = json.load(open(args.data_path))
elif args.data_path.endswith('pkl'):
data = pickle.load(open(args.data_path, 'rb'))
else:
data = load_dataset(args.data_path, "val")["val"].to_pandas()
data = data.to_dict(orient="records")
print('total number of data samples:', len(data))
gen_explanations_llava(model, image_processor, data, args)