Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Models with Non-Standard Architectures (e.g., Multimodal Models) #450

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mergekit/_data/architectures/gpt2.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
"num_layers_config_key": "n_layer",
"layer_templates": {
"weights": [
{
"name": "h.${layer_index}.attn.bias"
},
{
"name": "h.${layer_index}.attn.c_attn.weight"
},
Expand Down
250 changes: 248 additions & 2 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

import importlib.resources
import re
import string
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import ClassVar, Dict, List, Optional, Tuple, Union

from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from pydantic import BaseModel, Field
from transformers import PretrainedConfig
from typing_extensions import Literal

import mergekit._data.architectures
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex


class WeightInfo(BaseModel, frozen=True):
Expand Down Expand Up @@ -199,6 +206,112 @@ def _template_substitution(
return TemplateWithArithmetic(template).substitute(substitutions)


def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]:
hierarchy = defaultdict(list)

# Regular expression to match layers (denoted by .{integer}. by default)
layer_pattern = re.compile(layer_prefix)

if names:
for name in names:
# Find the layer part of the string (e.g., 'model.layers.0.')
match = layer_pattern.search(name)
if match:
# Extract everything up to the layer identifier
layer_prefix = name[: match.end() - 1] # e.g., 'model.layers.0'
# Extract the parameter name after the layer identifier
param_name = name[match.end() :] # e.g., 'input_layernorm.weight'
# Add the parameter name to the corresponding layer in the hierarchy
hierarchy[layer_prefix].append(param_name)
else:
hierarchy[name].append("")

return hierarchy


class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel):
arch_name: str = Field(default="")
parameter_names: List[str] = Field(default_factory=list)
embed: List[str] = Field(default_factory=list)
layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict)
prefix_tracker: Dict[str, str] = Field(default_factory=dict)

def __init__(
self,
arch_name: str,
parameter_names: List[str],
prefix_tracker: Optional[Dict[str, str]] = None,
):
super().__init__()
self.arch_name = arch_name
self.parameter_names = parameter_names
self.layered_parameter_names = _hierarchy(self.parameter_names)
self.prefix_tracker = prefix_tracker or {}
self.embed = self._find_embed_params()

def _find_embed_params(self) -> List[str]:
"""Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling."""
embed_params = []
for name in self.parameter_names:
if any(embedding_name in name for embedding_name in ["lm_head", "embed"]):
embed_params.append(name)
return embed_params

def name(self) -> str:
"""Returns the architecture name."""
return self.arch_name

def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
"""This architecture does not distinguish pre-weights."""
return []

def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
"""This architecture does not distinguish post-weights."""
return []

def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
"""
Retrieves the weights for a specified layer, adjusting names for prefixes if applicable.
"""
layer_name = list(self.layered_parameter_names.keys())[index]
adjusted_layer_name = self._adjust_layer_name(layer_name, config)

weights = [
WeightInfo(
name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name,
is_embed=(layer_name in self.embed),
)
for param in self.layered_parameter_names[layer_name]
]
return (
weights
if weights
else [
WeightInfo(
name=adjusted_layer_name, is_embed=(layer_name in self.embed)
)
]
)

def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str:
"""Adjust layer names by removing any prefix as indicated in the prefix tracker."""
if config and config.name_or_path in self.prefix_tracker:
prefix = self.prefix_tracker.get(config.name_or_path, "")
if layer_name.startswith(prefix):
return layer_name[len(prefix) :]
return layer_name

def sliceable(self) -> bool:
"""Indicates if the architecture supports slicing."""
return True

def num_layers(self, config: PretrainedConfig) -> int:
"""Returns the number of layers based on layered parameter names."""
return len(self.layered_parameter_names)


class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True):
definition: JSONArchitectureDefinition

Expand Down Expand Up @@ -365,7 +478,10 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
return MixtralTensorNames.from_config(config)

if arch_name not in NAME_TO_ARCH:
raise RuntimeError(f"Unsupported architecture {arch_name}")
warnings.warn(
f"Unsupported architecture {arch_name}, attempting automatic architecture generation"
)
return False

candidates = list(NAME_TO_ARCH[arch_name])
if len(candidates) == 1:
Expand All @@ -375,6 +491,136 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
if c.definition.expected_model_type == config.model_type:
return c

