-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat auto round #1064
base: dev
Are you sure you want to change the base?
Feat auto round #1064
Conversation
@pablomlago, can you switch this target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preliminary review, I'll play with the code a bit in the meantime while you address these comments
loss, rec_loss, round_loss, b) | ||
|
||
|
||
class AdaRound(LearnedRound): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this, not sure to what
return "loss = {:.4f}".format(loss) | ||
|
||
|
||
class AutoRound(LearnedRound): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, rename to something that makes more clear the difference between this FloatToIntImpl and the one above.
src/brevitas_examples/llm/main.py
Outdated
|
||
learned_round_llm_utils = LearnedRoundLLMUtils() | ||
learned_round = AutoRound() | ||
learned_round_optimiser = LearnedRoundOptimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a fan of this entrypoint, I'd rather have strings that have a meaning, that are interpreted within the class to the correct classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I created a builder method. I don't think it's optimal still, and I'd be happy to iterate over it.
src/brevitas_examples/llm/main.py
Outdated
if args.learned_round: | ||
print("Applying learned round...") | ||
|
||
learned_round_llm_utils = LearnedRoundLLMUtils() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we shouldn't be needing this class, and try to merge everything with the vision one.
For anything that can't be merged, it is user defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My idea with this design is to decouple the logic of the PTQ algorithm, which is contained in LearnedRoundOptimizer, from anything that it is model/dataloader specific, e.g. capturing inputs/outputs for a block. This needs to be managed by specific utilities, which adhere to the interface LearnedRoundModelUtils (thus following the strategy pattern). By doing so, a potential change in the vision models would not break the LLM entrypoint, and viceversa.
@@ -0,0 +1,310 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still confused about this license (not sure where my previous comment went)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) | ||
|
||
|
||
class AdaRoundLoss(LearnedRoundLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change name to something more meaningful
@@ -52,6 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return p | |||
|
|||
|
|||
# TODO: Change name to AdaRoundSte for consistency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope don't change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove comment
@@ -92,3 +94,36 @@ def _load_from_state_dict( | |||
value_key = prefix + 'value' | |||
if config.IGNORE_MISSING_KEYS and value_key in missing_keys: | |||
missing_keys.remove(value_key) | |||
|
|||
|
|||
class AutoRoundSte(brevitas.jit.ScriptModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to merge this into the previous class?
In general, it would be nice to have a single learned round class that is general enough to support the different types of learned round
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also simplify the rest of the work in the other files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we have different float_to_int_impl (round, ceil, floor, ...), each of them with their corresponding STE (RoundSte, FloorSte, ...). Therefore, it seems sensible to me for those implementations of float_to_int_impl which involve learned parameters to be given the same treatment, and not aiming to aggregate them within a single general class, while keeping separate classes for those rounding methods which are not learnable. Moreover, the learned round methods that we have right now only have a single parameter tensor, but this does not need to be the case for future methods, so by trying to aggregated the learned round methods under a common umbrella now, we might make it more difficult in the future to integrate other methods.
|
||
|
||
class DataSaverHook: | ||
class LearnedRoundVisionUtils(LearnedRoundModelUtils): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As said somewhere else, it'd be nice to have a single class to handle both CNN/LLM or anything else really
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #1064 (comment).
self.llm_cache_state = model.config.use_cache | ||
model.config.use_cache = False | ||
|
||
def finish_model_learned_round(self, model: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be a part of some function implemented in llm utils file, not of the PTQ algorithm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is not needed in that case
disable_quant_class.enable_param_quantization(model, False) | ||
restore_return_quant_tensor(model, return_quant_tensor_state) | ||
|
||
def init_model_learned_round(self, model: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as below
model.config.use_cache = self.llm_cache_state | ||
self.llm_cache_state = None | ||
|
||
def init_cache(self) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
|
||
return (args, kwargs), outs | ||
|
||
def run_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be here. Rather, the user should be able to specify whatever interface for the input to the model.
We should provide the interface to accept a function with a certain signature, and the user decides what happens inside that function.
Possible signature
def model_forward(model, model_args, model_kwargs):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature is defined in LearnedRoundModelUtils.
self, | ||
loss: torch.Tensor, | ||
) -> torch.Tensor: | ||
return loss * 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded stuff, bad.
Why is this here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended to help prevent gradient underflow if training in float16 (if not used, there's a +1.3 perplexity increase). Agreed that it should not be hard-coded for sure.
def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool: | ||
return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer) | ||
|
||
class _DataSaverHookLLM: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this different from the CNN version? How?
|
||
def solve_learned_round_method_cls(method_type) -> LearnedRound: | ||
if method_type == "ada_round": | ||
return AdaRound |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No AdaRound/AutoRound.
LearnedRound with options Sigmoid, HardSigmoid, Linear
Reason for this PR
Implement AutoRound within Brevitas (see https://github.com/intel/auto-round, https://arxiv.org/pdf/2309.05516)
Changes Made in this PR
Incorporate AutoRound, refactored learned round methods into a single common interface.
Testing Summary
New tests for the learned round utilities, replicate results from AutoRound repo.
Risk Highlight
Checklist
dev
branch.