Description
This is more of a question inspired by the Wan2.1-14B training code rather than an issue.
Observation
First, thank you for your incredible work! After thoroughly reading the Wan technical report and studying the codebases of Wan 2.1 and DiffSynth, I’ve learned a great deal. I noticed many carefully designed tricks to improve numerical precision in various operations, such as:
- 64-bit RoPE embeddings
- Unfused FP32 normalization (pre-{sa, ca, ffn}-norm and QK norm)
- FP32 timestep embeddings
- FP32 modulation inside every DiT block
Given that your reply of this issue and readme file emphasizes Wan2.1-14B’s sensitivity to numerical precision—requiring high precision in many intermediate results to avoid artifacts—I suspect these optimizations are directly related. This leads me to a few questions:
- When you mention Wan2.1-14B’s sensitivity to precision, does this apply only during inference, or does the forward pass during training also require upcasting intermediate results?
- If these precision-enhancing tricks are disabled during training, what kinds of artifacts manifest in the trained model?
- How did you identify that these artifacts were caused by precision issues, and further pinpoint which operations (RoPE, norm, modulation) were the bottlenecks, leading you to apply precision improvements?
Motivation
This question stems from my own experience. Earlier this year, I was pretraining an ~11B-parameter T2V/I2V model using MCore. At later stages of training, as the output quality became more refined, the generated videos started developing random spatiotemporal black spot-like artifacts, which worsened over time. Despite thorough code review, we couldn't identify the root cause.
Your insights into Wan2.1-14B's precision management could provide valuable clues for diagnosing similar issues in my work. I would greatly appreciate any technical details or suggestions you could share.