Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export model to fp16 #347

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 41 additions & 36 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@
# -----------------------------------------------------------------------------
# common utilities

def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
file.write(b)

def serialize_int8(file, tensor):
""" writes one int8 tensor to file that is open in wb mode """
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
b = struct.pack(f'{len(d)}b', *d)
def serialize(file, tensor, type):
""" writes one tensor to file that is open in wb mode """
if type == 'fp32':
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels simplfiable

d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
elif type == 'fp16':
d = tensor.detach().cpu().view(-1).to(torch.float16).numpy()
b = struct.pack(f'{len(d)}e', *d)
elif dtype == 'bfloat16':
d = tensor.detach().cpu().view(-1).to(torch.bfloat16).numpy()
b = struct.pack(f'{len(d)}e', *d)
elif type == 'int8':
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
b = struct.pack(f'{len(d)}b', *d)
file.write(b)

def quantize_q80(w, group_size):
Expand Down Expand Up @@ -72,7 +76,7 @@ def quantize_q80(w, group_size):
# -----------------------------------------------------------------------------
# legacy

def legacy_export(model, filepath):
def legacy_export(model, filepath, type):
""" Original export of llama2.c bin files, i.e. version v0 """
out_file = open(filepath, 'wb')

Expand All @@ -89,38 +93,38 @@ def legacy_export(model, filepath):
out_file.write(header)

# next write out the embedding weights
serialize_fp32(out_file, model.tok_embeddings.weight)
serialize(out_file, model.tok_embeddings.weight, type)

# now all the layers
# attention weights
for layer in model.layers:
serialize_fp32(out_file, layer.attention_norm.weight)
serialize(out_file, layer.attention_norm.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wq.weight)
serialize(out_file, layer.attention.wq.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wk.weight)
serialize(out_file, layer.attention.wk.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wv.weight)
serialize(out_file, layer.attention.wv.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.attention.wo.weight)
serialize(out_file, layer.attention.wo.weight, type)
# ffn weights
for layer in model.layers:
serialize_fp32(out_file, layer.ffn_norm.weight)
serialize(out_file, layer.ffn_norm.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w1.weight)
serialize(out_file, layer.feed_forward.w1.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w2.weight)
serialize(out_file, layer.feed_forward.w2.weight, type)
for layer in model.layers:
serialize_fp32(out_file, layer.feed_forward.w3.weight)
serialize(out_file, layer.feed_forward.w3.weight, type)
# final rmsnorm
serialize_fp32(out_file, model.norm.weight)
serialize(out_file, model.norm.weight, type)
# freqs_cis
serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
serialize(out_file, model.freqs_cos[:p.max_seq_len], type)
serialize(out_file, model.freqs_sin[:p.max_seq_len], type)

# final classifier weights
if not shared_classifier:
serialize_fp32(out_file, model.output.weight)
serialize(out_file, model.output.weight, type)

# write to binary file
out_file.close()
Expand All @@ -129,7 +133,7 @@ def legacy_export(model, filepath):
# -----------------------------------------------------------------------------
# new version

def version1_export(model, filepath):
def version1_export(model, filepath, type):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd have to serialize the type to the header too

"""
Export the model weights in full float32 .bin file to be read from C.
This is same as legacy_export, but with a proper header.
Expand Down Expand Up @@ -173,7 +177,7 @@ def version1_export(model, filepath):
if not shared_classifier:
weights.append(model.output.weight)
for w in weights:
serialize_fp32(out_file, w)
serialize(out_file, w, type)

# write to binary file
out_file.close()
Expand Down Expand Up @@ -233,10 +237,10 @@ def version2_export(model, filepath, group_size=64):

# first let's write out all the params that we are keeping in fp32: the norms
for layer in model.layers: # attention norms
serialize_fp32(out_file, layer.attention_norm.weight)
serialize(out_file, layer.attention_norm.weight, 'fp32')
for layer in model.layers: # MLP norms
serialize_fp32(out_file, layer.ffn_norm.weight)
serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
serialize(out_file, layer.ffn_norm.weight, 'fp32')
serialize(out_file, model.norm.weight, 'fp32') # final pre-classifier norm

# now let's write out all the params that we are quantizing to Q8_0
# note we skip classifier weights, which are shared with the embedding
Expand All @@ -246,7 +250,7 @@ def version2_export(model, filepath, group_size=64):
# quantize this weight
q, s, err = quantize_q80(w, group_size)
# save the int8 weights to file
serialize_int8(out_file, q) # save the tensor in int8
serialize(out_file, q, 'int8') # save the tensor in int8
scales.append(s) # we'll do all the scales after all the qs
# logging
ew.append((err, w.shape))
Expand All @@ -255,7 +259,7 @@ def version2_export(model, filepath, group_size=64):
# save the scaling factors in fp32 here
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
for s in scales:
serialize_fp32(out_file, s)
serialize(out_file, s, 'fp32')

# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)
Expand Down Expand Up @@ -404,11 +408,11 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim)
# -----------------------------------------------------------------------------
# API entrypoint

def model_export(model, filepath, version):
def model_export(model, filepath, version, type):
if version == 0:
legacy_export(model, filepath)
legacy_export(model, filepath, type)
elif version == 1:
version1_export(model, filepath)
version1_export(model, filepath, type)
elif version == 2:
version2_export(model, filepath)
else:
Expand Down Expand Up @@ -450,6 +454,7 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
parser.add_argument("--type", default='fp32', type=str, help="the data type to export to (fp32, fp16, bfloat16)")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not 100% decided if type should be a separate variable that is written into the header, or if it should just be absorbed into version. E.g.:

version 0 original float32
version 1 original float16
version 2 new header int8

etc and just go that way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another PR that just uses "--version" for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget the support to bf16 (and maybe others to come)

If using a version number for each, it is both not intuitive as also it will have a lot of "versions"

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
Expand All @@ -467,4 +472,4 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser.error("Can't load input model!")

# export
model_export(model, args.filepath, args.version)
model_export(model, args.filepath, args.version, args.type)