Confused about dtype and precision #3987
Unanswered
davidshen84
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am a bit confused about the
dtype,param_dtypeand theprecisionparameters in some of theflax.linenmodules.According to the document of
Conv, it has these parameters to control the precision:dtype: can infer from the inputparam_dtype: default tofloat32precision: default toNone; I guess it is resolved todefault? https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.PrecisionIf I want to use
bfloat16for my model, which parameter should I use?Also, I found the return value of
nn.Conv.applyis not controlled bydtypenorprecisionbut byparam_dtype.For example, if I want to create a simple 2-layer conv net and do not set any of these parameters, then the 1st
convlayer's precision can be controlled by the input type, but the 2ndconvlayer's precision is controlled by the output type of the first layer, which is alwaysfloat32.Should I explicitly set all the
param_dtypeparameters of all the layers?Is there a way to control the precision globally? I guess it would cause trouble for some layers, like
BatchNorm,which always prefers higher precision.Do we have official guidelines on controlling the model precision and utilising hardware features?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions