You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The default setting of param_scan_axis=1 causes Flax to transpose the model parameters for the scan in addition to keeping the untransposed version for optimizer state update.
Compared to param_scan_axis=0 on Llama2-7b, the extra memory consumption is ~13GB (out of ~81GB vs ~68GB) and the performance hit is ~3%.
Here is a log from a run with the default param_scan_axis=1:
Uh oh!
There was an error while loading. Please reload this page.
The default setting of
param_scan_axis=1
causes Flax to transpose the model parameters for the scan in addition to keeping the untransposed version for optimizer state update.Compared to
param_scan_axis=0
on Llama2-7b, the extra memory consumption is ~13GB (out of ~81GB vs ~68GB) and the performance hit is ~3%.Here is a log from a run with the default
param_scan_axis=1
:Here is a log with
param_scan_axis=0
:The text was updated successfully, but these errors were encountered: