|
12 | 12 | import argparse
|
13 | 13 | import inspect
|
14 | 14 |
|
| 15 | +import shutil |
15 | 16 | import torchvision
|
16 | 17 | import whisper
|
17 | 18 | import matplotlib.pyplot as plt
|
@@ -1216,7 +1217,7 @@ def inference_auto_segment_object(self, image_path):
|
1216 | 1217 | text_prompt = generate_tags(caption, split=",")
|
1217 | 1218 | print(f"\nCaption: {caption}")
|
1218 | 1219 | print(f"Tags: {text_prompt}")
|
1219 |
| - updated_image_path, pred_phrases = self._segment_object(image_path, text_prompt, func_name="seg-objects") |
| 1220 | + updated_image_path, pred_phrases = self._segment_object(image_path, text_prompt, func_name="auto-label") |
1220 | 1221 | caption = check_caption(caption, pred_phrases)
|
1221 | 1222 | print(f"Revise caption with number: {caption}")
|
1222 | 1223 | print(f"Processed SegmentMultiObject, Input Image: {image_path}, Caption: {caption}, "
|
@@ -1251,8 +1252,8 @@ def _inpainting(self, image_path, to_be_replaced_txt, replace_with_txt, func_nam
|
1251 | 1252 | )
|
1252 | 1253 | # inpainting pipeline
|
1253 | 1254 | mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
|
1254 |
| - mask_pil = Image.fromarray(mask) |
1255 |
| - image_pil = Image.fromarray(image) |
| 1255 | + mask_pil = Image.fromarray(mask).resize((512, 512)) |
| 1256 | + image_pil = Image.fromarray(image).resize((512, 512)) |
1256 | 1257 | image = self.sd_pipe(prompt=replace_with_txt, image=image_pil, mask_image=mask_pil).images[0]
|
1257 | 1258 | updated_image_path = get_new_image_name(image_path, func_name)
|
1258 | 1259 | image.save(updated_image_path)
|
@@ -1313,18 +1314,23 @@ def run_text(self, text, state):
|
1313 | 1314 | return state, state
|
1314 | 1315 |
|
1315 | 1316 | def run_image(self, image, state, txt, lang):
|
1316 |
| - image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") |
1317 |
| - print("======>Auto Resize Image...") |
1318 |
| - img = Image.open(image.name) |
1319 |
| - width, height = img.size |
1320 |
| - ratio = min(512 / width, 512 / height) |
1321 |
| - width_new, height_new = (round(width * ratio), round(height * ratio)) |
1322 |
| - width_new = int(np.round(width_new / 64.0)) * 64 |
1323 |
| - height_new = int(np.round(height_new / 64.0)) * 64 |
1324 |
| - img = img.resize((width_new, height_new)) |
1325 |
| - img = img.convert('RGB') |
1326 |
| - img.save(image_filename, "PNG") |
1327 |
| - print(f"Resize image form {width}x{height} to {width_new}x{height_new}") |
| 1317 | + # image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") |
| 1318 | + # print("======>Auto Resize Image...") |
| 1319 | + # img = Image.open(image.name) |
| 1320 | + # width, height = img.size |
| 1321 | + # ratio = min(512 / width, 512 / height) |
| 1322 | + # width_new, height_new = (round(width * ratio), round(height * ratio)) |
| 1323 | + # width_new = int(np.round(width_new / 64.0)) * 64 |
| 1324 | + # height_new = int(np.round(height_new / 64.0)) * 64 |
| 1325 | + # img = img.resize((width_new, height_new)) |
| 1326 | + # img = img.convert('RGB') |
| 1327 | + # img.save(image_filename) |
| 1328 | + # img.save(image_filename, "PNG") |
| 1329 | + # print(f"Resize image form {width}x{height} to {width_new}x{height_new}") |
| 1330 | + ## Directly use original image for better results |
| 1331 | + suffix = image.name.split('.')[-1] |
| 1332 | + image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.{suffix}") |
| 1333 | + shutil.copy(image.name, image_filename) |
1328 | 1334 | if 'Grounded_dino_sam_inpainting' in self.models:
|
1329 | 1335 | description = self.models['Grounded_dino_sam_inpainting'].inference_caption(image_filename)
|
1330 | 1336 | else:
|
@@ -1388,7 +1394,7 @@ def speech_recognition(speech_file):
|
1388 | 1394 |
|
1389 | 1395 | if __name__ == '__main__':
|
1390 | 1396 | load_dict = {'Grounded_dino_sam_inpainting': 'cuda:0'}
|
1391 |
| -# load_dict = {'ImageCaptioning': 'cuda:0'} |
| 1397 | + # load_dict = {'ImageCaptioning': 'cuda:0'} |
1392 | 1398 |
|
1393 | 1399 | bot = ConversationBot(load_dict)
|
1394 | 1400 |
|
@@ -1451,3 +1457,4 @@ def speech_recognition(speech_file):
|
1451 | 1457 | clear.click(lambda: [], None, state)
|
1452 | 1458 |
|
1453 | 1459 | demo.launch(server_name="0.0.0.0", server_port=10010)
|
| 1460 | + |
0 commit comments