-
Notifications
You must be signed in to change notification settings - Fork 19.6k
WIP: nnx support demo #20546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
WIP: nnx support demo #20546
Conversation
@cgarciae Keras running on top of nnx sometime in the future? 👀 |
@DavidLandup0 maybe :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -33,6 +34,14 @@ def _initialize(self, value): | |||
self._layout = None | |||
self._direct_assign(value) | |||
|
|||
@property | |||
def _value(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can work (at least without some refactoring). self._value
is an attribute used in the super class and other places I think.
|
||
class JaxLayer(nnx.Object): | ||
def __init_subclass__(cls): | ||
super().__init_subclass__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this? It doesn't do that by default?
@@ -538,7 +541,9 @@ def add_weight( | |||
initializer = "zeros" | |||
initializer = initializers.get(initializer) | |||
with backend.name_scope(self.name, caller=self): | |||
value = initializer(shape, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you can't change this. With the Tensorflow backend, in some cases, we need to delay the call to initializers.
Why is this needed?
|
||
|
||
@keras_export("keras.Operation") | ||
class Operation: | ||
def __init_subclass__(cls): | ||
super().__init_subclass__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? There is no super class.
@@ -97,6 +100,7 @@ def __new__(cls, *args, **kwargs): | |||
to manually implement `get_config()`. | |||
""" | |||
instance = super(Operation, cls).__new__(cls) | |||
vars(instance)['_object__state'] = nnx.object.ObjectState() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this cannot go here. This class is cross-backend and Operations
are required to be stateless. This is probably what is making all tests fail.
Move to JaxLayer
.
NOTE: this is just for a demo.