-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Hi!
I evaluated on MIGBench using ROICtrl, but the results don't seem to match those reported in the paper. I also performed a resampling on MIGBench and got a new json file, but the results were still not satisfying.
Could you kindly help me investigate the possible reasons?
My code:
pipe = setup_pipeline(
pretrained_model_path="CompVis/stable-diffusion-v1-4", # you can switch to other community-finetuned model based on sdv14, sdv15
roictrl_path="ROICtrl_sdv14_30K.safetensors"
)
input_data = {
"height": 512,
"width": 512,
"seed": 1234,
"roictrl_scheduled_sampling_beta": 1.0
}
mig_bench_json = "mig_bench_coco.json"
path_name = "roictrl_migbench_coco"
os.makedirs(path_name, exist_ok=True)
with open(mig_bench_json, 'r') as file:
res = json.load(file)
for key, item in res.items():
bboxes = []
prompts = []
prompts_global = item["caption"]
for seg in item["segment"]:
bboxes.append(seg["bbox"])
prompts.append(seg["label"])
input_data['roi_boxes'] = bboxes
input_data['roi_phrases'] = prompts
cross_attention_kwargs = {
'roictrl': encode_roi_input(input_data, pipe, negative_prompt="worst quality, low quality, blurry, low resolution, low quality")
}
image = pipe(
prompt=prompts_global,
negative_prompt="worst quality, low quality, blurry, low resolution, low quality",
generator=torch.Generator().manual_seed(input_data['seed']),
cross_attention_kwargs=cross_attention_kwargs,
height=input_data['height'],
width=input_data['width'],
roictrl_scheduled_sampling_beta=input_data['roictrl_scheduled_sampling_beta']
).images[0]
image.save(os.path.join(path_name, "{}.jpg".format(str(key))))
I used torch.float32 precision for inference, which helps improve the results.
Thank you for your time and assistance!
Metadata
Metadata
Assignees
Labels
No labels