Skip to content

[Bug] InternVL3-38B yields lower-than-expected results on the image grounding task with sglang #1225

@antoinegg1

Description

@antoinegg1

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

When trying to reproduce the image-grounding results for the InternVL3-38B model downloaded from Modelscope I built an sglang-based inference pipeline following the official docs. Across all benchmarks—REFCOCO testA/testB, REFCOCO+ testA/testB/val, REFCOCOg val/test, and REFCOCO val—the scores are consistently lower than those reported by the InternVL team (approximately 75%).

This issue occurs only with the InternVL3-38B model. For the other InternVL3 and InternVL3.5 models, we obtained results consistent with the report. Therefore, we suspect that ModelScope may not be the correct download link for InternVL3-38B, or that there may be a problem with the model itself.

Reproduction

The python code:

import os, base64, argparse, asyncio
from typing import List, Dict, Any
from datasets import load_dataset, Dataset, concatenate_datasets
from openai import AsyncOpenAI
from tqdm import tqdm
import re
from PIL import Image
import numpy as np

PROMPT_TEMPLATE = (
    "Please provide the bounding box coordinate of the region this sentence describes: <ref>{sent}</ref> "
)

def to_data_url(image_path: str) -> str:
    with open(image_path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return f"data:image/jpeg;base64,{b64}"

def summarize_ious(ious: List[float]) -> Dict[str, float]:
    valid = [x for x in ious if x is not None]
    if not valid:
        return {"mean_iou": 0.0, "pass_rate_05": 0.0}
    mean_iou = sum(valid) / len(valid)
    pass_rate_05 = sum(1 for x in valid if x > 0.5) / len(valid)
    return {"mean_iou": mean_iou, "pass_rate_05": pass_rate_05}

def scale_bbox(bbox, w, h):
    x1, y1, x2, y2 = bbox
    x1 = round(x1 * w / 1000.0)
    x2 = round(x2 * w / 1000.0)
    y1 = round(y1 * h / 1000.0)
    y2 = round(y2 * h / 1000.0)

    x1 = max(0, min(x1, w - 1))
    x2 = max(0, min(x2, w - 1))
    y1 = max(0, min(y1, h - 1))
    y2 = max(0, min(y2, h - 1))

    if x1 == x2 and w > 1: x2 = min(x1 + 1, w - 1)
    if y1 == y2 and h > 1: y2 = min(y1 + 1, h - 1)
    return [float(x1), float(y1), float(x2), float(y2)]

def compute_iou(boxA, boxB):
    if boxA is None or boxB is None:
        return 0.0
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    inter_w = max(0, xB - xA)
    inter_h = max(0, yB - yA)
    inter_area = inter_w * inter_h

    areaA = max(0, (boxA[2] - boxA[0])) * max(0, (boxA[3] - boxA[1]))
    areaB = max(0, (boxB[2] - boxB[0])) * max(0, (boxB[3] - boxB[1]))

    if areaA + areaB - inter_area == 0:
        return 0.0
    return inter_area / (areaA + areaB - inter_area)

async def call_one(client: AsyncOpenAI, model: str, rec: Dict[str, Any], max_tokens: int) -> Dict[str, Any]:
    rec["image"] = rec["image"].replace("lustre/fsw/portfolios/nvr/users/yunhaof/datasets", "data").lstrip("/")
    base_image_dir = "/storage/openpsi/"
    path = os.path.join(base_image_dir, rec["image"])
    with Image.open(path) as img:
        w, h = img.size
    assert w == rec["width"] and h == rec["height"], (
        f"width/height mismatch: image.size=({w},{h}), rec=({rec['width']},{rec['height']}), image={rec['image']}, sent={rec['sent']}"
    )
    gt_bbox = rec["bbox"]

    prompt = PROMPT_TEMPLATE.format(sent=rec["sent"])
    content = [
        {"type": "image_url", "image_url": {"url": to_data_url(path), "detail": "high"}},
        {"type": "text", "text": prompt},
    ]

    for attempt in range(3):
        try:
            resp = await client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": content}],
                temperature=0.0,
                max_tokens=max_tokens,
            )
            ch = resp.choices[0]

            print("generation", ch.message.content)
            print("gt_bbox", rec["bbox"], "width", w, "height", h)

            m = re.search(
                r'"bbox_2d"\s*:\s*\[\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*\]',
                ch.message.content, flags=re.S
            )
            m2 = re.search(
                r'\[\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*,\s*([+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?)\s*\]',
                ch.message.content, flags=re.S
            )
            nums = re.findall(r'[+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]?\d+)?', ch.message.content, flags=re.S)

            remap = True
            if m:
                bbox = [float(m.group(i)) for i in range(1, 5)]
                scaled_bbox = scale_bbox(bbox, w, h) if remap else bbox
            elif m2:
                bbox = [float(m2.group(i)) for i in range(1, 5)]
                scaled_bbox = scale_bbox(bbox, w, h) if remap else bbox
            elif len(nums) >= 4:
                bbox = list(map(float, nums[-4:]))
                scaled_bbox = scale_bbox(bbox, w, h) if remap else bbox
            else:
                print(f"bbox_2d not found: {ch.message.content}")
                scaled_bbox = None

            iou = compute_iou(scaled_bbox, gt_bbox) if scaled_bbox is not None else 0.0
            print("iou", iou)

            return {
                "generation": ch.message.content,
                "scaled_bbox": scaled_bbox,
                "iou": iou,
            }
        except Exception as e:
            err = f"{type(e).__name__}: {e}"
            print("err", err)
            if attempt == 2:
                return {"generation": None, "scaled_bbox": 0, "iou": 0}
            await asyncio.sleep(1.5 * (attempt + 1))

