Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

What does this PR do?

  • Adds an optional output_sharding to standard layers just like in jax for use with explicit sharding.

@samanklesaria samanklesaria force-pushed the output_sharding branch 2 times, most recently from 1195087 to 96f0a64 Compare November 19, 2025 19:21
@samanklesaria samanklesaria changed the title Add out_sharding argument to call methods for standard layers Add out_sharding argument to call methods for layers with jax calls that support it Nov 19, 2025
with jax.set_mesh(mesh):
replicated_array = jnp.arange(4).reshape(2, 2)
sharded_array = reshard(replicated_array, P("X", None))
l = nnx.Linear(2,4, rngs=nnx.Rngs(0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: l hard to read

Comment on lines 415 to +416
dot_general_kwargs = {}
dot_general_kwargs['out_sharding'] = out_sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update this to be done in one line?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants