Skip to content

Commit 9cdd4a4

Browse files
authored
Bump version to v1.2.1 (#46)
1. Support dynamic sequence length during training 2. Update README.md 3. Update evaluation code
1 parent 26601cc commit 9cdd4a4

File tree

12 files changed

+167
-21
lines changed

12 files changed

+167
-21
lines changed

BLOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ We released [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-
5858

5959
<img width="650" alt="image" src="https://github.com/OpenGVLab/InternVL/assets/8529570/0e60912e-c52b-46fa-bd61-5f94a221d1fc">
6060

61-
6261
## InternVL
6362

6463
> Date: 2023/12/12<br>

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM.
332332
<summary>Multimodal Dialogue (click to expand)</summary>
333333

334334
- Compared with SOTA VLLMs
335-
335+
336336
| name | image size | MMMU<br>(val) | MMMU<br>(test) | MathVista<br>(testmini) | MMB<br>(test) | MMB−CN<br>(test) | MMVP | MME | ScienceQA<br>(image) | POPE | TextVQA | SEEDv1<br>(image) | VizWiz<br>(test) | GQA<br>(test) |
337337
| ------------------ | ---------- | ------------- | -------------- | ----------------------- | ------------- | ---------------- | ---- | -------- | -------------------- | ---- | ------- | ----------------- | ---------------- | ------------- |
338338
| GPT-4V\* | unknown | 56.8 | 55.7 | 49.9 | 77.0 | 74.4 | 38.7 | 1409/517 | - | - | 78.0 | 71.6 | - | - |
@@ -343,7 +343,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM.
343343
| | | | | | | | | | | | | | | |
344344
| LLaVA-NEXT-34B | 672x672 | 51.1 | 44.7 | 46.5 | 79.3 | 79.0 | - | 1631/397 | 81.8 | 87.7 | 69.5 | 75.9 | 63.8 | 67.1 |
345345
| InternVL-Chat-V1.2 | 448x448 | 51.6 | 46.2 | 47.7 | 82.2 | 81.2 | 56.7 | 1672/509 | 83.3 | 88.0 | 69.7 | 75.6 | 60.0 | 64.0 |
346-
346+
347347
\* denotes proprietary models. MMBench results are collected from the [leaderboard](https://mmbench.opencompass.org.cn/leaderboard). In most benchmarks, InternVL-Chat-V1.2 achieves better performance than LLaVA-NeXT-34B.
348348

349349
- Zero-Shot Image Captioning [\[see details\]](./internvl_g#zero-shot-image-captioning)

internvl_chat/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ The hyperparameters used for finetuning are listed in the following table.
141141

142142
## 📊 Evaluation
143143

144+
\* Training set observed.
145+
144146
**MultiModal Benchmark**
145147

146148
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP | MathVista |
@@ -151,14 +153,14 @@ The hyperparameters used for finetuning are listed in the following table.
151153
| model | MMMU<sub>val/test</sub> | CMMMU<sub>val/test</sub> | Tiny<sub>LVLM</sub> | LLaVA<sub>bench</sub> | MM-Vet |
152154
| --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ------------------------ | ------------------- | --------------------- | ------ |
153155
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 39.1 / 35.3 | 34.8 / 34.0 | 344.5 | 76.3 | 45.0 |
154-
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | TODO | 350.3 | - | 48.9 |
156+
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | - | 350.3 | - | 48.9 |
155157

156158
**Visual Question Answering**
157159

158160
| model | VQAv2<sub>test</sub> | OKVQA<sub>val</sub> | TextVQA<sub>val</sub> | VizWiz<sub>val/test</sub> | AI2D<sub>test</sub> | GQA<sub>test</sub> | SQA<sub>test</sub> |
159161
| --------------------------------------------------------------------------------- | -------------------- | ------------------- | --------------------- | ------------------------- | ------------------- | ------------------ | ------------------ |
160-
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9 | 64.2 | 65.8 | 58.3 / 57.3 | 70.2 | 62.4 | 91.2 |
161-
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5 | 69.7 | 61.9 / 60.0 | 71.6 | 64.0 | 83.3 |
162+
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9\* | 64.2\* | 65.8 | 58.3 / 57.3 | 70.2\* | 62.4\* | 91.2\* |
163+
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5\* | 69.7 | 61.9 / 60.0 | 71.6\* | 64.0\* | 83.3 |
162164

163165
**Image Captioning**
164166

internvl_chat/eval/scienceqa/evaluate_scienceqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def post_process(pred, option):
114114
if v in pred:
115115
return k
116116

117-
return random.choice(option_candidate)
117+
return pred
118118

119119

120120
def evaluate_chat_model():

internvl_chat/internvl/patch/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
33
from .llama_rmsnorm_monkey_patch import \
44
replace_llama_rmsnorm_with_fused_rmsnorm
5+
from .pad_data_collator import pad_data_collator
6+
from .train_sampler_patch import replace_train_sampler
57

68
__all__ = ['replace_llama_attn_with_flash_attn',
79
'replace_llama_rmsnorm_with_fused_rmsnorm',
8-
'replace_llama2_attn_with_flash_attn']
10+
'replace_llama2_attn_with_flash_attn',
11+
'replace_train_sampler',
12+
'pad_data_collator']
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
import torch
3+
4+
IGNORE_INDEX = -100
5+
6+
7+
def pad_data_collator(features, pad_id=0):
8+
9+
first = features[0]
10+
batch = {}
11+
12+
batch_lens = [feat['input_ids'].shape for feat in features]
13+
max_item_length = max(batch_lens)[0]
14+
for idx in range(len(features)):
15+
feat = features[idx]
16+
temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
17+
temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
18+
feat['input_ids'] = temp_input_ids
19+
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
20+
temp_labels[:feat['labels'].shape[0]] = feat['labels']
21+
feat['labels'] = temp_labels
22+
feat['attention_mask'] = feat['input_ids'].ne(pad_id)
23+
24+
# Special handling for labels.
25+
# Ensure that tensor is created with the correct type
26+
# (it should be automatically the case, but let's make sure of it.)
27+
if 'label' in first and first['label'] is not None:
28+
label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
29+
dtype = torch.long if isinstance(label, int) else torch.float
30+
batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
31+
elif 'label_ids' in first and first['label_ids'] is not None:
32+
if isinstance(first['label_ids'], torch.Tensor):
33+
batch['labels'] = torch.stack([f['label_ids'] for f in features])
34+
else:
35+
dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
36+
batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
37+
38+
# Handling of all other possible keys.
39+
# Again, we will use the first element to figure out which key/values are not None for this model.
40+
for k, v in first.items():
41+
if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
42+
if isinstance(v, torch.Tensor):
43+
batch[k] = torch.stack([f[k] for f in features])
44+
elif isinstance(v, np.ndarray):
45+
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
46+
else:
47+
batch[k] = torch.tensor([f[k] for f in features])
48+
49+
return batch
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Optional
2+
3+
import torch
4+
import transformers
5+
from transformers.trainer import (LengthGroupedSampler, RandomSampler,
6+
has_length)
7+
8+
9+
# patch trainer
10+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
11+
if self.train_dataset is None or not has_length(self.train_dataset):
12+
return None
13+
# Build the sampler.
14+
if self.args.group_by_length:
15+
lengths = []
16+
for dataset in self.train_dataset.datasets:
17+
lengths = lengths + dataset.length
18+
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
19+
return LengthGroupedSampler(
20+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
21+
dataset=self.train_dataset,
22+
lengths=lengths,
23+
model_input_name=model_input_name,
24+
)
25+
else:
26+
return RandomSampler(self.train_dataset)
27+
28+
29+
def replace_train_sampler():
30+
transformers.Trainer._get_train_sampler = _get_train_sampler
31+
print('Replace train sampler!!')

internvl_chat/internvl/train/internvl_chat_finetune.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
InternVisionModel,
2424
InternVLChatConfig,
2525
InternVLChatModel)
26-
from internvl.patch import (replace_llama2_attn_with_flash_attn,
27-
replace_llama_rmsnorm_with_fused_rmsnorm)
26+
from internvl.patch import (pad_data_collator,
27+
replace_llama2_attn_with_flash_attn,
28+
replace_llama_rmsnorm_with_fused_rmsnorm,
29+
replace_train_sampler)
2830
from internvl.train.dataset import (TCSLoader, WeightedConcatDataset,
2931
build_transform)
3032
from PIL import Image, ImageFile, PngImagePlugin
@@ -39,6 +41,7 @@
3941
# Upgrade transformers to v4.36.2, we don't need it anymore
4042
# replace_llama2_attn_with_flash_attn()
4143
replace_llama_rmsnorm_with_fused_rmsnorm()
44+
replace_train_sampler()
4245

4346
try:
4447
from petrel_client.client import Client
@@ -182,6 +185,7 @@ def preprocess(
182185
tokenizer: transformers.PreTrainedTokenizer,
183186
num_image_token: int,
184187
text_only: bool = False,
188+
group_by_length: bool = False,
185189
) -> Dict:
186190
conv = get_conv_template(template_name)
187191
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
@@ -213,7 +217,7 @@ def preprocess(
213217
input_ids = tokenizer(
214218
conversations,
215219
return_tensors='pt',
216-
padding='max_length',
220+
padding=False if group_by_length else 'max_length',
217221
max_length=tokenizer.model_max_length,
218222
truncation=True,
219223
).input_ids
@@ -283,6 +287,7 @@ def preprocess_mpt(
283287
tokenizer: transformers.PreTrainedTokenizer,
284288
num_image_token: int,
285289
text_only: bool = False,
290+
group_by_length: bool = False,
286291
) -> Dict:
287292
conv = get_conv_template(template_name)
288293
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
@@ -314,7 +319,7 @@ def preprocess_mpt(
314319
input_ids = tokenizer(
315320
conversations,
316321
return_tensors='pt',
317-
padding='max_length',
322+
padding=False if group_by_length else 'max_length',
318323
max_length=tokenizer.model_max_length,
319324
truncation=True,
320325
).input_ids
@@ -368,7 +373,7 @@ class LazySupervisedDataset(Dataset):
368373
"""Dataset for supervised fine-tuning."""
369374

370375
def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token,
371-
image_size=224, is_train=True, pad2square=False):
376+
image_size=224, is_train=True, pad2square=False, group_by_length=False):
372377
super(LazySupervisedDataset, self).__init__()
373378
self.tokenizer = tokenizer
374379
self.template_name = template_name
@@ -384,6 +389,21 @@ def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token,
384389
self.root = meta['root']
385390
self.cached_data_dict = {}
386391
self.tcs_loader = tcs_loader
392+
self.group_by_length = group_by_length
393+
if self.group_by_length:
394+
self.conv2length = {}
395+
self.length = []
396+
for data_item in self.raw_data:
397+
conversations = ''.join(data_item.split('conversations')[1:])
398+
str_length = len(conversations)
399+
if str_length not in self.conv2length:
400+
token_length = tokenizer(
401+
conversations, return_tensors='pt', padding=False, truncation=False,
402+
).input_ids.size(1)
403+
self.conv2length[str_length] = token_length
404+
else:
405+
token_length = self.conv2length[str_length]
406+
self.length.append(token_length)
387407

388408
def __len__(self):
389409
return len(self.raw_data)
@@ -405,7 +425,7 @@ def multi_modal_get_item(self, data_item):
405425
else:
406426
preprocess_function = preprocess
407427
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
408-
self.tokenizer, self.num_image_token)
428+
self.tokenizer, self.num_image_token, group_by_length=self.group_by_length)
409429
ret = dict(
410430
input_ids=ret['input_ids'][0],
411431
labels=ret['labels'][0],
@@ -425,7 +445,8 @@ def pure_text_get_item(self, data_item):
425445
else:
426446
preprocess_function = preprocess
427447
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
428-
self.tokenizer, self.num_image_token, text_only=True)
448+
self.tokenizer, self.num_image_token, text_only=True,
449+
group_by_length=self.group_by_length)
429450
ret = dict(
430451
input_ids=ret['input_ids'][0],
431452
labels=ret['labels'][0],
@@ -455,7 +476,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
455476
return ret
456477

457478

458-
def build_datasets(data_args, tokenizer, tcs_loader, model):
479+
def build_datasets(data_args, tokenizer, tcs_loader, model, group_by_length=False):
459480
datasets = []
460481
lengths = []
461482
ds_collections = json.loads(open(data_args.meta_path).read())
@@ -469,7 +490,8 @@ def build_datasets(data_args, tokenizer, tcs_loader, model):
469490
num_image_token=model.num_image_token,
470491
image_size=data_args.force_image_size,
471492
is_train=ds_collections[ds_name]['data_augment'],
472-
pad2square=data_args.pad2square
493+
pad2square=data_args.pad2square,
494+
group_by_length=group_by_length
473495
)
474496
except Exception:
475497
logger.info(f'Error in loading dataset: {ds_name}')
@@ -623,7 +645,8 @@ def main():
623645
if model_args.grad_checkpoint:
624646
model.language_model._set_gradient_checkpointing()
625647

