Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat auto round #1064

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from
Open

Feat auto round #1064

wants to merge 16 commits into from

Conversation

pablomlago
Copy link
Contributor

@pablomlago pablomlago commented Oct 20, 2024

Reason for this PR

Implement AutoRound within Brevitas (see https://github.com/intel/auto-round, https://arxiv.org/pdf/2309.05516)

Changes Made in this PR

Incorporate AutoRound, refactored learned round methods into a single common interface.

Testing Summary

New tests for the learned round utilities, replicate results from AutoRound repo.

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

@pablomlago pablomlago changed the title [DRAFT, DO NOT MERGE] Feat auto round Feat auto round Nov 4, 2024
@nickfraser
Copy link
Collaborator

@pablomlago, can you switch this target dev not master?

@pablomlago pablomlago changed the base branch from master to dev November 4, 2024 15:59
Copy link
Collaborator

@Giuseppe5 Giuseppe5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preliminary review, I'll play with the code a bit in the meantime while you address these comments

src/brevitas/core/function_wrapper/auto_round.py Outdated Show resolved Hide resolved
src/brevitas/optim/sign_sgd.py Show resolved Hide resolved
src/brevitas/optim/sign_sgd.py Show resolved Hide resolved
src/brevitas/optim/sign_sgd.py Show resolved Hide resolved
src/brevitas/optim/sign_sgd.py Show resolved Hide resolved
src/brevitas/optim/sign_sgd.py Show resolved Hide resolved
loss, rec_loss, round_loss, b)


class AdaRound(LearnedRound):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this, not sure to what

return "loss = {:.4f}".format(loss)


class AutoRound(LearnedRound):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, rename to something that makes more clear the difference between this FloatToIntImpl and the one above.


learned_round_llm_utils = LearnedRoundLLMUtils()
learned_round = AutoRound()
learned_round_optimiser = LearnedRoundOptimizer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of this entrypoint, I'd rather have strings that have a meaning, that are interpreted within the class to the correct classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I created a builder method. I don't think it's optimal still, and I'd be happy to iterate over it.

if args.learned_round:
print("Applying learned round...")

learned_round_llm_utils = LearnedRoundLLMUtils()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we shouldn't be needing this class, and try to merge everything with the vision one.
For anything that can't be merged, it is user defined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea with this design is to decouple the logic of the PTQ algorithm, which is contained in LearnedRoundOptimizer, from anything that it is model/dataloader specific, e.g. capturing inputs/outputs for a block. This needs to be managed by specific utilities, which adhere to the interface LearnedRoundModelUtils (thus following the strategy pattern). By doing so, a potential change in the vision models would not break the LLM entrypoint, and viceversa.

@@ -0,0 +1,310 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still confused about this license (not sure where my previous comment went)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t))


class AdaRoundLoss(LearnedRoundLoss):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change name to something more meaningful

@@ -52,6 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return p


# TODO: Change name to AdaRoundSte for consistency
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope don't change

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment

@@ -92,3 +94,36 @@ def _load_from_state_dict(
value_key = prefix + 'value'
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)


class AutoRoundSte(brevitas.jit.ScriptModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to merge this into the previous class?
In general, it would be nice to have a single learned round class that is general enough to support the different types of learned round

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also simplify the rest of the work in the other files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have different float_to_int_impl (round, ceil, floor, ...), each of them with their corresponding STE (RoundSte, FloorSte, ...). Therefore, it seems sensible to me for those implementations of float_to_int_impl which involve learned parameters to be given the same treatment, and not aiming to aggregate them within a single general class, while keeping separate classes for those rounding methods which are not learnable. Moreover, the learned round methods that we have right now only have a single parameter tensor, but this does not need to be the case for future methods, so by trying to aggregated the learned round methods under a common umbrella now, we might make it more difficult in the future to integrate other methods.



class DataSaverHook:
class LearnedRoundVisionUtils(LearnedRoundModelUtils):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said somewhere else, it'd be nice to have a single class to handle both CNN/LLM or anything else really

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.llm_cache_state = model.config.use_cache
model.config.use_cache = False

def finish_model_learned_round(self, model: nn.Module) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a part of some function implemented in llm utils file, not of the PTQ algorithm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not needed in that case

disable_quant_class.enable_param_quantization(model, False)
restore_return_quant_tensor(model, return_quant_tensor_state)

def init_model_learned_round(self, model: nn.Module) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below

model.config.use_cache = self.llm_cache_state
self.llm_cache_state = None

def init_cache(self) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


return (args, kwargs), outs

def run_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be here. Rather, the user should be able to specify whatever interface for the input to the model.
We should provide the interface to accept a function with a certain signature, and the user decides what happens inside that function.

Possible signature

def model_forward(model, model_args, model_kwargs):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature is defined in LearnedRoundModelUtils.

self,
loss: torch.Tensor,
) -> torch.Tensor:
return loss * 1000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded stuff, bad.
Why is this here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended to help prevent gradient underflow if training in float16 (if not used, there's a +1.3 perplexity increase). Agreed that it should not be hard-coded for sure.

def default_block_check_fn(self, module: nn.Module, module_name: str) -> bool:
return isinstance(module, LlamaDecoderLayer) or isinstance(module, OPTDecoderLayer)

class _DataSaverHookLLM:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from the CNN version? How?


def solve_learned_round_method_cls(method_type) -> LearnedRound:
if method_type == "ada_round":
return AdaRound
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No AdaRound/AutoRound.

LearnedRound with options Sigmoid, HardSigmoid, Linear

@Giuseppe5 Giuseppe5 added the next release PRs which should be merged for the next release label Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
next release PRs which should be merged for the next release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants