-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tweak torch parameter registration mechanism #19908
base: master
Are you sure you want to change the base?
Changes from 8 commits
d303e73
c711f1a
66126fe
d0a17fa
68c2565
b7393b3
7c35d68
c3ba9b0
37ab7f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,117 @@ | ||
from typing import Iterator | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from keras.src.backend.common.stateless_scope import in_stateless_scope | ||
from keras.src.ops.operation import Operation | ||
|
||
|
||
class TorchLayer(torch.nn.Module): | ||
def _post_build(self): | ||
# Do not track variables when in a stateless scope. | ||
# The variables are not initialized. | ||
if in_stateless_scope(): | ||
return | ||
self._track_variables() | ||
|
||
def _track_variables(self): | ||
"""Adaptation layer to make sure keras.layers.Layer works well with | ||
torch.nn.Module. Currently, the main modification are on parameter/module | ||
tracking and pointing torch.nn.Module.forward() to the right keras call. | ||
|
||
Module tracking: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a section outside of the normal set of sections to be highlighted, use markdown syntax, e.g.
|
||
All sublayers are tracked as modules in Module._modules. All module level | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Text content does not need to be indented. |
||
api with recurse=True should work properly just like a torch.nn.Module. | ||
|
||
Variable tracking: | ||
Since keras has a different variable tracking mechanism, unlike modules, | ||
Modules._parameter doesn't automatically tracks variables created for torch | ||
layers. | ||
This is currently manually populated through _track_torch_params() that | ||
does following work: | ||
1. Populate all sublayers torch params by calling _track_torch_params() | ||
2. Create a single torch.nn.ParameterList() parameter with trainable, | ||
non trainable and seed generator states belongs to the current layer. | ||
|
||
Few additional points that user should be aware of: | ||
1. When torch backend is enabled KerasVariable.value is torch.nn.Parameter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure to introduce a line break (blank line) before a list or after a section title. |
||
this is not visible to torch since it is separately tracked in keras | ||
tracker. | ||
2. When torch parameter is exposed with _track_torch_params(), no copy is | ||
made to the torch parameter in keras tracker; so both keras tracker and | ||
torch module sees the same object it is just present in 2 different | ||
member variables. This also means any modification to keras variable, | ||
for instance, setting trainable is automatically populated to torch | ||
parameters. | ||
3. Since keras creates variables in a deterministic order, resulted torch | ||
parameter list will also in deterministic order with the order of | ||
trainable->non_trainable->seed_generator_states. Changing variable from | ||
trainable to non trainable won't move keras variable from one tracker to | ||
the another, so does the final populated torch_params. | ||
4. It is recommended for user to alternate variables through keras variable | ||
apis instead of alternate with torch_params since it is simpler with the | ||
keras variable api and it is backend agnostic. | ||
5. Any torch module operation should in theory works; for example | ||
state_dict() and load_state_dict() works if you want a more torch way of | ||
saving variables. | ||
6. Although not recommended, but you can use below code snippet to find the | ||
corresponding parameter in torch_params from a keras variable: | ||
parameters = [(pname, p) for pname, p in layer.named_parameters() \ | ||
if id(p) == id(variable.value)] | ||
7. For non trainable varialbes like mean and var in BatchNormalization, this | ||
is registered as part of torch_params as parameters instead of buffers. | ||
This is not really torch best practices but it is not really possible in | ||
keras to track since keras doesn't distinguish a variable that is a stats | ||
or just have gradient skipped. | ||
""" | ||
|
||
def _track_torch_params(self): | ||
for layer in self._layers: | ||
layer._track_torch_params() | ||
torch_params = [] | ||
for v in self._trainable_variables + self._non_trainable_variables: | ||
torch_params.append(v.value) | ||
for sg in self._seed_generators: | ||
torch_params.append(sg.state.value) | ||
|
||
# set torch_params attribute will have module automatically track | ||
# parameters. | ||
self.torch_params = torch.nn.ParameterDict( | ||
{variable.path: variable.value for variable in self.variables} | ||
self.torch_params = torch.nn.ParameterList(torch_params) | ||
|
||
def _all_layers_built(self): | ||
sublayers_built = all( | ||
layer._all_layers_built() for layer in self._layers | ||
) | ||
return self.built and sublayers_built | ||
|
||
def _torch_params_tracked(self): | ||
return hasattr(self, "torch_params") | ||
|
||
def named_parameters( | ||
def _populate_torch_params(self): | ||
if not self._all_layers_built(): | ||
raise RuntimeError( | ||
"Torch parameters are not tracked since all layers are not " | ||
"built. Did you forget to call model once?" | ||
) | ||
|
||
if not self._torch_params_tracked(): | ||
self._track_torch_params() | ||
|
||
def named_modules( | ||
self, | ||
memo=None, | ||
prefix: str = "", | ||
recurse: bool = True, | ||
remove_duplicate: bool = True, | ||
) -> Iterator[Tuple[str, torch.nn.Parameter]]: | ||
if not hasattr(self, "torch_params"): | ||
self._track_variables() | ||
return torch.nn.Module.named_parameters( | ||
self, prefix, recurse, remove_duplicate | ||
): | ||
# named_modules is the root of all torch parameters/module calls. | ||
self._populate_torch_params() | ||
return torch.nn.Module.named_modules( | ||
self, memo, prefix, remove_duplicate | ||
) | ||
|
||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | ||
self._populate_torch_params() | ||
return torch.nn.Module.state_dict( | ||
self, | ||
*args, | ||
destination=destination, | ||
prefix=prefix, | ||
keep_vars=keep_vars, | ||
) | ||
|
||
def load_state_dict(self, state_dict, strict=True, assign=False): | ||
self._populate_torch_params() | ||
return torch.nn.Module.load_state_dict(self, state_dict, strict, assign) | ||
|
||
def forward(self, *args, **kwargs): | ||
return Operation.__call__(self, *args, **kwargs) | ||
|
||
|
@@ -49,14 +127,39 @@ def _setattr_hook(self, name, value): | |
|
||
if not isinstance(self, TorchModuleWrapper): | ||
value = TorchModuleWrapper(value) | ||
# Torch module don't register list[Module] in its __setattr__, it uses | ||
# nn.ModuleList normally. In Keras3, we only need a way for the module | ||
# class to be tracked by torch since keras3 user can still do | ||
# self._layers to reference all layers instead of using | ||
# torch.nn.Module.named_members(). | ||
if ( | ||
isinstance(value, list) | ||
and all(isinstance(v, Layer) for v in value) | ||
and len(value) > 0 | ||
): | ||
for idx, v in enumerate(value): | ||
self.add_module(f"{name}_{idx}", v) | ||
|
||
return name, value | ||
|
||
def _post_track_variable(self, variable): | ||
if hasattr(self, "torch_params"): | ||
if variable.path not in self.torch_params: | ||
self.torch_params[variable.path] = variable.value | ||
def _post_track_variable(self, _): | ||
if self._torch_params_tracked(): | ||
if not self._all_layers_built(): | ||
raise ValueError( | ||
"Torch parameters are tracked but not all " | ||
"layers are built. This is an invalid state " | ||
"in pytorch backend and please raise an " | ||
"issue in github repo." | ||
) | ||
self._track_torch_params() | ||
|
||
def _post_untrack_variable(self, variable): | ||
if hasattr(self, "torch_params"): | ||
if variable.path in self.torch_params: | ||
self.torch_params.pop(variable.path) | ||
def _post_untrack_variable(self, _): | ||
if self._torch_params_tracked(): | ||
if not self._all_layers_built(): | ||
raise ValueError( | ||
"Torch parameters are tracked but not all " | ||
"layers are built. This is an invalid state " | ||
"in pytorch backend and please raise an " | ||
"issue in github repo." | ||
) | ||
self._track_torch_params() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use backticks around all code keywords in docstrings.