Skip to content

Commit 526189f

Browse files
committed
upload
1 parent 0be1e82 commit 526189f

File tree

537 files changed

+175369
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

537 files changed

+175369
-0
lines changed

Reward3D/BLIP/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .blip_pretrain import *
192 Bytes
Binary file not shown.
185 Bytes
Binary file not shown.
2.31 KB
Binary file not shown.
2.28 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
27.2 KB
Binary file not shown.
27.1 KB
Binary file not shown.
11.9 KB
Binary file not shown.
11.8 KB
Binary file not shown.

Reward3D/BLIP/blip.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
'''
2+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
3+
'''
4+
5+
import warnings
6+
warnings.filterwarnings("ignore")
7+
8+
import torch
9+
import os
10+
from urllib.parse import urlparse
11+
from timm.models.hub import download_cached_file
12+
from transformers import BertTokenizer
13+
from .vit import VisionTransformer, interpolate_pos_embed
14+
15+
16+
def init_tokenizer():
17+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
19+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
20+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
21+
return tokenizer
22+
23+
24+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
25+
26+
assert vit in ['base', 'large'], "vit parameter must be base or large"
27+
if vit=='base':
28+
vision_width = 768
29+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
30+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
31+
drop_path_rate=0 or drop_path_rate
32+
)
33+
elif vit=='large':
34+
vision_width = 1024
35+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
36+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
37+
drop_path_rate=0.1 or drop_path_rate
38+
)
39+
return visual_encoder, vision_width
40+
41+
42+
def is_url(url_or_filename):
43+
parsed = urlparse(url_or_filename)
44+
return parsed.scheme in ("http", "https")
45+
46+
def load_checkpoint(model,url_or_filename):
47+
if is_url(url_or_filename):
48+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
49+
checkpoint = torch.load(cached_file, map_location='cpu')
50+
elif os.path.isfile(url_or_filename):
51+
checkpoint = torch.load(url_or_filename, map_location='cpu')
52+
else:
53+
raise RuntimeError('checkpoint url or path is invalid')
54+
55+
state_dict = checkpoint['model']
56+
57+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
58+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
59+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
60+
model.visual_encoder_m)
61+
for key in model.state_dict().keys():
62+
if key in state_dict.keys():
63+
if state_dict[key].shape!=model.state_dict()[key].shape:
64+
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
65+
del state_dict[key]
66+
67+
msg = model.load_state_dict(state_dict,strict=False)
68+
print('load checkpoint from %s'%url_or_filename)
69+
return model,msg
70+

Reward3D/BLIP/blip_pretrain.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
'''
2+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
3+
'''
4+
5+
import transformers
6+
transformers.logging.set_verbosity_error()
7+
8+
from torch import nn
9+
import os
10+
from .med import BertConfig, BertModel
11+
from .blip import create_vit, init_tokenizer
12+
13+
class BLIP_Pretrain(nn.Module):
14+
def __init__(self,
15+
med_config = "med_config.json",
16+
image_size = 224,
17+
vit = 'base',
18+
vit_grad_ckpt = False,
19+
vit_ckpt_layer = 0,
20+
embed_dim = 256,
21+
queue_size = 57600,
22+
momentum = 0.995,
23+
):
24+
"""
25+
Args:
26+
med_config (str): path for the mixture of encoder-decoder model's configuration file
27+
image_size (int): input image size
28+
vit (str): model size of vision transformer
29+
"""
30+
super().__init__()
31+
32+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
33+
34+
self.tokenizer = init_tokenizer()
35+
encoder_config = BertConfig.from_json_file(med_config)
36+
encoder_config.encoder_width = vision_width
37+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
38+
39+
text_width = self.text_encoder.config.hidden_size
40+
41+
self.vision_proj = nn.Linear(vision_width, embed_dim)
42+
self.text_proj = nn.Linear(text_width, embed_dim)
43+

0 commit comments

Comments
 (0)