Skip to content

Commit f5c3e86

Browse files
feat(multimodal): add dalle
1 parent efc97eb commit f5c3e86

29 files changed

+1901
-21
lines changed

hubconf.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99
vit_b_16,
1010
vit_b_32,
1111
vit_l_14,
12-
vit_l_14_336px
12+
vit_l_14_336px,
13+
)
14+
from official.multimodal.dalle import (
15+
Generator,
16+
OpenAIDiscreteVAE,
17+
OpenAIDiscreteVAEDecoder,
18+
OpenAIDiscreteVAEEncoder,
19+
VQGanVAE,
20+
coco_512_16_16d_16h_80tsl,
21+
openai_discrete_VAE_decoder,
22+
openai_discrete_VAE_encoder,
23+
vqgan_vae_1024,
1324
)
1425
from official.multimodal.taming_transformer import (
1526
ConditionalSampler,
@@ -20,7 +31,7 @@
2031
s_flckr_transformer,
2132
vqgan_gumbel_f8,
2233
vqgan_imagenet_f16_1024,
23-
vqgan_imagenet_f16_16384
34+
vqgan_imagenet_f16_16384,
2435
)
2536
from official.nlp.bert.model import (
2637
cased_L_12_H_768_A_12,
@@ -30,7 +41,7 @@
3041
uncased_L_12_H_768_A_12,
3142
uncased_L_24_H_1024_A_16,
3243
wwm_cased_L_24_H_1024_A_16,
33-
wwm_uncased_L_24_H_1024_A_16
44+
wwm_uncased_L_24_H_1024_A_16,
3445
)
3546
from official.quantization.models import quantized_resnet18
3647
from official.vision.classification.resnet.model import (
@@ -43,13 +54,13 @@
4354
resnet101,
4455
resnet152,
4556
resnext50_32x4d,
46-
resnext101_32x8d
57+
resnext101_32x8d,
4758
)
4859
from official.vision.classification.shufflenet.model import (
4960
shufflenet_v2_x0_5,
5061
shufflenet_v2_x1_0,
5162
shufflenet_v2_x1_5,
52-
shufflenet_v2_x2_0
63+
shufflenet_v2_x2_0,
5364
)
5465
from official.vision.detection.configs import (
5566
atss_res18_coco_3x_800size,
@@ -76,30 +87,18 @@
7687
retinanet_res34_coco_3x_800size,
7788
retinanet_res50_coco_3x_800size,
7889
retinanet_res101_coco_3x_800size,
79-
retinanet_resx101_coco_2x_800size
90+
retinanet_resx101_coco_2x_800size,
8091
)
8192
from official.vision.detection.models import ATSS, FCOS, FasterRCNN, FreeAnchor, RetinaNet
8293
from official.vision.detection.tools.utils import DetEvaluator
8394
from official.vision.keypoints.inference import KeypointEvaluator
8495
from official.vision.keypoints.models import (
8596
simplebaseline_res50,
8697
simplebaseline_res101,
87-
simplebaseline_res152
98+
simplebaseline_res152,
8899
)
89100
from official.vision.segmentation.configs import (
90101
deeplabv3plus_res101_cityscapes_768size,
91-
deeplabv3plus_res101_voc_512size
102+
deeplabv3plus_res101_voc_512size,
92103
)
93104
from official.vision.segmentation.models import DeepLabV3Plus
94-
from official.multimodal.clip.models import (
95-
rn50,
96-
rn101,
97-
rn50x4,
98-
rn50x16,
99-
rn50x64,
100-
vit_b_32,
101-
vit_b_16,
102-
vit_l_14,
103-
vit_l_14_336px,
104-
)
105-
from official.multimodal.clip.inference_utils import ClipInferenceUtils
873 KB
Loading
10.4 KB
Loading

official/assets/test_000009.png

124 KB
Loading

official/assets/test_000010.png

113 KB
Loading

official/assets/test_depth.png

19.5 KB
Loading

official/assets/test_sample_255.png

135 KB
Loading

official/assets/total.png

1.94 MB
Loading

official/multimodal/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dalle.dalle import DALLE

official/multimodal/clip/simple_tokenizer.py

+23
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,29 @@ def decode(self, tokens):
141141
'utf-8', errors="replace").replace('</w>', ' ')
142142
return text
143143

144+
def tokenize(
145+
self,
146+
texts: Union[str, List[str]],
147+
context_length: int = 77,
148+
truncate_text: bool = False
149+
):
150+
if isinstance(texts, str):
151+
texts = [texts]
152+
153+
all_tokens = [self.encode(text) for text in texts]
154+
result = np.zeros((len(all_tokens), context_length), dtype=np.int32)
155+
156+
for i, tokens in enumerate(all_tokens):
157+
if len(tokens) > context_length:
158+
if truncate_text:
159+
tokens = tokens[:context_length]
160+
else:
161+
raise RuntimeError(
162+
f"Input {texts[i]} is too long for context length {context_length}")
163+
result[i, :len(tokens)] = tokens
164+
165+
return mge.tensor(result)
166+
144167

145168
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False):
146169
"""

official/multimodal/dalle/README.md

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# DALLE
2+
3+
此仓库包含MegEngine实现的多模态模型DALLE以及文生图代码,但不包含训练代码。
4+
5+
## 图像重建
6+
7+
对于给定的大小为256x256的归一化四维输入,可以使用如下方式进行重建:
8+
9+
```python
10+
from official.multimodal.dalle.vae import OpenAIDiscreteVAE
11+
from official.multimodal.big_sleep.big_sleep import save_images
12+
13+
14+
vae = OpenAIDiscreteVAE(True)
15+
16+
img_seq = vae.get_codebook_indices(input)
17+
18+
reconstructed_image = vae.decode(img_seq)
19+
20+
save_images(reconstructed_image, './image.png')
21+
22+
```
23+
24+
25+
26+
## 文生图
27+
28+
可以使用以下代码体验文生图的功能,需要先下载[dalle_new_variety.bpe](https://data.megengine.org.cn/research/multimodality/dalle_new_variety.bpe)文件
29+
30+
```python
31+
from official.multimodal.dalle import coco_512_16_16d_16h_80tsl
32+
from official.multimodal.dalle import Generator
33+
34+
dalle = coco_512_16_16d_16h_80tsl()
35+
36+
generator = Generator(
37+
dalle,
38+
texts = ['A tower has a clock on it on a day with a blue sky'],
39+
num_images=64,
40+
batch_size=4,
41+
bpe_path = './dalle_new_variety.bpe',
42+
root='./dalle'
43+
)
44+
45+
generator()
46+
```
47+
48+
生成结果如下所示:
49+
50+
![res](../../assets/total.png)
51+
52+
53+
## 参考
54+
55+
[DALLE-pytorch](https://github.com/lucidrains/DALLE-pytorch)
56+
57+
[DALLE-pytorch-discussions](https://github.com/lucidrains/DALLE-pytorch/discussions/335)

official/multimodal/dalle/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .dalle import DALLE
2+
from .generate import Generator
3+
from .pretrained import coco_512_16_16d_16h_80tsl
4+
from .vae import (
5+
OpenAIDiscreteVAE,
6+
OpenAIDiscreteVAEDecoder,
7+
OpenAIDiscreteVAEEncoder,
8+
VQGanVAE,
9+
openai_discrete_VAE_decoder,
10+
openai_discrete_VAE_encoder
11+
)
12+
from .vae.vqgan_vae import vqgan_vae_1024

0 commit comments

Comments
 (0)