-
Notifications
You must be signed in to change notification settings - Fork 72
[WIP] Migrate JAX workloads from pmap to jit #848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
priyakasimbeg
wants to merge
81
commits into
dev
Choose a base branch
from
jit_switch
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
81 commits
Select commit
Hold shift + click to select a range
bf61255
Merge pull request #825 from mlcommons/dev
priyakasimbeg 9653f18
Merge pull request #843 from mlcommons/dev
priyakasimbeg ae48ccd
Use jax.jit for sharding initial steps
rka97 eb5cac7
Use jax.jit for adamw
rka97 82977da
Pass yapf checks
rka97 99545d4
CIFAR workload sharding
rka97 018711a
librispeech_conformer now running
rka97 fbeb5f1
fix formatting
rka97 6e4e7b0
shard default
rka97 4a2c02d
start imagenet
rka97 47beba1
remove bn sync in imagenet (jit handles it automatically)
rka97 3a18f19
ImageNet-ViT also works
rka97 bd0f565
Start working on WMT. OOM error
rka97 3044efb
post-rebase, still on wmt
rka97 e301c49
cache sharding fix
rka97 e5ed97a
Merge branch 'dev' into jit_switch
priyakasimbeg 4fcf984
target_setting_algorithms sharding, compilation caching
rka97 d147e39
Update tests to correct batch size
rka97 a2b61be
yapf and isort checks..
rka97 be11c23
Merge branch 'jit_switch' of https://github.com/mlcommons/algorithmic…
priyakasimbeg e2a3b5f
Merge branch 'dev' into jit_switch
priyakasimbeg a80f4ec
switch fastmri from pmap to jit
priyakasimbeg c39ca51
migrate criteo workload
priyakasimbeg 06377d9
update utils function used for sharding conformer
priyakasimbeg 9cbe7d9
update conformer and deepspeech
priyakasimbeg c6ecd67
debugging
priyakasimbeg f35690d
debuging
priyakasimbeg 848b50c
reformatting
priyakasimbeg fb62eae
reformatting
priyakasimbeg fe3f9f0
reformatting
priyakasimbeg 004afbd
reformatting
priyakasimbeg f1db3d3
reformatting
priyakasimbeg c208cc7
sharding deepspeech
priyakasimbeg 2e4cc9e
ogbg jit migration
priyakasimbeg d3a06fc
deepspeech jit changes
priyakasimbeg 2cfa2a9
set jax to 0.5.1
priyakasimbeg 70705a7
merge
priyakasimbeg 75d6315
upgrade jax to 0.5.3
priyakasimbeg 1df0690
change bsz back
priyakasimbeg c1d0c66
formatting
priyakasimbeg 1b9466c
remove debugging statements from submission_runner.py
priyakasimbeg 7a71cf0
pyproject.toml
priyakasimbeg 9e1f337
clean up ogbg
priyakasimbeg a1d0abd
clean up ogbg
priyakasimbeg adb2b7e
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg 99caa03
clean up mnist workload.py
priyakasimbeg b14174b
refactoring & clean up
priyakasimbeg a3a9b9f
simplify changes in cifar jax
priyakasimbeg 0a340a2
small fix
priyakasimbeg 60c1cce
rename sharding utils
priyakasimbeg 1edb724
fix sharding rename
priyakasimbeg 49864fb
refactoring
priyakasimbeg 7820ac6
modifications to cifar
priyakasimbeg 0a2043c
fix
priyakasimbeg 95037bf
clean up and small fixes
priyakasimbeg e79c761
add test for sharding invariance
priyakasimbeg 110e792
fix
priyakasimbeg 9c91c65
Update pyproject.toml
priyakasimbeg 21bb997
Update workload.py
priyakasimbeg eb56919
Update workload.py
priyakasimbeg c489749
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg 1277cc2
upgrade jax
priyakasimbeg def4ac5
update dockerfile
priyakasimbeg 450cbee
remove extra installs
priyakasimbeg 89718e7
update jax version
priyakasimbeg 7dcf5af
update install commands for pytorch cpu only
priyakasimbeg 4335688
update dockerfile
priyakasimbeg 8d1fe7e
update dockerfile
priyakasimbeg 240e2e5
update dockerfile
priyakasimbeg cc8d604
update dockerfile
priyakasimbeg fe56eaf
update dockerfile
priyakasimbeg de4c38b
modify initial model_state
priyakasimbeg 5b7fb31
docker build script change
priyakasimbeg 57b8fe6
temporarily use pre-releases for jax install
priyakasimbeg e23e99a
fix to pyproject.toml
priyakasimbeg 505fab2
chnage defaults for job config script
priyakasimbeg 4acaffe
fix docker image
priyakasimbeg 3481f0e
jax deprecation fix for jax.tree_map
priyakasimbeg 447d621
try to fix jax installation
priyakasimbeg a3df78c
temporary pip install change for jax gpu nightly
priyakasimbeg 8aa3ffc
add step_time to summary df
priyakasimbeg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Utilities for dealing with sharding in JAX.""" | ||
|
||
import jax | ||
from jax.sharding import NamedSharding, PartitionSpec as P | ||
|
||
|
||
def get_replicate_sharding(): | ||
"""Returns a sharding spec that replicates data across all devices.""" | ||
mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) | ||
return NamedSharding(mesh, P()) | ||
|
||
|
||
def get_batch_dim_sharding(): | ||
"""Returns a sharding spec that shards data along the first axis.""" | ||
mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) | ||
return NamedSharding(mesh, P('batch')) | ||
|
||
|
||
def shard_along_batch_dim(x): | ||
"""Shards a tensor across all devices.""" | ||
mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) | ||
return jax.tree.map( | ||
lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x) | ||
|
||
|
||
def replicate(x): | ||
"""Replicates tensor across all devices.""" | ||
mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) | ||
return jax.tree.map( | ||
lambda x: jax.device_put(x, NamedSharding(mesh, P())), x) | ||
|
||
|
||
def display_shard_info(x: jax.Array): | ||
"""Displays shard info of a jax array.""" | ||
for shard in x.addressable_shards: | ||
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" | ||
f" {shard.replica_id}.\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did we swap out the eval_rngs with the model rng?