-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
The generate.py script won't run because the weights on hugging face are incompatible with the model architecture in the repository.
Here's a greatly simplified part of the file generated.py.
from utils import *
from config import *
from transformers import GPT2Config
import requests
from tqdm import tqdm
filename = "weights.pth"
url = "https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth"
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
chunk_size = 10
with open(filename, "wb") as file, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
patchilizer = Patchilizer()
patch_config = GPT2Config(
num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
vocab_size=1,
)
char_config = GPT2Config(
num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE,
max_position_embeddings=PATCH_SIZE,
vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
checkpoint = torch.load("weights.pth")
model.load_state_dict(checkpoint["model"])
Result of running this is
weights.pth: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 0.99G/0.99G [06:52<00:00, 2.57MB/s]
Traceback (most recent call last):
File "/home/jet08013/GitHub/tunesformer/jeremy.py", line 41, in <module>
model.load_state_dict(checkpoint["model"])
File "/home/jet08013/anaconda3/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TunesFormer:
Unexpected key(s) in state_dict: "patch_level_decoder.base.h.0.attn.bias",
"patch_level_decoder.base.h.0.attn.masked_bias",
"patch_level_decoder.base.h.1.attn.bias", "patch_level_decoder.base.h.1.attn.masked_bias",
"patch_level_decoder.base.h.2.attn.bias", "patch_level_decoder.base.h.2.attn.masked_bias",
"patch_level_decoder.base.h.3.attn.bias", "patch_level_decoder.base.h.3.attn.masked_bias",
"patch_level_decoder.base.h.4.attn.bias", "patch_level_decoder.base.h.4.attn.masked_bias",
"patch_level_decoder.base.h.5.attn.bias", "patch_level_decoder.base.h.5.attn.masked_bias",
"patch_level_decoder.base.h.6.attn.bias", "patch_level_decoder.base.h.6.attn.masked_bias",
"patch_level_decoder.base.h.7.attn.bias", "patch_level_decoder.base.h.7.attn.masked_bias",
"patch_level_decoder.base.h.8.attn.bias", "patch_level_decoder.base.h.8.attn.masked_bias",
"char_level_decoder.base.transformer.h.0.attn.bias", "char_level_decoder.base.transformer.h.0.attn.masked_bias",
"char_level_decoder.base.transformer.h.1.attn.bias", "char_level_decoder.base.transformer.h.1.attn.masked_bias",
"char_level_decoder.base.transformer.h.2.attn.bias", "char_level_decoder.base.transformer.h.2.attn.masked_bias".
It looks like the saved weights include biases to the attention layers that aren't present in the model description.
Metadata
Metadata
Assignees
Labels
No labels