626-
train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model)
648+
train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model,
649+
group_by_length=training_args.group_by_length)
627650

628651
def _freeze_params(module):
629652
for param in module.parameters():
@@ -672,7 +695,7 @@ def _freeze_params(module):
672695
train_dataset=train_dataset if training_args.do_train else None,
673696
eval_dataset=None,
674697
tokenizer=tokenizer,
675-
data_collator=default_data_collator,
698+
data_collator=default_data_collator if not training_args.group_by_length else pad_data_collator,
676699
)
677700

678701
# Training

internvl_chat/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "internvl_chat"
7-
version = "1.2.0"
7+
version = "1.2.1"
88
description = "Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks."
99
readme = "README.md"
1010
requires-python = ">=3.8"

internvl_chat/tools/json2jsonl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111

1212
data = json.load(open(args.path))
1313
writer = open(args.path.replace('.json', '.jsonl'), 'w')
14-
for item in data:
14+
for idx, item in enumerate(data):
1515
conversations = item['conversations']
1616
if conversations[0]['from'] == 'system':
1717
item['conversations'] = item['conversations'][1:]
18+
item['id'] = idx
1819
writer.write(json.dumps(item, ensure_ascii=False) + '\n')
1920

2021
writer.close()

0 commit comments

Comments
 (0)