From 3340650e847e98ccac7b0d1c83089e88d5476e14 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 17 Oct 2024 17:00:09 +0200 Subject: [PATCH] Change layers_pattern logic Addreses part of #2155. Description So far, the layers_pattern argument would only work if there was a prefix to the pattern. As an example, if the module name is: decoder.layer.0.attn.to_q and we pass layers_pattern="layer", this would match. However, if the module name was: layer.0.attn.to_q it would not work. Usually, when we create a model with AutoModelForFoo.from_pretrained, the "layer" part would never be first. However, if we load a model directly, e.g. through LlamaModel.from_pretrained, there is actually no prefix. As a consequence, we get no match there. With this PR, the prefix is made optional, so that the second pattern also matches. Status I'm not sure yet if this should be merged, as it is technically backwards incompatible. Users can still target the desired modules by carefully crafting a regex for target_modules so that it only matches the desired layer indices. However, this is tedious and layers_pattern was introduced to avoid having to do this. --- src/peft/tuners/tuners_utils.py | 11 +++++++---- tests/test_tuners_utils.py | 10 ++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 405277b6b5..277996b376 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -979,18 +979,21 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave # For now, empty layers_pattern means any layer pattern is ok if layers_pattern is None or len(layers_pattern) == 0: - layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + match = re.match(r".*\.[^.]*\.(?P\d+)\.", key) else: layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern for pattern in layers_pattern: - layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) - if layer_index is not None: + match = re.match(rf"(.*\.)?{pattern}\.(?P\d+)\.", key) + if match is not None: break + if match: + layer_index = match.groupdict().get("idx") + if layer_index is None: target_module_found = False else: - layer_index = int(layer_index.group(1)) + layer_index = int(layer_index) if isinstance(layer_indexes, int): target_module_found = layer_index == layer_indexes else: diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 06a47deb26..e713828c22 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -92,6 +89,10 @@ ("foo.bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [1], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [0, 1, 2], ["bar"], True), + ("bar.1.baz", ["baz"], [0, 2], ["bar"], False), + ("bar.1.baz", ["baz"], [0, 1, 2], ["foo"], False), + ("bar.1.baz", ["baz"], [0, 2], ["bar"], False), + ("bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True), # empty layers_to_transform ("foo.bar.7.baz", ["baz"], [], ["bar"], True), ("foo.bar.7.baz", ["baz"], None, ["bar"], True), @@ -119,14 +120,11 @@ # is one of the target nn.modules ("foo.bar.1.baz", ["baz"], [1], ["baz"], False), # here, layers_pattern is 'bar', but only keys that contain '.bar' are valid. - ("bar.1.baz", ["baz"], [1], ["bar"], False), ("foo.bar.001.baz", ["baz"], [1], ["bar"], True), ("foo.bar.1.spam.2.baz", ["baz"], [1], ["bar"], True), ("foo.bar.2.spam.1.baz", ["baz"], [1], ["bar"], False), # some realistic examples: module using nn.Sequential # for the below test case, key should contain '.blocks' to be valid, because of how layers_pattern is matched - ("blocks.1.weight", ["weight"], [1], ["blocks"], False), - ("blocks.1.bias", ["weight"], [1], ["blocks"], False), ("mlp.blocks.1.weight", ["weight"], [1], ["blocks"], True), ("mlp.blocks.1.bias", ["weight"], [1], ["blocks"], False), ]