Skip to content

The result is worse than the original version SAM2 from Meta AI reseach #160

@173619070

Description

@173619070

If I use point as input prompt words, the result is worse than the original version SAM2 from Meta AI reseach. I want to know if SAM-HQ only supports using boxes as prompt words. The left picture is the output of SAM-HQ, and the right picture is the output of SAM2. I also attached the modified code that uses point as prompt words and the orignal picture for test.

Image Image ![Image](https://github.com/user-attachments/assets/5b65dc4d-1531-42da-9e20-9f522ac3571e)

Code is here:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import os

def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='
', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=False):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
def show_res(masks, scores, input_point, input_label, input_box, filename, image):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
if input_box is not None:
box = input_box[i]
show_box(box, plt.gca())
if (input_point is not None) and (input_label is not None):
show_points(input_point, input_label, plt.gca())

    print(f"Score: {score:.3f}")
    plt.axis('off')
    plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1)
    plt.close()

def show_res_multi(masks, scores, input_point, input_label, input_box, filename, image):
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask, plt.gca(), random_color=True)
for box in input_box:
show_box(box, plt.gca())
for score in scores:
print(f"Score: {score:.3f}")
plt.axis('off')
plt.savefig(filename +'.png',bbox_inches='tight',pad_inches=-0.1)
plt.close()

if name == "main":
checkpoint = "D:/huggingface_cache/sam2.1_hq_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hq_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

for i in range(1,3):
    print("image:   ",i)
    # hq_token_only: False means use hq output to correct SAM output. 
    #                True means use hq output only. 
    #                Default: False
    hq_token_only = False 
    # To achieve best visualization effect, for images contain multiple objects (like typical coco images), we suggest to set hq_token_only=False
    # For images contain single object, we suggest to set hq_token_only = True
    # For quantiative evaluation on COCO/YTVOS/DAVIS/UVO/LVIS etc., we set hq_token_only = False

    image = cv2.imread('E:/AI/sam-hq-main/sam-hq2/demo/input_images/test'+str(i)+'.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    if i==1:
        input_box = None
        input_point = np.array([[400, 550]])
        input_label = np.ones(input_point.shape[0])
    elif i==2:
        input_box = None
        input_point = np.array([[300, 350]])
        input_label = np.ones(input_point.shape[0])
    elif i==3:
        input_box = None
        input_point = np.array([[400, 390]])
        input_label = np.ones(input_point.shape[0])

    batch_box = False if input_box is None else len(input_box)>1 
    result_path = 'demo/hq_sam_result_vis/'
    os.makedirs(result_path, exist_ok=True)

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        masks, scores, logits = predictor.predict(point_coords=input_point,
                                        point_labels=input_label,
                                        box=input_box,
                                        multimask_output=True, hq_token_only=hq_token_only)


        if not batch_box:
            show_res(masks,scores,input_point, input_label, input_box, result_path + 'example'+str(i), image)
        else:
            masks = masks.squeeze(1)
            scores = scores.squeeze(1)
            input_box = input_box.cpu().numpy()
            show_res_multi(masks, scores, input_point, input_label, input_box, result_path + 'example'+str(i), image)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions