Skip to content

The default setting of param_scan_axis=1 hurts performance and memory consumption on GPUs #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jaro-sevcik opened this issue Mar 12, 2025 · 1 comment · May be fixed by #1394
Open
Assignees

Comments

@jaro-sevcik
Copy link

jaro-sevcik commented Mar 12, 2025

The default setting of param_scan_axis=1 causes Flax to transpose the model parameters for the scan in addition to keeping the untransposed version for optimizer state update.

Compared to param_scan_axis=0 on Llama2-7b, the extra memory consumption is ~13GB (out of ~81GB vs ~68GB) and the performance hit is ~3%.

Here is a log from a run with the default param_scan_axis=1:

completed step: 1, seconds: 0.637, TFLOP/s/device: 148.271, Tokens/s/device: 6434.697, total_weights: 8192, loss: 10.863
completed step: 2, seconds: 0.450, TFLOP/s/device: 209.671, Tokens/s/device: 9099.331, total_weights: 8192, loss: 9.510
completed step: 3, seconds: 0.449, TFLOP/s/device: 210.105, Tokens/s/device: 9118.189, total_weights: 8192, loss: 7.975
completed step: 4, seconds: 0.450, TFLOP/s/device: 209.736, Tokens/s/device: 9102.141, total_weights: 8192, loss: 6.078
completed step: 5, seconds: 0.450, TFLOP/s/device: 209.950, Tokens/s/device: 9111.435, total_weights: 8192, loss: 4.459
completed step: 6, seconds: 0.450, TFLOP/s/device: 209.673, Tokens/s/device: 9099.412, total_weights: 8192, loss: 3.240
completed step: 7, seconds: 0.450, TFLOP/s/device: 209.882, Tokens/s/device: 9108.497, total_weights: 8192, loss: 2.379
completed step: 8, seconds: 0.450, TFLOP/s/device: 209.835, Tokens/s/device: 9106.431, total_weights: 8192, loss: 1.819
completed step: 9, seconds: 0.450, TFLOP/s/device: 209.915, Tokens/s/device: 9109.915, total_weights: 8192, loss: 1.484
Output size: 40430494092, temp size: 40316314088, argument size: 40430575628, host temp size: 0, in bytes.

Here is a log with param_scan_axis=0:

completed step: 1, seconds: 0.620, TFLOP/s/device: 152.335, Tokens/s/device: 6611.047, total_weights: 8192, loss: 10.863
completed step: 2, seconds: 0.435, TFLOP/s/device: 217.068, Tokens/s/device: 9420.337, total_weights: 8192, loss: 9.510
completed step: 3, seconds: 0.434, TFLOP/s/device: 217.292, Tokens/s/device: 9430.074, total_weights: 8192, loss: 7.975
completed step: 4, seconds: 0.434, TFLOP/s/device: 217.400, Tokens/s/device: 9434.745, total_weights: 8192, loss: 6.079
completed step: 5, seconds: 0.434, TFLOP/s/device: 217.451, Tokens/s/device: 9436.983, total_weights: 8192, loss: 4.461
completed step: 6, seconds: 0.435, TFLOP/s/device: 217.040, Tokens/s/device: 9419.145, total_weights: 8192, loss: 3.241
completed step: 7, seconds: 0.434, TFLOP/s/device: 217.333, Tokens/s/device: 9431.833, total_weights: 8192, loss: 2.380
completed step: 8, seconds: 0.434, TFLOP/s/device: 217.454, Tokens/s/device: 9437.114, total_weights: 8192, loss: 1.820
completed step: 9, seconds: 0.434, TFLOP/s/device: 217.345, Tokens/s/device: 9432.376, total_weights: 8192, loss: 1.485
Output size: 40430494092, temp size: 27134419432, argument size: 40430575628, host temp size: 0, in bytes.
@shralex
Copy link
Collaborator

shralex commented May 1, 2025

@khatwanimohit can we merge your PR and close this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants