-
Notifications
You must be signed in to change notification settings - Fork 7
Description
python infer_style.py
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.03s/it]
Inputs shape: torch.Size([1, 3, 224, 224])
Inputs: {'pixel_values': tensor([[[[ 1.4336, 1.4629, 1.4629, ..., 1.6533, 1.6533, 1.6533],
[ 1.4482, 1.4629, 1.4629, ..., 1.6533, 1.6533, 1.6533],
[ 1.4629, 1.4629, 1.4775, ..., 1.6533, 1.6533, 1.6533],
...,
[ 0.4121, 0.4558, 0.4705, ..., -0.1427, -0.2449, -0.2886],
[ 0.4851, 0.4558, 0.4414, ..., -0.2156, -0.2449, -0.1864],
[ 0.2515, 0.3098, 0.1785, ..., -0.2156, -0.2595, -0.2156]],
[[ 1.6992, 1.7148, 1.7148, ..., 1.8945, 1.8945, 1.8945],
[ 1.7148, 1.7148, 1.7148, ..., 1.8945, 1.8945, 1.8945],
[ 1.7148, 1.7148, 1.7295, ..., 1.8945, 1.8945, 1.8945],
...,
[ 0.3040, 0.3340, 0.3640, ..., 0.0488, -0.0112, -0.0712],
[ 0.3340, 0.3040, 0.2891, ..., 0.0038, -0.0262, 0.0038],
[ 0.0939, 0.1840, 0.0638, ..., 0.0038, -0.0412, 0.0038]],
[[ 2.0469, 2.0605, 2.0605, ..., 2.1309, 2.1309, 2.1309],
[ 2.0605, 2.0605, 2.0605, ..., 2.1309, 2.1309, 2.1309],
[ 2.0605, 2.0605, 2.0742, ..., 2.1309, 2.1309, 2.1309],
...,
[ 0.5532, 0.5815, 0.5674, ..., 0.4253, 0.3257, 0.2688],
[ 0.5107, 0.4536, 0.4395, ..., 0.3542, 0.3398, 0.3826],
[ 0.2974, 0.3684, 0.2830, ..., 0.3542, 0.3257, 0.3542]]]],
device='cuda:0', dtype=torch.float16)}
Traceback (most recent call last):
File "infer_style.py", line 118, in
content_image_prompt = generate_caption(content_image) # 生成图像的文本描述
File "infer_style.py", line 49, in generate_caption
generated_ids = model.generate(
File "/home/ad/anaconda3/envs/cclappt38/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/ad/anaconda3/envs/cclappt38/lib/python3.8/site-packages/transformers/models/blip_2/modeling_blip_2.py", line 2316, in generate
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
RuntimeError: shape mismatch: value tensor of shape [65536] cannot be broadcast to indexing result of shape [0]