-
Notifications
You must be signed in to change notification settings - Fork 119
Description
FID estimation uses the Newton-Schulz Iterative method to compute the square root of a matrix and stops too early.
To determine when to stop computing sqrt(M), the algorithm computes the ratio of the norms of the square of the current approximation and M and stops when it drops below 1e-5.
However, this is not what we is used in FID, what matters is the trace of the resulting square root.
For instance, using a real dataset, I have modified the code to run until the difference in norm above is 1e-7 instead, and here are the values of the trace of s_matrix:
| iteration | error | trace |
|---|---|---|
| 0 | 0.3495 | 495.6085 |
| 1 | 0.2281 | 650.9324 |
| 2 | 0.1414 | 825.6144 |
| 3 | 0.0814 | 1010.0005 |
| 4 | 0.0444 | 1190.7549 |
| 5 | 0.0230 | 1355.3980 |
| 6 | 0.0115 | 1496.0165 |
| 7 | 0.0057 | 1610.0487 |
| 8 | 0.0027 | 1698.8550 |
| 9 | 0.0013 | 1764.7837 |
| 10 | 0.0006 | 1811.5501 |
| 11 | 0.0003 | 1843.3400 |
| 12 | 0.0001 | 1864.3024 |
| 13 | 5.0639e-05 | 1877.7849 |
| 14 | 2.2064e-05 | 1886.2064 |
| 15 | 9.5923e-06 | 1891.3386 |
| 16 | 4.1472e-06 | 1894.3981 |
| 17 | 1.7905e-06 | 1896.1724 |
| 18 | 7.7794e-07 | 1897.1686 |
| 19 | 3.4087e-07 | 1897.7101 |
| 20 | 1.4866e-07 | 1897.9915 |
| 21 | 6.5038e-08 | 1898.1279 |
If we stop at 1e-5 we stop at iteration 15 and use 1891.34 for the value of the trace instead of the actual trace which should be 1898.22 in this case (computed with scipy's linalg.sqrtm). This results in a FID of 90.74 instead of 76.98 (a significant difference!).
This issue was introduced in this commit (@denproc).
A potential fix could be to use the absolute difference of the trace between two iterations as an additional stopping criterion, although there could still theoretically be cases where this would not be enough. I can submit a PR with that change if you want.