Skip to content

Commit c1fc690

Browse files
committed
introduce edit anythong
1 parent 6f90bfd commit c1fc690

21 files changed

+133
-20
lines changed

main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
1010
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
1111
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
12-
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=False, help='Set this flag to True if you want to use semantic segmentation')
12+
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
13+
parser.add_argument('--region_classify_model', choices=['ssa', 'edit_anything'], dest='region_classify_model', default='edit_anything', help='Select the region classification model: semantic segment anything or edit anything')
1314
parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
1415
parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
15-
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
16+
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended. Make sue this model and image_caption model on same device.')
1617
parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
1718

1819
args = parser.parse_args()

main_gradio.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
1212
parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
1313
parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
14-
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=False, help='Set this flag to True if you want to use semantic segmentation')
14+
parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
15+
parser.add_argument('--region_classify_model', choices=['ssa', 'edit_anything'], dest='region_classify_model', default='edit_anything', help='Select the region classification model: semantic segment anything or edit anything')
1516
parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
1617
parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
17-
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
18+
parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended. Make sue this model and image_caption model on same device.')
1819
parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
1920

2021
args = parser.parse_args()
142 Bytes
Binary file not shown.
4 Bytes
Binary file not shown.
Binary file not shown.
697 Bytes
Binary file not shown.

models/blip2_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import requests
33
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForConditionalGeneration
44
import torch
5+
from utils.util import resize_long_edge
56

67

78
class ImageCaptioning:
@@ -20,12 +21,13 @@ def initialize_model(self):
2021
# )
2122
# for gpu with small memory
2223
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23-
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
24+
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=self.data_type)
2425
model.to(self.device)
2526
return processor, model
2627

2728
def image_caption(self, image_src):
2829
image = Image.open(image_src)
30+
image = resize_long_edge(image, 384)
2931
inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
3032
generated_ids = self.model.generate(**inputs)
3133
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
Binary file not shown.

models/grit_src/image_dense_captions.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from models.grit_src.grit.predictor import VisualizationDemo
1818
import json
19+
from utils.util import resize_long_edge_cv2
1920

2021

2122
# constants
@@ -62,6 +63,7 @@ def image_caption_api(image_src, device):
6263
demo = VisualizationDemo(cfg)
6364
if image_src:
6465
img = read_image(image_src, format="BGR")
66+
img = resize_long_edge_cv2(img, 384)
6567
predictions, visualized_output = demo.run_on_image(img)
6668
new_caption = dense_pred_to_caption(predictions)
6769
return new_caption

models/image_text_transformation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from models.gpt_model import ImageToText
44
from models.controlnet_model import TextToImage
55
from models.region_semantic import RegionSemantic
6-
from utils.util import read_image_width_height, display_images_and_text
6+
from utils.util import read_image_width_height, display_images_and_text, resize_long_edge
77
import argparse
88
from PIL import Image
99
import base64
@@ -33,13 +33,15 @@ def init_models(self):
3333
self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
3434
self.gpt_model = ImageToText(openai_key)
3535
self.controlnet_model = TextToImage(device=self.args.contolnet_device)
36-
self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device)
36+
self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device, image_caption_model=self.image_caption_model, region_classify_model=self.args.region_classify_model)
3737
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
3838

3939

4040
def image_to_text(self, img_src):
4141
# the information to generate paragraph based on the context
4242
self.ref_image = Image.open(img_src)
43+
# resize image to long edge 384
44+
self.ref_image = resize_long_edge(self.ref_image, 384)
4345
width, height = read_image_width_height(img_src)
4446
print(self.args)
4547
if self.args.image_caption:

models/region_semantic.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,57 @@
11
from models.segment_models.semgent_anything_model import SegmentAnything
22
from models.segment_models.semantic_segment_anything_model import SemanticSegment
3+
from models.segment_models.edit_anything_model import EditAnything
34

45

56
class RegionSemantic():
6-
def __init__(self, device):
7+
def __init__(self, device, image_caption_model, region_classify_model='edit_anything'):
78
self.device = device
9+
self.image_caption_model = image_caption_model
10+
self.region_classify_model = region_classify_model
811
self.init_models()
912