async def bounded_call(idx: int, rec: Dict[str, Any], sem: asyncio.Semaphore,
                       client: AsyncOpenAI, model: str, max_tokens: int) -> (int, Dict[str, Any]):
    async with sem:
        out = await call_one(client, model, rec, max_tokens)
        return idx, out

async def _run_range(ds, start, end, client, model, max_tokens, sem, desc="async infer"):
    ds_chunk = ds.select(range(start, end))
    tasks = []
    for j in range(len(ds_chunk)):
        rec = {
            "image": ds_chunk[j]["image"],
            "sent": ds_chunk[j]["sent"],
            "bbox": ds_chunk[j]["bbox"],
            "height": ds_chunk[j]["height"],
            "width": ds_chunk[j]["width"],
        }
        tasks.append(asyncio.create_task(bounded_call(start + j, rec, sem, client, model, max_tokens)))

    generations = [None] * len(ds_chunk)
    scaled_bbox = [None] * len(ds_chunk)
    iou = [None] * len(ds_chunk)

    pbar = tqdm(total=len(tasks), desc=f"{desc} [{start}-{end-1}]")
    for coro in asyncio.as_completed(tasks):
        idx, out = await coro
        j = idx - start
        generations[j] = out.get("generation")
        scaled_bbox[j] = out.get("scaled_bbox")
        iou[j] = out.get("iou")
        pbar.update(1)
    pbar.close()

    return ds_chunk, generations, scaled_bbox, iou

