-
Notifications
You must be signed in to change notification settings - Fork 27
/
matrix_functions_types.py
107 lines (70 loc) · 3.59 KB
/
matrix_functions_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from dataclasses import dataclass
from distributed_shampoo.shampoo_types import AbstractDataclass
@dataclass
class PreconditionerComputationConfig(AbstractDataclass):
"""Configuration for preconditioner computation in Shampoo."""
...
@dataclass
class RootInvConfig(PreconditionerComputationConfig):
"""Base dataclass for matrix root inverse method configurations in Shampoo."""
...
@dataclass(kw_only=True)
class EigenConfig(RootInvConfig):
"""Configuration for eigendecomposition method in Shampoo.
Args:
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: 1.0)
"""
make_positive_semidefinite: bool = True
retry_double_precision: bool = True
exponent_multiplier: float = 1.0
DefaultEigenConfig = EigenConfig()
@dataclass(kw_only=True)
class CoupledNewtonConfig(RootInvConfig):
"""Configuration for coupled Newton method in Shampoo.
Args:
max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6)
"""
max_iterations: int = 100
tolerance: float = 1e-6
@dataclass(kw_only=True)
class CoupledHigherOrderConfig(RootInvConfig):
"""Configuration for coupled higher-order method in Shampoo.
Args:
rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix
before taking matrix root, where lambda_max is an upper bound on maximum eigenvalue. (Default: 0.0)
max_iterations (int): Maximum number of iterations for coupled higher order method. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled higher order method. (Default: 1e-8)
order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations),
but can take more matmuls per iteration. order=2 represents Newton's method. (Default: 3)
disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True,
since tf32 is challenging numerically here. (Default: True)
"""
rel_epsilon: float = 0.0
max_iterations: int = 100
tolerance: float = 1e-8
order: int = 3
disable_tf32: bool = True
@dataclass
class EigenvalueCorrectionConfig(PreconditionerComputationConfig):
"""Base dataclass for matrix eigenvector method configurations in Shampoo."""
...
@dataclass(kw_only=True)
class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig):
"""Configuration for eigendecomposition method used in eigenvalue corrected Shampoo.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""
retry_double_precision: bool = True
DefaultEighEigenvalueCorrectionConfig = EighEigenvalueCorrectionConfig()