1013
def init_models(self):
1114
self.segment_model = SegmentAnything(self.device)
12-
self.semantic_segment_model = SemanticSegment(self.device)
13-
14-
def semantic_prompt_gen(self, anns):
15+
if self.region_classify_model == 'ssa':
16+
self.semantic_segment_model = SemanticSegment(self.device)
17+
elif self.region_classify_model == 'edit_anything':
18+
self.edit_anything_model = EditAnything(self.device, self.image_caption_model)
19+
print('initalize edit anything model')
20+
else:
21+
raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
22+
23+
def semantic_prompt_gen(self, anns, topk=5):
1524
"""
1625
fliter too small objects and objects with low stability score
1726
anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
1827
semantic_prompt: "person: [0.0, 0.0, 0.0, 0.0]; ..."
1928
"""
2029
# Sort annotations by area in descending order
2130
sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
31+
anns_len = len(sorted_annotations)
2232
# Select the top 10 largest regions
23-
top_10_largest_regions = sorted_annotations[:10]
33+
top_10_largest_regions = sorted_annotations[:min(anns_len, topk)]
2434
semantic_prompt = ""
25-
print('\033[1;35m' + '*' * 100 + '\033[0m')
26-
print("\nStep3, Semantic Prompt:")
2735
for region in top_10_largest_regions:
2836
semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
2937
print(semantic_prompt)
3038
print('\033[1;35m' + '*' * 100 + '\033[0m')
3139
return semantic_prompt
3240

33-
def region_semantic(self, img_src):
41+
def region_semantic(self, img_src, region_classify_model='edit_anything'):
42+
print('\033[1;35m' + '*' * 100 + '\033[0m')
43+
print("\nStep3, Semantic Prompt:")
3444
anns = self.segment_model.generate_mask(img_src)
35-
anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
45+
if region_classify_model == 'ssa':
46+
print('generate region supervision with blip2 model....\n')
47+
anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
48+
print('finished...\n')
49+
elif region_classify_model == 'edit_anything':
50+
print('generate region supervision with edit anything model....\n')
51+
anns_w_class = self.edit_anything_model.semantic_class_w_mask(img_src, anns)
52+
print('finished...\n')
53+
else:
54+
raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
3655
return self.semantic_prompt_gen(anns_w_class)
3756

