Skip to content

WIP: nnx support demo #20546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

cgarciae
Copy link

@cgarciae cgarciae commented Nov 25, 2024

NOTE: this is just for a demo.

@DavidLandup0
Copy link
Contributor

@cgarciae Keras running on top of nnx sometime in the future? 👀

@cgarciae
Copy link
Author

@DavidLandup0 maybe :)

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae

Thanks for the PR!

It says demo. Besides making all the tests pass, what else is needed?

@@ -33,6 +34,14 @@ def _initialize(self, value):
self._layout = None
self._direct_assign(value)

@property
def _value(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can work (at least without some refactoring). self._value is an attribute used in the super class and other places I think.


class JaxLayer(nnx.Object):
def __init_subclass__(cls):
super().__init_subclass__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this? It doesn't do that by default?

@@ -538,7 +541,9 @@ def add_weight(
initializer = "zeros"
initializer = initializers.get(initializer)
with backend.name_scope(self.name, caller=self):
value = initializer(shape, dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you can't change this. With the Tensorflow backend, in some cases, we need to delay the call to initializers.

Why is this needed?



@keras_export("keras.Operation")
class Operation:
def __init_subclass__(cls):
super().__init_subclass__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? There is no super class.

@@ -97,6 +100,7 @@ def __new__(cls, *args, **kwargs):
to manually implement `get_config()`.
"""
instance = super(Operation, cls).__new__(cls)
vars(instance)['_object__state'] = nnx.object.ObjectState()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this cannot go here. This class is cross-backend and Operations are required to be stateless. This is probably what is making all tests fail.

Move to JaxLayer.

@github-project-automation github-project-automation bot moved this from Assigned Reviewer to Reviewer Requested Changes in PR Queue Mar 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Reviewer Requested Changes
Development

Successfully merging this pull request may close these issues.

4 participants