-
Notifications
You must be signed in to change notification settings - Fork 66
/
grit2odvg.py
112 lines (103 loc) · 3.77 KB
/
grit2odvg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import jsonlines
from tqdm import tqdm
import random
import json
import os
from multiprocessing import Pool
from functools import partial
import emoji
import argparse
def clean_span(span):
span = span.rstrip()
span = span.replace('"', "'").replace('\"', "'").replace('“', "'").replace('”', "'")
span = span.replace('‘', "'").replace('’', "'").replace('–', "—")
if span.endswith('/') or span.endswith('.'):
span = span[:-1]
return span
def check_caption(cap):
check_anno = cap["caption"].rstrip()[:-1]
if not str.isascii(check_anno):
return False
# "The view is better from here 🦅 (Chouf" wtf??
check_list = {"↙️", "-", ",", " ", "*", "/", "$", "[CLS]", "[SEP]", "?"}
for ch in check_list:
if ch in check_anno:
return False
if '.' in check_anno[:-1]:
return False
if emoji.emoji_count(check_anno):
print(check_anno)
return False
return True
def get_regions(nc, anno):
h = anno["height"]
w = anno["width"]
phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])])
bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)]
return {
"bbox": bbox,
"phrase": phrase
}
def prepare_list(file_name: str, random_samples):
with open(file_name, "r") as f:
metas = [line.strip() for line in f]
num_of_files = len(metas)
print(num_of_files)
metas = random.sample(metas, random_samples)
num_of_files = len(metas)
print("after sample:", num_of_files)
return metas, num_of_files
def process_item(file, args):
with open(os.path.join(args.root, file)) as f:
anno = json.load(f)
if not check_caption(anno):
return None
noun_chunks = anno['noun_chunks']
ref_exps = anno['ref_exps']
regions = []
random_num = random.random()
if random_num > 0.5:
for nc in noun_chunks:
region = get_regions(nc, anno)
if str.isascii(region["phrase"]):
regions.append(region)
else:
for re in ref_exps:
region = get_regions(re, anno)
if str.isascii(region["phrase"]):
regions.append(region)
if len(regions) < args.min_phrase:
return None
odvg_anno = {
"filename": f'{file.split(".")[0]}.jpg',
"height": anno["height"],
"width": anno["width"],
"grounding": {
"caption": clean_span(anno["caption"]),
"regions": regions
}
}
return odvg_anno
if __name__ == "__main__":
# jsons = "/share_data/mllm/kosmos-2/GRIT-20M/anno/14m_anno.list"
# root = "/share_data/mllm/kosmos-2/GRIT-20M/data"
# output_name = "./girt_14m_odvg.jsonl"
parser = argparse.ArgumentParser(description="GRIT2ODVG List.")
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--root", type=str, default="", help="Source image root")
parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl")
parser.add_argument("--random_samples", type=int, default=200000)
parser.add_argument("--chunk_or_ref", type=float, default=0.5)
parser.add_argument("--min_phrase", type=int, default=6)
parser.add_argument("--process_num", type=int, default=10, help="the number of processes")
args = parser.parse_args()
print(args)
metas, metas_len = prepare_list(args.input_file, args.random_samples)
odvg_anno = []
func = partial(process_item, args=args)
with Pool(processes=args.process_num) as pool:
for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)):
odvg_anno.append(result)
odvg_anno = list(filter(None, odvg_anno))
with jsonlines.open(args.output_file, mode="w") as fwriter:
fwriter.write_all(odvg_anno)