Skip to content

Conversation

@zinccat
Copy link
Contributor

@zinccat zinccat commented Oct 8, 2024

Porting RNN from Linen to NNX

Fixes # (4259), #4259

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@google-cla
Copy link

google-cla bot commented Oct 8, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

return_carry: bool = False,
reverse: bool = False,
keep_order: bool = False,
unroll: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could accepts rngs to get a default that can be optionally override during __call__?

Suggested change
unroll: int = 1,
unroll: int = 1,
rngs: rnglib.Rngs | None = None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed and set default to rngs(0)

Comment on lines 675 to 688
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = self.cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y

scan = nnx.scan(
scan_fn,
in_axes=(Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)

scan_output = scan(carry, inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently self is being passed as a capture, we need to pass cell as an explicit input.

Suggested change
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = self.cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
scan = nnx.scan(
scan_fn,
in_axes=(Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
scan_output = scan(carry, inputs)
def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
state_axes = nnx.StateAxes({...: Carry})
scan = nnx.scan(
scan_fn,
in_axes=(state_axes, Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
scan_output = scan(self.cell, carry, inputs)

*,
merge_fn: Callable[[Array, Array], Array] = _concatenate,
time_major: bool = False,
return_carry: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept rngs here as well.

@cgarciae
Copy link
Collaborator

Thanks @zinccat for doing this, this is amazing!
Left a few comments.

@zinccat
Copy link
Contributor Author

zinccat commented Oct 10, 2024

Thanks for the review! Will fix it soon

@IvyZX
Copy link
Collaborator

IvyZX commented Oct 10, 2024

Thank you for making the change! You probably need to rebase to the current head to resolve the Read the Docs build error.

@zinccat
Copy link
Contributor Author

zinccat commented Oct 16, 2024

hi, any updates on this?

@IvyZX
Copy link
Collaborator

IvyZX commented Oct 16, 2024

Sorry about the delay! Merging them rn

@zinccat
Copy link
Contributor Author

zinccat commented Oct 16, 2024

thanks!

@copybara-service copybara-service bot merged commit 3bf732c into google:main Oct 16, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants