Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verbose weights #8

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 130 additions & 24 deletions lora-inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import os
from collections import OrderedDict
import re
from datetime import datetime
from pathlib import Path
from typing import Callable, OrderedDict
Expand Down Expand Up @@ -210,12 +211,109 @@ def find_vectors_weights(vectors):
print(f"Text Encoder weight average magnitude: {avg_mag}")
print(f"Text Encoder weight average strength: {avg_str}")

_keys = [
unet_attn_weight_results.keys(),
unet_conv_weight_results.keys(),
text_encoder_weight_results.keys(),
]

return {
"unet": unet_attn_weight_results,
"text_encoder": text_encoder_weight_results,
}


def find_vectors_weight_blocks(vectors):
weight = ".weight"

results = {}

print(f"model key count: {len(vectors.keys())}")

for k in vectors.keys():
if k.endswith(".weight") is False:
continue

x = find_group(k)

if x is not None:
if x not in results.keys():
results[x] = []
results[x].append(torch.flatten(vectors.get_tensor(k)).tolist())

for key in results.keys():
sum_mag = 0 # average magnitude
sum_str = 0 # average strength
for vectors in results.get(key):
sum_mag += get_vector_data_magnitude(vectors)
sum_str += get_vector_data_strength(vectors)

avg_mag = sum_mag / len(results[key])
avg_str = sum_str / len(results[key])

print(f"{key} weight average magnitude: {avg_mag}")
print(f"{key} weight average strength: {avg_str}")

return results


def find_group(key):
"""
Find the group we want to put these keys into
TODO: describe this better
"""
r = r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_"

matches = re.search(r, key)

if matches is not None and matches.group(3) == "attentions":
r2 = r"(transformer_blocks)_(\d+)_(attn\d+)_(to_[^\.]+).(lora_up|lora_down)"

matches2 = re.search(r2, key)

if matches2 is not None:
# print(
# matches2.group(1),
# matches2.group(2),
# matches2.group(3),
# matches2.group(4),
# matches2.group(5),
# )

block_name = matches.group(1) + "_block" # up|down
block_id = matches.group(2) # \d
transformer_block = matches2.group(2)
attn = matches.group(3)

attn_to = matches2.group(4)

return f"unet_{block_name}_{block_id}_transformer_block_{transformer_block}_{attn}_{attn_to}"

# atten
# atten2
# ff

else:
# lora_unet_up_blocks_3_attentions_0_proj_in.lora_down.weight
# lora_unet_up_blocks_3_attentions_0_proj_in.lora_up.weight
# lora_unet_up_blocks_3_attentions_0_proj_out.lora_down.weight
# lora_unet_up_blocks_3_attentions_0_proj_out.lora_up.weight
# proj
r3 = r"(proj_(in|out)).(lora_up|lora_down)"
# key2 = "transformer_blocks_0_attn1_to_k.lora_down.weight"

matches3 = re.search(r3, key)

if matches3 is not None:
block_name = matches.group(1) + "_block" # up|down
block_id = matches.group(2) # \d
in_out = matches3.group(1)

return f"{block_name}_{block_id}_{in_out}"

return None


def get_vector_data_strength(data: dict[int, Tensor]) -> float:
value = 0
for n in data:
Expand Down Expand Up @@ -253,21 +351,25 @@ def process_safetensor_file(file, args):
filename = os.path.basename(file)
print(file)

parsed = {}
meta = {}

if metadata is not None:
parsed = parse_metadata(metadata)
meta = parse_metadata(metadata)
else:
parsed = {}
meta = {}

parsed["file"] = file
parsed["filename"] = filename
meta["file"] = file
meta["filename"] = filename

if args.weights:
find_vectors_weights(f)
weights = find_vectors_weights(f)
elif args.verbose_weights:
weights = find_vectors_weight_blocks(f)
else:
weights = None

print("----------------------")
return parsed
return (weights, meta)


def print_list(list):
Expand Down Expand Up @@ -380,7 +482,7 @@ def process(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="LoRA Inspector")

parser.add_argument(
"lora_file_or_dir", type=str, help="Directory containing the lora files"
Expand All @@ -407,12 +509,19 @@ def process(args):
help="Show the most common tags in the training set",
)

parser.add_argument(
"-v",
"--verbose_weights",
action="store_true",
help="Experimental. Average magnitude and strength, separated, for all blocks and attention",
)

args = parser.parse_args()
results = process(args)
(weights, meta) = process(args)

if args.save_meta:
if type(results) == list:
for result in results:
if type(meta) == list:
for result in meta:
# print("result", json.dumps(result, indent=4, sort_keys=True, default=str))
if "ss_session_id" in result:
newfile = (
Expand All @@ -426,24 +535,19 @@ def process(args):
save_metadata(newfile, result)
print(f"Metadata saved to {newfile}.json")
else:
if "ss_session_id" in results:
newfile = (
"meta/"
+ str(results["filename"])
+ "-"
+ results["ss_session_id"]
)
if "ss_session_id" in meta:
newfile = "meta/" + str(meta["filename"]) + "-" + meta["ss_session_id"]
else:
newfile = "meta/" + str(results["filename"])
save_metadata(newfile, results)
newfile = "meta/" + str(meta["filename"])
save_metadata(newfile, meta)
print(f"Metadata saved to {newfile}.json")

if args.tags:
print("-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=")
print("Tags")
print("-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=")
if type(results) == list:
for result in results:
if type(meta) == list:
for result in meta:
if "ss_tag_frequency" in result:
freq = result["ss_tag_frequency"]
tags = []
Expand All @@ -456,8 +560,8 @@ def process(args):
else:
print("No tag frequency found")
else:
if "ss_tag_frequency" in results:
freq = results["ss_tag_frequency"]
if "ss_tag_frequency" in meta:
freq = meta["ss_tag_frequency"]
tags = []
longest_tag = 0
for k in freq.keys():
Expand All @@ -477,3 +581,5 @@ def process(args):
else:
print("No tag frequency found")
# print(results)
# if weights is not None:
# print(weights.keys())