-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
export model to fp16 #347
base: master
Are you sure you want to change the base?
export model to fp16 #347
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,16 +31,20 @@ | |
# ----------------------------------------------------------------------------- | ||
# common utilities | ||
|
||
def serialize_fp32(file, tensor): | ||
""" writes one fp32 tensor to file that is open in wb mode """ | ||
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() | ||
b = struct.pack(f'{len(d)}f', *d) | ||
file.write(b) | ||
|
||
def serialize_int8(file, tensor): | ||
""" writes one int8 tensor to file that is open in wb mode """ | ||
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) | ||
b = struct.pack(f'{len(d)}b', *d) | ||
def serialize(file, tensor, type): | ||
""" writes one tensor to file that is open in wb mode """ | ||
if type == 'fp32': | ||
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() | ||
b = struct.pack(f'{len(d)}f', *d) | ||
elif type == 'fp16': | ||
d = tensor.detach().cpu().view(-1).to(torch.float16).numpy() | ||
b = struct.pack(f'{len(d)}e', *d) | ||
elif dtype == 'bfloat16': | ||
d = tensor.detach().cpu().view(-1).to(torch.bfloat16).numpy() | ||
b = struct.pack(f'{len(d)}e', *d) | ||
elif type == 'int8': | ||
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) | ||
b = struct.pack(f'{len(d)}b', *d) | ||
file.write(b) | ||
|
||
def quantize_q80(w, group_size): | ||
|
@@ -72,7 +76,7 @@ def quantize_q80(w, group_size): | |
# ----------------------------------------------------------------------------- | ||
# legacy | ||
|
||
def legacy_export(model, filepath): | ||
def legacy_export(model, filepath, type): | ||
""" Original export of llama2.c bin files, i.e. version v0 """ | ||
out_file = open(filepath, 'wb') | ||
|
||
|
@@ -89,38 +93,38 @@ def legacy_export(model, filepath): | |
out_file.write(header) | ||
|
||
# next write out the embedding weights | ||
serialize_fp32(out_file, model.tok_embeddings.weight) | ||
serialize(out_file, model.tok_embeddings.weight, type) | ||
|
||
# now all the layers | ||
# attention weights | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.attention_norm.weight) | ||
serialize(out_file, layer.attention_norm.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.attention.wq.weight) | ||
serialize(out_file, layer.attention.wq.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.attention.wk.weight) | ||
serialize(out_file, layer.attention.wk.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.attention.wv.weight) | ||
serialize(out_file, layer.attention.wv.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.attention.wo.weight) | ||
serialize(out_file, layer.attention.wo.weight, type) | ||
# ffn weights | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.ffn_norm.weight) | ||
serialize(out_file, layer.ffn_norm.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.feed_forward.w1.weight) | ||
serialize(out_file, layer.feed_forward.w1.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.feed_forward.w2.weight) | ||
serialize(out_file, layer.feed_forward.w2.weight, type) | ||
for layer in model.layers: | ||
serialize_fp32(out_file, layer.feed_forward.w3.weight) | ||
serialize(out_file, layer.feed_forward.w3.weight, type) | ||
# final rmsnorm | ||
serialize_fp32(out_file, model.norm.weight) | ||
serialize(out_file, model.norm.weight, type) | ||
# freqs_cis | ||
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) | ||
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) | ||
serialize(out_file, model.freqs_cos[:p.max_seq_len], type) | ||
serialize(out_file, model.freqs_sin[:p.max_seq_len], type) | ||
|
||
# final classifier weights | ||
if not shared_classifier: | ||
serialize_fp32(out_file, model.output.weight) | ||
serialize(out_file, model.output.weight, type) | ||
|
||
# write to binary file | ||
out_file.close() | ||
|
@@ -129,7 +133,7 @@ def legacy_export(model, filepath): | |
# ----------------------------------------------------------------------------- | ||
# new version | ||
|
||
def version1_export(model, filepath): | ||
def version1_export(model, filepath, type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we'd have to serialize the type to the header too |
||
""" | ||
Export the model weights in full float32 .bin file to be read from C. | ||
This is same as legacy_export, but with a proper header. | ||
|
@@ -173,7 +177,7 @@ def version1_export(model, filepath): | |
if not shared_classifier: | ||
weights.append(model.output.weight) | ||
for w in weights: | ||
serialize_fp32(out_file, w) | ||
serialize(out_file, w, type) | ||
|
||
# write to binary file | ||
out_file.close() | ||
|
@@ -233,10 +237,10 @@ def version2_export(model, filepath, group_size=64): | |
|
||
# first let's write out all the params that we are keeping in fp32: the norms | ||
for layer in model.layers: # attention norms | ||
serialize_fp32(out_file, layer.attention_norm.weight) | ||
serialize(out_file, layer.attention_norm.weight, 'fp32') | ||
for layer in model.layers: # MLP norms | ||
serialize_fp32(out_file, layer.ffn_norm.weight) | ||
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm | ||
serialize(out_file, layer.ffn_norm.weight, 'fp32') | ||
serialize(out_file, model.norm.weight, 'fp32') # final pre-classifier norm | ||
|
||
# now let's write out all the params that we are quantizing to Q8_0 | ||
# note we skip classifier weights, which are shared with the embedding | ||
|
@@ -246,7 +250,7 @@ def version2_export(model, filepath, group_size=64): | |
# quantize this weight | ||
q, s, err = quantize_q80(w, group_size) | ||
# save the int8 weights to file | ||
serialize_int8(out_file, q) # save the tensor in int8 | ||
serialize(out_file, q, 'int8') # save the tensor in int8 | ||
scales.append(s) # we'll do all the scales after all the qs | ||
# logging | ||
ew.append((err, w.shape)) | ||
|
@@ -255,7 +259,7 @@ def version2_export(model, filepath, group_size=64): | |
# save the scaling factors in fp32 here | ||
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C | ||
for s in scales: | ||
serialize_fp32(out_file, s) | ||
serialize(out_file, s, 'fp32') | ||
|
||
# print the highest error across all weights, should be very small, e.g. O(~0.001) | ||
ew.sort(reverse=True) | ||
|
@@ -404,11 +408,11 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim) | |
# ----------------------------------------------------------------------------- | ||
# API entrypoint | ||
|
||
def model_export(model, filepath, version): | ||
def model_export(model, filepath, version, type): | ||
if version == 0: | ||
legacy_export(model, filepath) | ||
legacy_export(model, filepath, type) | ||
elif version == 1: | ||
version1_export(model, filepath) | ||
version1_export(model, filepath, type) | ||
elif version == 2: | ||
version2_export(model, filepath) | ||
else: | ||
|
@@ -450,6 +454,7 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False): | |
parser = argparse.ArgumentParser() | ||
parser.add_argument("filepath", type=str, help="the output filepath") | ||
parser.add_argument("--version", default=0, type=int, help="the version to export with") | ||
parser.add_argument("--type", default='fp32', type=str, help="the data type to export to (fp32, fp16, bfloat16)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not 100% decided if type should be a separate variable that is written into the header, or if it should just be absorbed into version. E.g.: version 0 original float32 etc and just go that way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's another PR that just uses "--version" for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget the support to If using a version number for each, it is both not intuitive as also it will have a lot of "versions" |
||
group = parser.add_mutually_exclusive_group(required=True) | ||
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") | ||
group.add_argument("--meta-llama", type=str, help="meta llama model path") | ||
|
@@ -467,4 +472,4 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False): | |
parser.error("Can't load input model!") | ||
|
||
# export | ||
model_export(model, args.filepath, args.version) | ||
model_export(model, args.filepath, args.version, args.type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels simplfiable