Skip to content

Commit fa9152f

Browse files
committed
Add CLIPS to open_clip
1 parent aeaf2a0 commit fa9152f

File tree

7 files changed

+239
-9
lines changed

7 files changed

+239
-9
lines changed

src/open_clip/factory.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
1919
list_pretrained_tags_by_model, download_pretrained_from_hf
2020
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
21-
from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
21+
from .tokenizer import HFTokenizer, SimpleTokenizer, CLIPS_Tokenizer, DEFAULT_CONTEXT_LENGTH
2222

2323
HF_HUB_PREFIX = 'hf-hub:'
2424
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
@@ -123,12 +123,17 @@ def get_tokenizer(
123123
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
124124

125125
if 'hf_tokenizer_name' in text_config:
126-
tokenizer = HFTokenizer(
127-
text_config['hf_tokenizer_name'],
126+
if 'CLIPS' in model_name:
127+
tokenizer = CLIPS_Tokenizer(
128128
context_length=context_length,
129-
cache_dir=cache_dir,
130129
**tokenizer_kwargs,
131130
)
131+
else:
132+
tokenizer = HFTokenizer(
133+
text_config['hf_tokenizer_name'],
134+
context_length=context_length,
135+
**tokenizer_kwargs,
136+
)
132137
else:
133138
tokenizer = SimpleTokenizer(
134139
context_length=context_length,
@@ -341,6 +346,9 @@ def create_model(
341346
else:
342347
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
343348
else:
349+
if 'CLIPS' in model_name:
350+
model_cfg['vision_cfg']['eps'] = 1e-6
351+
model_cfg['text_cfg']['eps'] = 1e-6
344352
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
345353

346354
if precision in ("fp16", "bf16"):

src/open_clip/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class CLIPVisionCfg:
3131
mlp_ratio: float = 4.0
3232
patch_size: int = 16
3333
image_size: Union[Tuple[int, int], int] = 224
34+
eps: float = 1e-5
3435

3536
ls_init_value: Optional[float] = None # layer scale initial value
3637
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
@@ -76,6 +77,7 @@ class CLIPTextCfg:
7677
output_tokens: bool = False
7778
act_kwargs: dict = None
7879
norm_kwargs: dict = None
80+
eps: float = 1e-5
7981

8082
# HuggingFace specific text tower config
8183
hf_model_name: Optional[str] = None
@@ -166,6 +168,7 @@ def _build_vision_tower(
166168
output_dim=embed_dim,
167169
act_layer=act_layer,
168170
norm_layer=norm_layer,
171+
eps=vision_cfg.eps,
169172
)
170173

171174
return visual
@@ -215,6 +218,7 @@ def _build_text_tower(
215218
output_tokens=text_cfg.output_tokens,
216219
act_layer=act_layer,
217220
norm_layer=norm_layer,
221+
eps=text_cfg.eps,
218222
)
219223
return text
220224

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
{
2+
"model_cfg": {
3+
"embed_dim": 1024,
4+
"vision_cfg": {
5+
"image_size": 224,
6+
"layers": 32,
7+
"width": 1280,
8+
"head_width": 80,
9+
"patch_size": 14,
10+
"no_ln_pre": true,
11+
"pool_type": "avg",
12+
"final_ln_after_pool": true
13+
},
14+
"text_cfg": {
15+
"context_length": 80,
16+
"vocab_size": 32000,
17+
"hf_tokenizer_name": "bert-base-uncased",
18+
"tokenizer_kwargs": {
19+
"strip_sep_token": true
20+
},
21+
"width": 1024,
22+
"heads": 16,
23+
"layers": 24,
24+
"pool_type": "last",
25+
"no_causal_mask": true,
26+
"act_kwargs": {
27+
"approximate": "tanh"
28+
}
29+
}
30+
},
31+
"preprocess_cfg": {
32+
"mean": [
33+
0.485,
34+
0.456,
35+
0.406
36+
],
37+
"std": [
38+
0.229,
39+
0.224,
40+
0.225
41+
],
42+
"interpolation": "bilinear",
43+
"resize_mode": "squash"
44+
}
45+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"model_cfg": {
3+
"embed_dim": 768,
4+
"vision_cfg": {
5+
"image_size": 224,
6+
"layers": 24,
7+
"width": 1024,
8+
"patch_size": 14,
9+
"no_ln_pre": true,
10+
"pool_type": "avg",
11+
"final_ln_after_pool": true
12+
},
13+
"text_cfg": {
14+
"context_length": 80,
15+
"vocab_size": 32000,
16+
"hf_tokenizer_name": "bert-base-uncased",
17+
"tokenizer_kwargs": {
18+
"strip_sep_token": true
19+
},
20+
"width": 768,
21+
"heads": 12,
22+
"layers": 12,
23+
"pool_type": "last",
24+
"no_causal_mask": true,
25+
"act_kwargs": {
26+
"approximate": "tanh"
27+
}
28+
}
29+
},
30+
"preprocess_cfg": {
31+
"mean": [
32+
0.485,
33+
0.456,
34+
0.406
35+
],
36+
"std": [
37+
0.229,
38+
0.224,
39+
0.225
40+
],
41+
"interpolation": "bilinear",
42+
"resize_mode": "squash"
43+
}
44+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"model_cfg": {
3+
"embed_dim": 768,
4+
"vision_cfg": {
5+
"image_size": 336,
6+
"layers": 24,
7+
"width": 1024,
8+
"patch_size": 14,
9+
"no_ln_pre": true,
10+
"pool_type": "avg",
11+
"final_ln_after_pool": true
12+
},
13+
"text_cfg": {
14+
"context_length": 80,
15+
"vocab_size": 32000,
16+
"hf_tokenizer_name": "bert-base-uncased",
17+
"tokenizer_kwargs": {
18+
"strip_sep_token": true
19+
},
20+
"width": 768,
21+
"heads": 12,
22+
"layers": 12,
23+
"pool_type": "last",
24+
"no_causal_mask": true,
25+
"act_kwargs": {
26+
"approximate": "tanh"
27+
}
28+
}
29+
},
30+
"preprocess_cfg": {
31+
"mean": [
32+
0.485,
33+
0.456,
34+
0.406
35+
],
36+
"std": [
37+
0.229,
38+
0.224,
39+
0.225
40+
],
41+
"interpolation": "bilinear",
42+
"resize_mode": "squash"
43+
}
44+
}

src/open_clip/tokenizer.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,89 @@ def get_reduction_mask_fn(type: str):
400400
return syntax_mask_tokenize # randomly drop prioritized by syntax
401401

402402

403+
from tokenizers import BertWordPieceTokenizer
404+
405+
class CustomTokenizer:
406+
"""Custom tokenizer using WordPiece-based subword tokenization"""
407+
408+
def __init__(self, vocab_file, context_length=512, bos_token=1, eos_token=2, class_token=101, pad_token=0):
409+
self.tokenizer = BertWordPieceTokenizer(lowercase=True)
410+
self.tokenizer = self.tokenizer.from_file(vocab_file)
411+
self.context_length = context_length
412+
self.bos_token = bos_token
413+
self.eos_token = eos_token
414+
self.class_token = class_token
415+
self.pad_token = pad_token
416+
417+
def tokenize(self, text):
418+
encoding = self.tokenizer.encode(text, add_special_tokens=False)
419+
tokens = encoding.ids[:self.context_length - 3]
420+
return [self.bos_token] + tokens + [self.eos_token]
421+
422+
def batch_encode_plus(self, texts, max_length=None):
423+
max_length = max_length or self.context_length
424+
encoded = [self.tokenize(text) for text in texts]
425+
import torch
426+
return {
427+
'input_ids': torch.tensor([self.pad_and_add_class_token(e, max_length) for e in encoded])
428+
}
429+
430+
def pad_and_add_class_token(self, encoded_text, max_length):
431+
if len(encoded_text) < max_length - 1:
432+
encoded_text += [self.pad_token] * (max_length - 1 - len(encoded_text))
433+
return encoded_text + [self.class_token]
434+
435+
class CLIPS_Tokenizer:
436+
"""HuggingFace tokenizer wrapper"""
437+
438+
def __init__(
439+
self,
440+
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
441+
clean: str = 'whitespace',
442+
strip_sep_token: bool = False,
443+
language: Optional[str] = None,
444+
**kwargs
445+
):
446+
vocab_file = './vocab.txt'
447+
self.tokenizer = CustomTokenizer(vocab_file, context_length=80, bos_token=1, eos_token=2, class_token=101, pad_token=0)
448+
print("Load CLIPS Tokenizer.")
449+
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
450+
if callable(set_lang_fn):
451+
self.set_lang_fn = set_lang_fn
452+
if language is not None:
453+
self.set_language(language)
454+
self.context_length = context_length
455+
self.clean_fn = get_clean_fn(clean)
456+
self.strip_sep_token = strip_sep_token
457+
458+
def save_pretrained(self, dest):
459+
self.tokenizer.save_pretrained(dest)
460+
461+
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
462+
# same cleaning as for default tokenizer, except lowercasing
463+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
464+
if isinstance(texts, str):
465+
texts = [texts]
466+
context_length = context_length or self.context_length
467+
assert context_length, 'Please set a valid context length in class init or call.'
468+
469+
texts = [self.clean_fn(text) for text in texts]
470+
encoded_outputs = self.tokenizer.batch_encode_plus(
471+
texts,
472+
max_length=context_length
473+
)
474+
475+
input_ids = encoded_outputs['input_ids']
476+
477+
return input_ids
478+
479+
def set_language(self, src_lang):
480+
if hasattr(self, 'set_lang_fn'):
481+
self.set_lang_fn(src_lang)
482+
else:
483+
warnings.warn('Cannot set language for the tokenizer.')
484+
485+
403486
class HFTokenizer:
404487
"""HuggingFace tokenizer wrapper"""
405488

src/open_clip/transformer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __init__(
455455
act_layer: Callable = nn.GELU,
456456
norm_layer: Callable = LayerNorm,
457457
output_tokens: bool = False,
458+
eps: float = 1e-5 # Add eps as a parameter
458459
):
459460
super().__init__()
460461
assert pool_type in ('tok', 'avg', 'none')
@@ -487,15 +488,15 @@ def __init__(
487488
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
488489
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
489490

490-
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
491+
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width, eps=eps)
491492
self.transformer = Transformer(
492493
width,
493494
layers,
494495
heads,
495496
mlp_ratio,
496497
ls_init_value=ls_init_value,
497498
act_layer=act_layer,
498-
norm_layer=norm_layer,
499+
norm_layer=lambda x: norm_layer(x, eps=eps),
499500
)
500501

501502
if attentional_pool:
@@ -533,7 +534,7 @@ def __init__(
533534
pool_dim = width
534535
self.pool_type = pool_type
535536

536-
self.ln_post = norm_layer(pool_dim)
537+
self.ln_post = norm_layer(pool_dim, eps=eps)
537538
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
538539

539540
self.init_parameters()
@@ -693,6 +694,7 @@ def __init__(
693694
act_layer: Callable = nn.GELU,
694695
norm_layer: Callable = LayerNorm,
695696
output_tokens: bool = False,
697+
eps: float = 1e-5 # Add eps as a parameter
696698
):
697699
super().__init__()
698700
assert pool_type in ('first', 'last', 'argmax', 'none')
@@ -719,9 +721,9 @@ def __init__(
719721
mlp_ratio=mlp_ratio,
720722
ls_init_value=ls_init_value,
721723
act_layer=act_layer,
722-
norm_layer=norm_layer,
724+
norm_layer=lambda x: norm_layer(x, eps=eps),
723725
)
724-
self.ln_final = norm_layer(width)
726+
self.ln_final = norm_layer(width, eps=eps)
725727

726728
if no_causal_mask:
727729
self.attn_mask = None

0 commit comments

Comments
 (0)