3857
def region_semantic_debug(self, img_src):
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import cv2
2+
import torch
3+
import mmcv
4+
import numpy as np
5+
from PIL import Image
6+
from utils.util import resize_long_edge
7+
8+
class EditAnything:
9+
def __init__(self, device, image_caption_model):
10+
self.device = image_caption_model.device
11+
self.data_type = image_caption_model.data_type
12+
self.image_caption_model = image_caption_model
13+
14+
# working on paraliz these images now
15+
def region_classify_w_blip2(self, image):
16+
inputs = self.image_caption_model.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
17+
generated_ids = self.image_caption_model.model.generate(**inputs)
18+
generated_text = self.image_caption_model.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
19+
return generated_text
20+
21+
def region_level_semantic_api(self, image, anns, topk=5):
22+
"""
23+
rank regions by area, and classify each region with blip2
24+
Args:
25+
image: numpy array
26+
topk: int
27+
Returns:
28+
topk_region_w_class_label: list of dict with key 'class_label'
29+
"""
30+
topk_region_w_class_label = []
31+
if len(anns) == 0:
32+
return []
33+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
34+
for i in range(min(topk, len(sorted_anns))):
35+
ann = anns[i]
36+
m = ann['segmentation']
37+
m_3c = m[:,:, np.newaxis]
38+
m_3c = np.concatenate((m_3c,m_3c,m_3c), axis=2)
39+
bbox = ann['bbox']
40+
region = mmcv.imcrop(image*m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
41+
region_class_label = self.region_classify_w_blip2(region)
42+
ann['class_name'] = region_class_label
43+
# print(ann['class_label'], str(bbox))
44+
topk_region_w_class_label.append(ann)
45+
return topk_region_w_class_label
46+
47+
def semantic_class_w_mask(self, img_src, anns):
48+
image = Image.open(img_src)
49+
image = resize_long_edge(image, 384)
50+
return self.region_level_semantic_api(image, anns)

models/segment_models/semantic_segment_anything_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pycocotools.mask as maskUtils
1111
from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
1212
from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
13+
from utils.util import resize_long_edge, resize_long_edge_cv2
1314
# from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
1415

1516
nlp = spacy.load('en_core_web_sm')
@@ -113,6 +114,7 @@ def semantic_class_w_mask(self, img_src, anns, out_file_name="output/test.json",
113114
:return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
114115
"""
115116
img = mmcv.imread(img_src)
117+
img = resize_long_edge_cv2(img, 384)
116118
oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
117119
oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
118120
bitmasks, class_names = [], []

models/segment_models/semgent_anything_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import cv2
22
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3-
import torch
3+
from utils.util import resize_long_edge_cv2
44

55
class SegmentAnything:
66
def __init__(self, device, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
@@ -16,5 +16,6 @@ def initialize_model(self, arch, pretrained_weights):
1616
def generate_mask(self, img_src):
1717
image = cv2.imread(img_src)
1818
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
19+
image = resize_long_edge_cv2(image, 384)
1920
anns = self.model.generate(image)
2021
return anns

output/1_result.jpg

-51.2 KB
Loading

readme.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
<img src="examples/icon/news.gif" alt="Your Image Description" width=100> <strong><span style="font-size: 24px;">News</span></strong>
4444
</p>
4545

46-
46+
- 17/April/2023. In addition to semantic segment anything, we use [Edit Anything](https://github.com/sail-sg/EditAnything) to get region-level semantic.
47+
- 17/April/2023. Our project is online on Huggingface. Have a try! [huggingface](https://huggingface.co/spaces/Awiny/Image2Paragraph/tree/main)
4748
- 14/April/2023. Our project is very popular in twitter. Looking [the posted twitter](https://twitter.com/awinyimgprocess/status/1646225454599372800?s=46&t=HvOe9T2n35iFuCHP5aIHpQ) for details.
4849

4950
### To Do List
@@ -57,9 +58,10 @@
5758
- [x] Integrate GRIT into our code.
5859
- [x] Support GPT4 API.
5960
- [x] Notebook/Huggingface Space.
61+
- [x] Region Semantic Classification from Edit-Anything.
62+
- [x] Make the model lightweight.
6063

6164
#### Doing
62-
- [ ] Make the model lightweight.
6365
- [ ] Replace ChatGPT with own trained LLM.
6466
- [ ] Other grounding text2image model as instead of Canny ControlNet.
6567
- [ ] Show retrieval result in gradio.
@@ -179,4 +181,4 @@ If you have more suggestions or functions need to be implemented in this codebas
179181

180182
## Acknowledgment
181183

182-
This work is based on [ChatGPT](http://chat.openai.com), [BLIP2](https://huggingface.co/spaces/Salesforce/BLIP2), [GRIT](https://github.com/JialianW/GRiT), [OFA](https://github.com/OFA-Sys/OFA),[Segment-Anything](https://segment-anything.com), [Semantic-Segment-Anything](https://github.com/fudan-zvg/Semantic-Segment-Anything), [ControlNet](https://github.com/lllyasviel/ControlNet).
184+
This work is based on [ChatGPT](http://chat.openai.com), [Edit_Anything](https://github.com/sail-sg/EditAnything), [BLIP2](https://huggingface.co/spaces/Salesforce/BLIP2), [GRIT](https://github.com/JialianW/GRiT), [OFA](https://github.com/OFA-Sys/OFA),[Segment-Anything](https://segment-anything.com), [Semantic-Segment-Anything](https://github.com/fudan-zvg/Semantic-Segment-Anything), [ControlNet](https://github.com/lllyasviel/ControlNet).

utils/__pycache__/util.cpython-38.pyc

677 Bytes
Binary file not shown.

utils/util.py

+31
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,37 @@ def read_image_width_height(image_path):
1414
width, height = image.size
1515
return width, height
1616

17+
def resize_long_edge(image, target_size=384):
18+
# Calculate the aspect ratio
19+
width, height = image.size
20+
aspect_ratio = float(width) / float(height)
21+
22+
# Determine the new dimensions
23+
if width > height:
24+
new_width = target_size
25+
new_height = int(target_size / aspect_ratio)
26+
else:
27+
new_width = int(target_size * aspect_ratio)
28+
new_height = target_size
29+
30+
# Resize the image
31+
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
32+
return resized_image
33+
34+
def resize_long_edge_cv2(image, target_size=384):
35+
height, width = image.shape[:2]
36+
aspect_ratio = float(width) / float(height)
37+
38+
if height > width:
39+
new_height = target_size
40+
new_width = int(target_size * aspect_ratio)
41+
else:
42+
new_width = target_size
43+
new_height = int(target_size / aspect_ratio)
44+
45+
resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
46+
return resized_image
47+
1748
def display_images_and_text(source_image_path, generated_image, generated_paragraph, outfile_name):
1849
source_image = Image.open(source_image_path)
1950
# Create a new image that can fit the images and the text

0 commit comments

Comments
 (0)