Question
Why does network initialization use jax's default network initialization?
Would it be more reasonable to use He initialization or orthogonal initialization?
Checklist
- [√ ] I have read the documentation (required)
- [√ ] I have checked that there is no similar issue in the repo (required)