async def main_async(args):
    ds = load_dataset("json", split='train', data_files=args.data_json)

    os.makedirs(args.output_dir, exist_ok=True)

    client = AsyncOpenAI(base_url=args.endpoint.rstrip("/") + "/v1", api_key="none")
    sem = asyncio.Semaphore(args.concurrency)

    if args.flush_every is None or args.flush_every <= 0:
        tasks = []
        for i in range(len(ds)):
            rec = {
                "image": ds[i]["image"],
                "sent": ds[i]["sent"],
                "bbox": ds[i]["bbox"],
                "height": ds[i]["height"],
                "width": ds[i]["width"],
            }
            tasks.append(asyncio.create_task(bounded_call(i, rec, sem, client, args.model_path, args.max_tokens)))

        generations = [None] * len(ds)
        scaled_bbox = [None] * len(ds)
        iou = [None] * len(ds)

        pbar = tqdm(total=len(tasks), desc="async infer (no chunk)")
        for coro in asyncio.as_completed(tasks):
            idx, out = await coro
            generations[idx] = out.get("generation")
            scaled_bbox[idx] = out.get("scaled_bbox")
            iou[idx] = out.get("iou")
            pbar.update(1)
        pbar.close()

        stats = summarize_ious(iou)
        mean_iou = stats["mean_iou"]
        pass_rate = stats["pass_rate_05"]
        print("mean_iou", mean_iou)
        print("[email protected]", pass_rate)

        iou_file = os.path.join(args.output_dir, "iou.txt")
        with open(iou_file, "w", encoding="utf-8") as f:
            f.write(f"# IOU scores (N={len(iou)})\n")
            f.write(f"mean_iou: {mean_iou:.6f}\n")
            f.write(f"[email protected]: {pass_rate:.6f}\n")

        ds_out = (ds.add_column("generated_result", generations)
                    .add_column("scaled_bbox", scaled_bbox)
                    .add_column("iou", iou))
        ds_out.save_to_disk(args.output_dir)
        return

    shards_dir = os.path.join(args.output_dir, "shards")
    os.makedirs(shards_dir, exist_ok=True)
    all_scores = []
    shard_paths = []

    N = len(ds)
    step = args.flush_every

    for start in range(0, N, step):
        end = min(start + step, N)
        shard_path = os.path.join(shards_dir, f"shard_{start}_{end}")

        if os.path.exists(shard_path):

            shard_paths.append(shard_path)
            try:
                with open(os.path.join(shard_path, "iou.txt"), "r", encoding="utf-8") as f:
                    for line in f:
                        if line.startswith("avg_score:"):
                            all_scores.append(float(line.strip().split(":")[1]))
            except Exception:
                pass
            continue

        ds_chunk, generations, scaled_bbox, iou= await _run_range(
            ds, start, end, client, args.model_path, args.max_tokens, sem, desc="async infer (chunk)"
        )
        stats = summarize_ious(iou)
        mean_iou = stats["mean_iou"]
        pass_rate = stats["pass_rate_05"]
        all_scores.append(pass_rate)

        ds_chunk_out = (ds_chunk
                        .add_column("generated_result", generations)
                        .add_column("scaled_bbox", scaled_bbox)
                        .add_column("iou", iou))
        ds_chunk_out.save_to_disk(shard_path)
        with open(os.path.join(shard_path, "iou.txt"), "w", encoding="utf-8") as f:
            f.write(f"# IOU scores for shard [{start}-{end-1}] (N={len(iou)})\n")
            f.write(f"mean_iou: {mean_iou:.6f}\n")
            f.write(f"[email protected]: {pass_rate:.6f}\n")
        shard_paths.append(shard_path)

    shard_datasets = [Dataset.load_from_disk(p) for p in shard_paths]
    ds_merged = concatenate_datasets(shard_datasets)

    ious = ds_merged["iou"]
    iou_above_05 = [x for x in ious if (x is not None and x > 0.5)]
    overall_score = (len(iou_above_05) / len(ious)) if len(ious) > 0 else 0.0

    final_dir = os.path.join(args.output_dir, "final")
    os.makedirs(final_dir, exist_ok=True)
    ds_merged.save_to_disk(final_dir)
    with open(os.path.join(final_dir, "iou.txt"), "w", encoding="utf-8") as f:
        f.write(f"avg_score: {overall_score}\n")

    print(f"overall avg_score: {overall_score}")

def parse_args():
    args = argparse.ArgumentParser("Async concurrent VLM inference via SGLang")
    args.add_argument("--data_json", required=True)
    args.add_argument("--endpoint", default="http://127.0.0.1:30000")
    args.add_argument("--model_path", required=True)
    args.add_argument("--output_dir", required=True)
    args.add_argument("--max_tokens", type=int, default=512)
    args.add_argument("--concurrency", type=int, default=128)
    args.add_argument("--flush_every", type=int, default=0)
    return args.parse_args()

if __name__ == "__main__":
    args = parse_args()
    asyncio.run(main_async(args))

Run the python code with:

#!/usr/bin/env bash
set -euo pipefail

export BASE_IMAGE_DIR="${BASE_IMAGE_DIR:-./data}"      
export DATA_DIR="${DATA_DIR:-./datasets/}"
export RUNS_DIR="${RUNS_DIR:-./runs}"
export ENDPOINT="${ENDPOINT:-http://127.0.0.1:30000}"

MODEL_PATH="${1:-${MODEL_PATH:-path/to/your-model-or-hf-id}}"

echo "Using model path: ${MODEL_PATH}"
model_name="$(basename "${MODEL_PATH}")"
result_dir="${RUNS_DIR}/result/${model_name}"
mkdir -p "${result_dir}"

