Skip to content

Commit

Permalink
Change layers_pattern logic
Browse files Browse the repository at this point in the history
Addreses part of huggingface#2155.

Description

So far, the layers_pattern argument would only work if there was a
prefix to the pattern. As an example, if the module name is:

decoder.layer.0.attn.to_q

and we pass layers_pattern="layer", this would match. However, if the
module name was:

layer.0.attn.to_q

it would not work.

Usually, when we create a model with AutoModelForFoo.from_pretrained,
the "layer" part would never be first. However, if we load a model
directly, e.g. through LlamaModel.from_pretrained, there is actually no
prefix. As a consequence, we get no match there.

With this PR, the prefix is made optional, so that the second pattern
also matches.

Status

I'm not sure yet if this should be merged, as it is technically
backwards incompatible. Users can still target the desired modules by
carefully crafting a regex for target_modules so that it only matches
the desired layer indices. However, this is tedious and layers_pattern
was introduced to avoid having to do this.
  • Loading branch information
BenjaminBossan committed Oct 17, 2024
1 parent 93ddb10 commit 3340650
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
11 changes: 7 additions & 4 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,18 +979,21 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
# TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave
# For now, empty layers_pattern means any layer pattern is ok
if layers_pattern is None or len(layers_pattern) == 0:
layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key)
match = re.match(r".*\.[^.]*\.(?P<idx>\d+)\.", key)
else:
layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern
for pattern in layers_pattern:
layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key)
if layer_index is not None:
match = re.match(rf"(.*\.)?{pattern}\.(?P<idx>\d+)\.", key)
if match is not None:
break

if match:
layer_index = match.groupdict().get("idx")

if layer_index is None:
target_module_found = False
else:
layer_index = int(layer_index.group(1))
layer_index = int(layer_index)
if isinstance(layer_indexes, int):
target_module_found = layer_index == layer_indexes
else:
Expand Down
10 changes: 4 additions & 6 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python3

# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -92,6 +89,10 @@
("foo.bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True),
("foo.bar.1.baz", ["baz", "spam"], [1], ["bar"], True),
("foo.bar.1.baz", ["baz", "spam"], [0, 1, 2], ["bar"], True),
("bar.1.baz", ["baz"], [0, 2], ["bar"], False),
("bar.1.baz", ["baz"], [0, 1, 2], ["foo"], False),
("bar.1.baz", ["baz"], [0, 2], ["bar"], False),
("bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True),
# empty layers_to_transform
("foo.bar.7.baz", ["baz"], [], ["bar"], True),
("foo.bar.7.baz", ["baz"], None, ["bar"], True),
Expand Down Expand Up @@ -119,14 +120,11 @@
# is one of the target nn.modules
("foo.bar.1.baz", ["baz"], [1], ["baz"], False),
# here, layers_pattern is 'bar', but only keys that contain '.bar' are valid.
("bar.1.baz", ["baz"], [1], ["bar"], False),
("foo.bar.001.baz", ["baz"], [1], ["bar"], True),
("foo.bar.1.spam.2.baz", ["baz"], [1], ["bar"], True),
("foo.bar.2.spam.1.baz", ["baz"], [1], ["bar"], False),
# some realistic examples: module using nn.Sequential
# for the below test case, key should contain '.blocks' to be valid, because of how layers_pattern is matched
("blocks.1.weight", ["weight"], [1], ["blocks"], False),
("blocks.1.bias", ["weight"], [1], ["blocks"], False),
("mlp.blocks.1.weight", ["weight"], [1], ["blocks"], True),
("mlp.blocks.1.bias", ["weight"], [1], ["blocks"], False),
]
Expand Down

0 comments on commit 3340650

Please sign in to comment.