raise RuntimeError(
warnings.warn(
f"Unsupported model_type {config.model_type} for architecture {arch_name}"
)
return False


def strip_prefix(name: str, prefixes: List[str]) -> str:
"""Remove any prefix in prefixes from the start of the name."""
for prefix in prefixes:
if name.startswith(prefix + "."):
return name[len(prefix) + 1 :]
return name


def is_ordered_sublist_with_prefix(
list1: List[str], list2: List[str], prefixes: List[str]
) -> bool:
"""
Check if list1 matches a subset of list2 in the correct order after optional prefix removal.
"""
stripped_list2 = [strip_prefix(name, prefixes) for name in list2]

try:
start_index = stripped_list2.index(list1[0])
for i, item in enumerate(list1):
if stripped_list2[start_index + i] != item:
return False
return True
except (ValueError, IndexError):
return False


def find_prefix_and_check_sublist(list1: List[str], list2: List[str]) -> Optional[str]:
"""
Attempts to find a prefix from elements in list2 that makes list1 an ordered sublist of list2.
"""
if len(list1) > len(list2):
list1, list2 = list2, list1

possible_prefixes = {item.split(".")[0] for item in list2 if "." in item}

for prefix in possible_prefixes:
if is_ordered_sublist_with_prefix(list1, list2, [prefix]):
return prefix

return None


def find_prefixes_for_alignment(param_names: List[List[str]]) -> List[str]:
"""Determine prefixes needed to align parameter names in order of the longest list."""
prefixes = [""]
for i in range(1, len(param_names)):
if param_names[0] != param_names[i]:
prefix = find_prefix_and_check_sublist(param_names[0], param_names[i])
if not prefix:
raise ValueError("Could not resolve model architecture automatically.")
else:
prefix = ""
prefixes.append(prefix)
return prefixes


def find_common_ordered_names(
param_names: List[List[str]], prefixes: List[str]
) -> List[str]:
"""Identify and return common parameter names across all models, ensuring correct order."""
common_names = set(param_names[0])
for i in range(1, len(param_names)):
prefix = f"{prefixes[i]}." if prefixes[i] else ""
common_names.intersection_update({prefix + name for name in param_names[i]})
return [name for name in param_names[0] if name in common_names]


def _get_model_parameter_names(repo_id: str) -> list:
"""
Get the parameter names of a model from a Hugging Face repo or local directory.
"""
model_dir = _resolve_model_directory(repo_id)
return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())


def _resolve_model_directory(repo_id: str) -> Path:
"""
Resolve the model directory either from a local path, URL, or by downloading from Hugging Face.
"""
if Path(repo_id).is_dir():
return Path(repo_id)

try:
return Path(snapshot_download(repo_id))
except HfHubHTTPError:
raise ValueError(f"Model {repo_id} not found on Hugging Face Hub.")
except Exception as e:
raise ValueError(f"Error locating model {repo_id}: {e}")


def _infer_architecture_info(merge_config):
"""
Infers and returns architecture info, including parameter names and prefixes for alignment.
"""
param_names = [
_get_model_parameter_names(source_model.model.path)
for source_model in merge_config.referenced_models()
]

if all(param_names[0] == param_names[i] for i in range(1, len(param_names))):
arch_name = merge_config.referenced_models()[0].model.path
parameter_names = param_names[0]
prefix_tracker = {}
else:
# Pair param_names with referenced models and sort by length
paired_list = list(zip(param_names, merge_config.referenced_models()))
paired_list.sort(key=lambda x: len(x[0]), reverse=True)
param_names, referenced_models = zip(*paired_list)

prefixes = find_prefixes_for_alignment(param_names)
common_names = find_common_ordered_names(param_names, prefixes)

prefix_tracker = {
model.model.path: f"{prefix}." if prefix else ""
for model, prefix in zip(referenced_models, prefixes)
}

arch_name = referenced_models[0].model.path
parameter_names = common_names

return [
AutomaticArchitectureInfo(
arch_name=arch_name,
parameter_names=parameter_names,
prefix_tracker=prefix_tracker,
)
]
49 changes: 37 additions & 12 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,33 @@
import os
import shutil
from collections import Counter
from typing import Optional
from pathlib import Path
from typing import List, Optional

import tqdm
import transformers
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError

from mergekit._data import chat_templates
from mergekit.architecture import ArchitectureInfo, get_architecture_info
from mergekit.architecture import (
ArchitectureInfo,
AutomaticArchitectureInfo,
_infer_architecture_info,
get_architecture_info,
)
from mergekit.card import generate_card
from mergekit.config import MergeConfiguration
from mergekit.graph import Executor
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
from mergekit.io.tasks import LoaderCache
from mergekit.options import MergeOptions
from mergekit.plan import MergePlanner
from mergekit.tokenizer import TokenizerInfo

# Overwritten by the environment variable HF_HOME if set
HF_HOME_DEFAULT = "~/.cache/huggingface"


def run_merge(
merge_config: MergeConfiguration,
Expand All @@ -47,16 +59,7 @@ def run_merge(
if not merge_config.models and not merge_config.slices:
raise RuntimeError("No output requested")

model_arch_info = [
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]
if not options.allow_crimes:
if not all(a == model_arch_info[0] for a in model_arch_info[1:]):
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
arch_info = model_arch_info[0]
arch_info = _load_arch_info(merge_config, options)

# initialize loader cache and set options
loader_cache = LoaderCache()
Expand Down Expand Up @@ -273,4 +276,26 @@ def _update_config_vocab(
)


def _load_arch_info(merge_config, options):
"""
Loads architecture information, handling cases where models lack predefined architecture info.
"""
model_arch_info = [
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]

if not any(a is False for a in model_arch_info):
if not options.allow_crimes and not all(
a == model_arch_info[0] for a in model_arch_info[1:]
):
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
else:
model_arch_info = _infer_architecture_info(merge_config)

return model_arch_info[0]


__all__ = ["MergeOptions", "run_merge"]
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run_and_check_merge(
index = ShardedTensorIndex.from_disk(tmpdir)
for weight_info in arch_info.all_weights(config):
if weight_info.name not in index.tensor_paths:
raise RuntimeError(f"Output missing tensor {tensor_name}")
raise RuntimeError(f"Output missing tensor {weight_info.name}")

if validate:
validate(tmpdir)
Expand Down
Loading