-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rotation based equalization #1061
base: dev
Are you sure you want to change the base?
Conversation
c95387d
to
66d6661
Compare
5942658
to
7f26f10
Compare
Current limitations:
|
ecf7b4f
to
cb691c4
Compare
cb691c4
to
5850118
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some basic tests for the LLM entry-point as well? Both for RMSNorm & rotations?
try: | ||
import fast_hadamard_transform | ||
except: | ||
fast_hadamard_transform = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe print a warning if fast hadamard transform can't be loaded? Should we also add an extras brevitas[fast_hadamard]
or brevitas[fast_equalize]
- whichever makes most sense?
@property | ||
def is_valid(self): | ||
return self.max_shape_srcs == self.max_shape_sinks | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments explaining what these do?
# Exit if source and sink have different sizes | ||
if max_shape_srcs != max_shape_sinks and len(region.srcs) > 0: | ||
return _no_equalize() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intuition is that this the purpose of this code is to ensure that all sources and sinks have compatible shapes - not quite following why this isn't needed anymore...
return torch.matmul(tensor, ort) | ||
|
||
|
||
def random_orthogonal_matrix(size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this come from somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__) | ||
rewriters = [ | ||
ModuleToModuleByClass( | ||
rms_cls, torch.nn.RMSNorm, normalized_shape=config.hidden_size, eps=config.rms_norm_eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch version guard required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
Other than the hadamard matrix instantiation issue, this LGTM (once above changes are applied). |
if require_fx: | ||
model = get_fx(model) | ||
with torch.no_grad(): | ||
model, guards = torch._dynamo.export(model)(**calibration_loader[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has an impact on how dataloaders are created, we don't need extra kwargs for attention_mask
Reason for this PR
Implement rotation based equalization for weights and activation.
Highlights:
Changes Made in this PR
The graph based region algorithm has been extended and generalized to support different set of supported layers based on what type of equalization is applied.
Important considerations:
Testing Summary
Added one test that checks:
Risk Highlight
Checklist
dev
branch.