-
Notifications
You must be signed in to change notification settings - Fork 2.1k
minor changes to OFT to make it faster #2805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
700f64e
8317b9d
96f7699
806a425
da73c33
bc597cc
85ed7ec
29833f4
ed56c57
c3bc6e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -114,6 +114,7 @@ def _pytorch_skew_symmetric_inv(self, matrix, block_size): | |||||||||||||||||||||||
vec = matrix[:, self.rows, self.cols] | ||||||||||||||||||||||||
return vec | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@torch.compile | ||||||||||||||||||||||||
def _cayley_batch( | ||||||||||||||||||||||||
self, Q: torch.Tensor, block_size: int, use_cayley_neumann: bool = True, num_neumann_terms: int = 5 | ||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||
|
@@ -139,9 +140,11 @@ def _cayley_batch( | |||||||||||||||||||||||
R.add_(Q_squared, alpha=2.0) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Q_power = Q_squared | ||||||||||||||||||||||||
for i in range(3, num_neumann_terms): | ||||||||||||||||||||||||
for _ in range(3, num_neumann_terms - 1): | ||||||||||||||||||||||||
Q_power = torch.bmm(Q_power, Q_skew) | ||||||||||||||||||||||||
R.add_(Q_power, alpha=2.0) | ||||||||||||||||||||||||
Q_power = torch.bmm(Q_power, Q_skew) | ||||||||||||||||||||||||
R.add_(Q_power) | ||||||||||||||||||||||||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
id_mat = ( | ||||||||||||||||||||||||
torch.eye(Q_skew.shape[-1], device=Q_skew.device) | ||||||||||||||||||||||||
|
@@ -248,6 +251,10 @@ def forward(self, x): | |||||||||||||||||||||||
if required_dtype != self.weight.dtype: | ||||||||||||||||||||||||
x = x.to(self.weight.dtype) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.rows.device != self.weight.device: | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.rows = self.rows.to(self.weight.device) | ||||||||||||||||||||||||
self.cols = self.cols.to(self.weight.device) | ||||||||||||||||||||||||
|
self.oft_R[adapter_name] = OFTRotationModule( | |
r if not block_share else 1, | |
n_elements, | |
oft_block_size, | |
self.in_features, | |
coft=coft, | |
eps=eps, | |
block_share=block_share, | |
use_cayley_neumann=use_cayley_neumann, | |
num_cayley_neumann_terms=num_cayley_neumann_terms, | |
) |
As these tensors are quite small, I think the memory impact is negligible but we avoid moving them during each forward
call. I would still keep this code here to be safe (as it's a no-op if they're already on the right device).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we could also do this, but when it is initialized I think it is on cpu? (I checked the device of the base layer weight, it is on cpu), so how should I know the correct device? Best,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at this point, the weights of the base model should generally already be on the accelerator. E.g. when I run these tests with a GPU:
pytest tests/test_custom_models.py -k "test_forward_float16 and oft and not boft"
and I add this check before OFTRotationModule
is assigned:
base_layer_device = self.get_base_layer().device
assert base_layer_device.type == "cuda"
it passes. So we can add the device as input argument to OFTRotationModule
and then ensure that it's moved to the right device. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, we don't
torch.compile
any code but instead leave it to the user to choose if they want to compile or not. Testing on my machine with a 4090, I also don't see a big difference (6582 vs 6532 tokens / sec). Did you find this to help a lot in your setting? If not, I'd suggest to remove it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick reply! The thing is that I noticed it will speed it up a bit (I deliberately do not add the specific configurations like dynamic etc. to avoid error), but we can leave it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question: if we have custom kernels (for example triton kernel, is it allowed to add?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much difference does it make for you? If it's not significantly more than I observed, I'd say, let's remove it for now.
Good question. So at HF, there is a general push to make it easier to optionally make use of kernels: https://huggingface.co/blog/hello-hf-kernels. In transformers, there is already an integration, though the API is still early stage (I think there is no option for fine-grained control of what kernels to use).
For PEFT, we plan to add the same functionality but we're still working out how to exactly implement it. Once it's there, we will happily accept kernels for specific PEFT methods (and migrate existing kernels there). Until then, feel free to add your OFT kernel to the HF kernel hub, we could use it for testing the kernels integration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the information :) In another project I am building kernels for this orthgonalization, which makes it faster and more memory-efficient than the current torch.compile, but sure, we can add it later :)
Let's just remove the torch.compile for now.