-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add Orthogonal Subspace Fine-Tuning (OSF) Tuner for Parameter-Efficient Continual Learning #2685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Orthogonal Subspace Fine-Tuning (OSF) Tuner for Parameter-Efficient Continual Learning #2685
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for the thorough update, that's a good step forward.
A minor nit: Several files are missing the copyright notice, please make sure to include them in new source files (also make sure that they are not outdated, i.e. include the current year).
I like that you already implemented several (custom) tests, I think that's super helpful. Let's also add some tests to test_decoder_models.py
and test_encoder_decoder_models.py
similar to the test in test_custom_models.py
when you think the implementation can move forward in testing. Let's move the skips for convolutions to testing_common.py
, there are already similar exceptions in place.
Two bigger topics:
ModelWithOSF
seems to re-invent PEFT functionality inside PEFT, specifically the layer targeting + replacement portion. Let's streamline OSF with other tuners, i.e. have implementations for specific layers and by implementinginject_adapter
,_create_new_module
and_create_and_replace
to make it easier to branch out to other layer types / quantizations. The LoRA implementation maybe helpful, e.g.peft.tuners.lora.layers.LoraLayer
contains specific layers forLinear
andConv*d
specifics (no need to implement Conv now, of course). I can see that this conflicts with using a dict for specifying the top-k ranks per module. How about usingtarget_modules
and a singular value for the topk rank (e.g.,config.topk_r
) which can default toNone
(-> uses 50% of min(shape)). Every targeted module gets that topk rank or an automatic 50% one. We could also add something likerank_pattern
from LoRA to define exceptions (seelora.model.py
->_create_and_replace
). WDYT?
Example config:
OSFConfig(
target_modules='all-linear',
topk_r=None,
rank_pattern={
'q_proj': 10,
}
)
- It's not possible to use more than one adapter of OSF since the base model is modified and we therefore cannot switch between adapters (could be handy in pipeline scenarios where one model is used at several places with different adapters, for example). I left a comment at
decompose_weight_matrix
to discuss this.
Once we're done with the general implementation I think it'd be super if we could add an experiment to the MetaMathQA comparison suite so that we can compare OSF directly to other implementations.
Awesome will definitely evaluate our method once the implementation is complete to benchmark OSF against other methods in PEFT. |
@githubnemo great suggestion in response to the first bigger topic raised I have implemented the minimal PEFT integration changes: What we implemented:
Scope decisions we made:
Key files changed:
These changes integrate the OSF method modularly into PEFT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed feedback and your changes.
I think that the re-structuring of OSFModel
is almost complete and most of the comments are rather minor. As far as I can see the adhoc ModelWithOSF
is replaced by OSFModel
and OSFLayer
and can be removed - good progress!
I think this is a good time remove outdated code, to merge with main
, run make style
and run the tests to see if there's still something going horribly wrong.
Let's discuss whether we want to implement the importance score now or leave it up for implementation later. If I'm not mistaken I think that the importance score can technically be added later since it would compute the effective rank of layers based on two new hyper-parameters, so in that sense it is modular. Since it is quite a crucial part of the paper and is touted to improve multi-task learning (arguably one of the big selling points of OSF) I wonder if it should be included from the get-go. What's your opinion on that?
Regardless, I think we can a MetaMathQA experiment rather soon and check if there are major problems with memory consumption or runtime.
- Complete continual learning scenario with multiple tasks | ||
- Demonstration of OSF's catastrophic forgetting prevention | ||
- Configuration examples (target_modules, effective_rank, rank_pattern) | ||
- Performance comparison with baseline methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the performance comparison with baseline methods - at least for single tasks - is best done in the PEFT method comparison (MetaMathQA). Of course, feel free to provide a comparison with methods for support multi-task learning if it fits into the example without too much effort.
src/peft/tuners/osf/model.py
Outdated
def unload(self): | ||
raise NotImplementedError("OSF models cannot be unloaded yet") | ||
|
||
def merge_adapter(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") | ||
|
||
def unmerge_adapter(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") | ||
|
||
def merge_and_unload(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{merge_and_}unload
and {un}merge_adapter
are still open, commenting so I dont forget :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are keeping OSF non-mergeable for now, no code change is required here.
src/peft/tuners/osf/model.py
Outdated
|
||
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: | ||
for n, p in model.named_parameters(): | ||
if "svd_params" not in n and not n.endswith(("_U_low", "_S_low", "_V_low")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also check if self.prefix
is in the parameter name as to reduce the risk of overriding similarly named parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated _mark_only_adapters_as_trainable
to include the OSF prefix guard.
|
||
def __init__(self, base_layer: nn.Module, **kwargs) -> None: | ||
self.base_layer = base_layer | ||
self.effective_rank = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding (no change necessary): we diverge in naming from LoRA's r
parameter here because there's still the option of adding the importance weighting and if we'd add that then
effective_rank
overrides importance metric, layer-wise ranktarget
andminimum
rank as additional hyper params to compute the effective rank of layers according to their importance
Do I understand this correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is exactly right and effective_rank is also more conceptually descriptive here for what we are trying to do with OSF in identifying the important subspace (effective rank of the matrix).
src/peft/utils/osf_utils.py
Outdated
svd = { | ||
"U_high": U[:, :k].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"S_high": S[:k].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"V_high": Vt[:k, :].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"U_low": nn.Parameter(U[:, k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"S_low": nn.Parameter(S[k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"V_low": nn.Parameter(Vt[k:, :].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"rank_high": k, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the detailed explanation!
The sequential dependency of later added adapters to previous adapters removes a lot of the convenience gained by being able to remove individual adapters, I agree.
I'm OK with not implementing this.
@githubnemo added MetaMathQA experiment results. OSF achieves the highest accuracy at 55.72% among all PEFT methods in the benchmark! 😊 Top results for comparison:
Memory consumption and runtime look okay thus far as well. |
@NikhilNayak-debug very nice results! :) Is this ready for review from your side? If so, could you merge main and resolve the merge conflicts? This saves one review cycle. |
2d435a5
to
372a375
Compare
@githubnemo thank you. I have rebased the branch on top of the latest upstream main and resolved the conflicts. This is ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply, I was at a conference.
The changes look very good! There was quite a large PR merged in the mean time that refactored a good portion of the BaseTuner
infrastructure (#2771) which means that you need a lot less code now - I hope I highlighted all occurrences.
I'm currently in the process of reproducing the MetaMathQA results you posted. One thing I noticed is that there are more layers targeted and the default effective rank (min(shape) // 2
) is used which is using way more parameters than other methods. While it is certainly good to see that OSF is better than full fine-tuning it would be a fairer comparison to match the trainable parameter counts of the other methods.
train_task(model, task_2_data) | ||
|
||
# Task 3: recompute again and expand preserved subspace further | ||
base_model = model.base_model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's implement OSFModel.unload
and use it here so we don't have to assume base model paths (we don't have to support merging to support unloading).
def inject_adapter( | ||
self, | ||
model: nn.Module, | ||
adapter_name: str, | ||
autocast_adapter_dtype: bool = True, | ||
low_cpu_mem_usage: bool = False, | ||
) -> None: | ||
# Delegate to BaseTuner to perform standard target discovery and replacement | ||
return super().inject_adapter( | ||
model, | ||
adapter_name, | ||
autocast_adapter_dtype=autocast_adapter_dtype, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the inject_adapter
method completely to avoid errors when the underlying API changes (as is the case with #2637 where inject_adapter
gets a new keyword argument).
@staticmethod | ||
def _check_target_module_exists(osf_config, key): | ||
return check_target_module_exists(osf_config, key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staticmethod | |
def _check_target_module_exists(osf_config, key): | |
return check_target_module_exists(osf_config, key) |
This can now be removed since this equals the BaseTuner
implementation.
def _set_adapter_layers(self, enabled: bool = True) -> None: | ||
pass | ||
|
||
def enable_adapter_layers(self) -> None: | ||
self._set_adapter_layers(True) | ||
|
||
def disable_adapter_layers(self) -> None: | ||
self._set_adapter_layers(False) | ||
|
||
def set_adapter(self, adapter_name): | ||
self.active_adapter = adapter_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _set_adapter_layers(self, enabled: bool = True) -> None: | |
pass | |
def enable_adapter_layers(self) -> None: | |
self._set_adapter_layers(True) | |
def disable_adapter_layers(self) -> None: | |
self._set_adapter_layers(False) | |
def set_adapter(self, adapter_name): | |
self.active_adapter = adapter_name |
_set_adapter_layers
, enable_adapter_layers
, disable_adapter_layers
and set_adapter
can be removed now that BaseTuner
provides those.
def unload(self): | ||
raise NotImplementedError("OSF models cannot be unloaded yet") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def unload(self): | |
raise NotImplementedError("OSF models cannot be unloaded yet") |
The BaseTuner
implementation provides.
@contextmanager | ||
def hub_online_once(model_id: str): | ||
"""Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way) | ||
for model ids that were seen already so that the hub is not contacted twice for the same model id in said context. | ||
The cache (`HUB_MODEL_ACCESSES`) also tracks the number of cache hits per model id. | ||
The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there | ||
are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly | ||
communicate with the hub to download files / check versions / etc. | ||
Note that using this context manager can cause problems when used in code sections that access different resources. | ||
Example: | ||
``` | ||
def test_something(model_id, config_kwargs): | ||
with hub_online_once(model_id): | ||
model = ...from_pretrained(model_id) | ||
self.do_something_specific_with_model(model) | ||
``` | ||
It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests. | ||
Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses | ||
a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't | ||
use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to | ||
extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is | ||
used, so that these tests don't share a cache pool with the tests that don't use a tokenizer. | ||
""" | ||
global HUB_MODEL_ACCESSES | ||
override = {} | ||
|
||
try: | ||
if model_id in HUB_MODEL_ACCESSES: | ||
override = {"HF_HUB_OFFLINE": "1"} | ||
HUB_MODEL_ACCESSES[model_id] += 1 | ||
else: | ||
if model_id not in HUB_MODEL_ACCESSES: | ||
HUB_MODEL_ACCESSES[model_id] = 0 | ||
with ( | ||
# strictly speaking it is not necessary to set the environment variable since most code that's out there | ||
# is evaluating it at import time and we'd have to reload the modules for it to take effect. It's | ||
# probably still a good idea to have it if there's some dynamic code that checks it. | ||
mock.patch.dict(os.environ, override), | ||
mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"), | ||
mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"), | ||
): | ||
yield | ||
except Exception: | ||
# in case of an error we have to assume that we didn't access the model properly from the hub | ||
# for the first time, so the next call cannot be considered cached. | ||
if HUB_MODEL_ACCESSES.get(model_id) == 0: | ||
del HUB_MODEL_ACCESSES[model_id] | ||
raise | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@contextmanager | |
def hub_online_once(model_id: str): | |
"""Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way) | |
for model ids that were seen already so that the hub is not contacted twice for the same model id in said context. | |
The cache (`HUB_MODEL_ACCESSES`) also tracks the number of cache hits per model id. | |
The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there | |
are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly | |
communicate with the hub to download files / check versions / etc. | |
Note that using this context manager can cause problems when used in code sections that access different resources. | |
Example: | |
``` | |
def test_something(model_id, config_kwargs): | |
with hub_online_once(model_id): | |
model = ...from_pretrained(model_id) | |
self.do_something_specific_with_model(model) | |
``` | |
It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests. | |
Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses | |
a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't | |
use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to | |
extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is | |
used, so that these tests don't share a cache pool with the tests that don't use a tokenizer. | |
""" | |
global HUB_MODEL_ACCESSES | |
override = {} | |
try: | |
if model_id in HUB_MODEL_ACCESSES: | |
override = {"HF_HUB_OFFLINE": "1"} | |
HUB_MODEL_ACCESSES[model_id] += 1 | |
else: | |
if model_id not in HUB_MODEL_ACCESSES: | |
HUB_MODEL_ACCESSES[model_id] = 0 | |
with ( | |
# strictly speaking it is not necessary to set the environment variable since most code that's out there | |
# is evaluating it at import time and we'd have to reload the modules for it to take effect. It's | |
# probably still a good idea to have it if there's some dynamic code that checks it. | |
mock.patch.dict(os.environ, override), | |
mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"), | |
mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"), | |
): | |
yield | |
except Exception: | |
# in case of an error we have to assume that we didn't access the model properly from the hub | |
# for the first time, so the next call cannot be considered cached. | |
if HUB_MODEL_ACCESSES.get(model_id) == 0: | |
del HUB_MODEL_ACCESSES[model_id] | |
raise |
This is probably a merge artifact. hub_online_once
belongs to testing_utils.py
and is imported at the top.
if effective_rank is None: | ||
# Default to 50% of min dimension | ||
effective_rank = min(target.weight.shape) // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if effective_rank is None: | |
# Default to 50% of min dimension | |
effective_rank = min(target.weight.shape) // 2 |
This is already handled in the layers, so I think it's fine to remove this implementation.
svd = { | ||
"U_high": U[:, :k].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"S_high": S[:k].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"V_high": Vt[:k, :].contiguous().detach().to(device=device_local, dtype=orig_dtype), | ||
"U_low": nn.Parameter(U[:, k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"S_low": nn.Parameter(S[k:].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"V_low": nn.Parameter(Vt[k:, :].contiguous().detach().to(device=device_local, dtype=orig_dtype)), | ||
"rank_high": k, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm misunderstanding something but I understand effective_rank
to be the best guess estimate of the effective rank of the weight matrix - the rest is assumed 'free' for us to modify (thus, the "low" portion has S.shape[0] - effective_rank
space for the trained task).
If that is the case, then the example in the documentation for continual learning is wrong as it decreases the effective rank over time (therefore growing the low part, shrinking the high part).
In case the above is correct the effective rank parameter should be explained better in the config to avoid mistakes. Maybe it is even worth thinking about swapping the semantics to be more akin to LoRA's r
since I can imagine a few people being surprised by using effective_rank=1
and running into OOM errors. It might also be easier to start from the bottom up and taking more space than guessing a high number and decreasing it. WDYT?
def merge_adapter(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") | ||
|
||
def unmerge_adapter(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") | ||
|
||
def merge_and_unload(self, *args, **kwargs): | ||
raise NotImplementedError("OSF models do not support merging") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def merge_adapter(self, *args, **kwargs): | |
raise NotImplementedError("OSF models do not support merging") | |
def unmerge_adapter(self, *args, **kwargs): | |
raise NotImplementedError("OSF models do not support merging") | |
def merge_and_unload(self, *args, **kwargs): | |
raise NotImplementedError("OSF models do not support merging") | |
def unmerge_adapter(self, *args, **kwargs): | |
raise NotImplementedError("OSF models do not support merging") |
IIUC only unmerge
is unsupported. Since #2771 is now merged we can remove these implementations and use the BaseTuner
implementation which calls layer.merge
under the hood.
Summary
This PR adds a new parameter-efficient fine-tuning method called Orthogonal Subspace Fine-Tuning (OSF) to the PEFT library. OSF enables continual learning in LLMs by freezing the high-rank subspace of weight matrices and fine-tuning only the low-rank directions. This approach constrains updates to be orthogonal to previously important directions, thereby mitigating catastrophic forgetting without increasing parameter count.
Issue for this PR on PEFT repository
Tracked in PEFT Issue #2648
Key Features
Implements a new
OSFConfig
,OSFModel
, and tuner class undersrc/peft/tuners/osf/
following PEFT's standard APIIntegrates seamlessly with the
get_peft_model
API:Adds utility functions for:
Automatically enforces orthogonality constraints during training without requiring optimizer wrapping
Will include tests for saving, loading, and applying the OSF adapter in
tests/test_custom_models.py
Exports relevant modules at the package level for easier use with other PEFT components
Notes
Background
This implementation is based on the method described in our paper:
Sculpting Subspaces: Constrained Full Fine-Tuning in LLMs for Continual Learning
Paper on arXiv · Project Repository