-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy path__init__.py
192 lines (145 loc) · 7.24 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import numpy as np
import torch
#from timm.models import create_model
from .protonet import ProtoNet
from .deploy import ProtoNet_Finetune, ProtoNet_Auto_Finetune, ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin
def get_backbone(args):
if args.arch == 'vit_base_patch16_224_in21k':
from .vit_google import VisionTransformer, CONFIGS
config = CONFIGS['ViT-B_16']
model = VisionTransformer(config, 224)
url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'
pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz'
if not os.path.exists(pretrained_weights):
try:
import wget
os.makedirs('pretrained_ckpts', exist_ok=True)
wget.download(url, pretrained_weights)
except:
print(f'Cannot download pretrained weights from {url}. Check if `pip install wget` works.')
model.load_from(np.load(pretrained_weights))
print('Pretrained weights found at {}'.format(pretrained_weights))
elif args.arch == 'dino_base_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'deit_base_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
url = "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth"
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
for k in ['head.weight', 'head.bias']:
if k in state_dict:
print(f"removing key {k} from pretrained checkpoint")
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'deit_small_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
url = "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
for k in ['head.weight', 'head.bias']:
if k in state_dict:
print(f"removing key {k} from pretrained checkpoint")
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'dino_small_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
if not args.no_pretrain:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'beit_base_patch16_224_pt22k':
from .beit import default_pretrained_model
model = default_pretrained_model(args)
print('Pretrained BEiT loaded')
elif args.arch == 'clip_base_patch16_224':
from . import clip
model, _ = clip.load('ViT-B/16', 'cpu')
elif args.arch == 'clip_resnet50':
from . import clip
model, _ = clip.load('RN50', 'cpu')
elif args.arch == 'dino_resnet50':
from torchvision.models.resnet import resnet50
model = resnet50(pretrained=False)
model.fc = torch.nn.Identity()
if not args.no_pretrain:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=False)
elif args.arch == 'resnet50':
from torchvision.models.resnet import resnet50
pretrained = not args.no_pretrain
model = resnet50(pretrained=pretrained)
model.fc = torch.nn.Identity()
elif args.arch == 'resnet18':
from torchvision.models.resnet import resnet18
pretrained = not args.no_pretrain
model = resnet18(pretrained=pretrained)
model.fc = torch.nn.Identity()
elif args.arch == 'dino_xcit_medium_24_p16':
model = torch.hub.load('facebookresearch/xcit:main', 'xcit_medium_24_p16')
model.head = torch.nn.Identity()
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=False)
elif args.arch == 'dino_xcit_medium_24_p8':
model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
elif args.arch == 'simclrv2_resnet50':
import sys
sys.path.insert(
0,
'cog',
)
import model_utils
model_utils.MODELS_ROOT_DIR = 'cog/models'
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts/simclrv2_resnet50.pth')
resnet, _ = model_utils.load_pretrained_backbone(args.arch, ckpt_file)
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x, apply_fc=False)
model = Wrapper(resnet)
elif args.arch in ['mocov2_resnet50', 'swav_resnet50', 'barlow_resnet50']:
from torchvision.models.resnet import resnet50
model = resnet50(pretrained=False)
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts_converted/{}.pth'.format(args.arch))
ckpt = torch.load(ckpt_file)
msg = model.load_state_dict(ckpt, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
# remove the fully-connected layer
model.fc = torch.nn.Identity()
else:
raise ValueError(f'{args.arch} is not conisdered in the current code.')
return model
def get_model(args):
backbone = get_backbone(args)
if args.deploy == 'vanilla':
model = ProtoNet(backbone)
elif args.deploy == 'finetune':
model = ProtoNet_Finetune(backbone, args.ada_steps, args.ada_lr, args.aug_prob, args.aug_types)
elif args.deploy == 'finetune_autolr':
model = ProtoNet_Auto_Finetune(backbone, args.ada_steps, args.aug_prob, args.aug_types)
elif args.deploy == 'ada_tokens':
model = ProtoNet_AdaTok(backbone, args.num_adapters,
args.ada_steps, args.ada_lr)
elif args.deploy == 'ada_tokens_entmin':
model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters,
args.ada_steps, args.ada_lr)
else:
raise ValueError(f'deploy method {args.deploy} is not supported.')
return model