Skip to content

Commit cbb02be

Browse files
committed
Tweaking Newton Schulz
1 parent ca00d5f commit cbb02be

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/rules.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -654,22 +654,22 @@ function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T
654654
end
655655
end
656656

657+
function _inner_newton_schulz5(X::AbstractMatrix{T}) where T
658+
a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0))
659+
for _ in 1:5
660+
A = X * X'
661+
B = b * A + c * A * A
662+
X = a * X + B * X
663+
end
664+
X
665+
end
657666
function _newton_schulz5(G::AbstractMatrix{T}) where T
658-
a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0))
659667
X = G / (norm(G) + eps(T))
660-
transposed = size(G, 1) > size(G, 2)
661-
if transposed
662-
X = X'
663-
end
664-
for _ in 1:5
665-
A = X * X'
666-
B = b * A + c * A * A
667-
X = a * X + B * X
668-
end
669-
if transposed
670-
X = X'
668+
if size(G, 1) > size(G, 2)
669+
transpose(_inner_newton_schulz5(transpose(X)))
670+
else
671+
_inner_newton_schulz5(X)
671672
end
672-
X
673673
end
674674
_newton_schulz5(G::AbstractArray) = reshape(_newton_schulz5(reshape(G, size(G,1), :)), size(G))
675675

0 commit comments

Comments
 (0)