-
Notifications
You must be signed in to change notification settings - Fork 346
Pathways: reuse setup_spmd for pathways init #1248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
axlearn/common/utils_spmd.py
Outdated
pathwaysutils.initialize() | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do a return
here to avoid the else
and extra nesting below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
jax_backend=FLAGS.jax_backend, | ||
initialization_timeout=FLAGS.initialization_timeout, | ||
) | ||
setup_spmd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose setup_spmd
isn't the most appropriate name anymore given pathways, but this seems OK for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setup_controller? but happy to stick with setup_spmd for now too. I'm afraid of breaking others by renaming setup functions like these. Not sure if people have custom launch.py functions.
@@ -44,6 +44,13 @@ def setup( | |||
# Use a GSPMD-friendly PRNG implementation. | |||
jax.config.update("jax_default_prng_impl", "rbg") | |||
|
|||
if jax_backend == "proxy": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the docstring for jax_backend
?
Also clarify the expected inputs for the other args in this case? Should they always be None like in the TPU case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the docstring to explicitely call out that the other args will be ignored. They don't have to be None though.
No description provided.