data_name_list=(
  "refcoco_testA" "refcoco_testB"
  "refcoco+_testA" "refcoco+_testB" "refcoco+_val"
  "refcocog_val" "refcocog_test" "refcoco_val"
)

for data_name in "${data_name_list[@]}"; do
  echo "Processing dataset: ${data_name}"
  data_json="${DATA_DIR}/${data_name}.jsonl"
  output_dir="${result_dir}/${data_name}_result"
  mkdir -p "${output_dir}"

  python scripts/sglang_infer.py \
    --model_path "${MODEL_PATH}" \
    --data_json "${data_json}" \
    --output_dir "${output_dir}" \
    --endpoint "${ENDPOINT}"
done

Environment

sys.platform: linux                                                                                                                                                                                                                                                                                                                                                            
Python: 3.12.3 (main, Feb  4 2025, 14:48:35) [GCC 13.3.0]                                                                                                                                                                                                                                                                                                                      
CUDA available: True                                                                                                                                                                                                                                                                                                                                                           
MUSA available: False                                                                                                                                                                                                                                                                                                                                                          
numpy_random_seed: 2147483648                                                                                                                                                                                                                                                                                                                                                  
GPU 0,1,2,3,4,5,6,7: NVIDIA L20X                                                                                                                                                                                                                                                                                                                                               
CUDA_HOME: /usr/local/cuda                                                                                                                                                                                                                                                                                                                                                     
NVCC: Cuda compilation tools, release 12.8, V12.8.93                                                                                                                                                                                                                                                                                                                           
GCC: x86_64-linux-gnu-gcc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0                                                                                                                                                                                                                                                                                                                
PyTorch: 2.8.0+cu128                                                                                                                                                                                                                                                                                                                                                           
PyTorch compiling details: PyTorch built with:                                                                                                                                                                                                                                                                                                                                 
  - GCC 13.3                                                                                                                                                                                                                                                                                                                                                                   
  - C++ Version: 201703                                                                                                                                                                                                                                                                                                                                                        
  - Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications                                                                                                                                                                                                                                                        
  - Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d)                                                                                                                                                                                                                                                                                                
  - OpenMP 201511 (a.k.a. OpenMP 4.5)                                                                                                                                                                                                                                                                                                                                          
  - LAPACK is enabled (usually provided by MKL)                                                                                                                                                                                                                                                                                                                                
  - NNPACK is enabled                                                                                                                                                                                                                                                                                                                                                          
  - CPU capability usage: AVX512                                                                                                                                                                                                                                                                                                                                               
  - CUDA Runtime 12.8                                                                                                                                                                                                                                                                                                                                                          
  - NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120                                                                                   
  - CuDNN 91.0.2  (built against CUDA 12.9)                                                                                                                                                                                                                                                                                                                                    
    - Built with CuDNN 90.8                                                                                                                                                                                                                                                                                                                                                    
  - Magma 2.6.1                                                                                                                                                                                                                                                                                                                                                                
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=a1cb3cc05d46d198467bebbb6e8fba50a325d4e7, CUDA_VERSION=12.8, CUDNN_VERSION=9.8.0, CXX_COMPILER=/opt/rh/gcc-toolset-13/root/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DU
SE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi 
-Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, U
SE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF,                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                                                               
TorchVision: 0.23.0+cu128                                                                                                                                                                                                                                                                                                                                                      
LMDeploy: 0.10.2+                                                                                                                                                                                                                                                                                                                                                              
transformers: 4.57.0                                                                                                                                                                                                                                                                                                                                                           
fastapi: 0.119.0
pydantic: 2.12.2
triton: 3.4.0
NVIDIA Topology:
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    48-95,144-191   1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    48-95,144-191   1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    48-95,144-191   1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     48-95,144-191   1               N/A
NIC0    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE    SYS     SYS     SYS     SYS
NIC1    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE    SYS     SYS     SYS     SYS
NIC2    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    SYS     SYS     SYS     SYS
NIC3    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE
NIC5    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      NODE    NODE
NIC6    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE
NIC7    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1
  NIC2: mlx5_bond_2
  NIC3: mlx5_bond_3
  NIC4: mlx5_bond_4
  NIC5: mlx5_bond_5
  NIC6: mlx5_bond_6
  NIC7: mlx5_bond_7

Error traceback

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions