Lowering of normalization calls #3706
Unanswered
pratnali-aws
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello All,
This is a newbie question:
I have the following simple model that uses BatchNormalization:
and it is lowered to
I see that
nn.Convgets lowered toconvolution.23 = f32[1,64,64,32]{3,2,1,0} convolution(Arg_6.20, Arg_5.19), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01fbutnn.BatchNormis lowered into a sequence of ops.I get a sense that it is a translation of implementation here-https://flax.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#BatchNorm.
Is there some way to avoid this? Especially, since hlo has a native
batch_norm_trainingop?Beta Was this translation helpful? Give feedback.
All reactions