1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch .nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU (nn .Module ):
13
+ def forward (self , x ):
14
+ return x * th .sigmoid (x )
15
+
16
+
17
+ class GroupNorm32 (nn .GroupNorm ):
18
+ def forward (self , x ):
19
+ return super ().forward (x .float ()).type (x .dtype )
20
+
21
+
22
+ def conv_nd (dims , * args , ** kwargs ):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1 :
27
+ return nn .Conv1d (* args , ** kwargs )
28
+ elif dims == 2 :
29
+ return nn .Conv2d (* args , ** kwargs )
30
+ elif dims == 3 :
31
+ return nn .Conv3d (* args , ** kwargs )
32
+ raise ValueError (f"unsupported dimensions: { dims } " )
33
+
34
+
35
+ def linear (* args , ** kwargs ):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn .Linear (* args , ** kwargs )
40
+
41
+
42
+ def avg_pool_nd (dims , * args , ** kwargs ):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1 :
47
+ return nn .AvgPool1d (* args , ** kwargs )
48
+ elif dims == 2 :
49
+ return nn .AvgPool2d (* args , ** kwargs )
50
+ elif dims == 3 :
51
+ return nn .AvgPool3d (* args , ** kwargs )
52
+ raise ValueError (f"unsupported dimensions: { dims } " )
53
+
54
+
55
+ def update_ema (target_params , source_params , rate = 0.99 ):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ , src in zip (target_params , source_params ):
65
+ targ .detach ().mul_ (rate ).add_ (src , alpha = 1 - rate )
66
+
67
+
68
+ def zero_module (module ):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module .parameters ():
73
+ p .detach ().zero_ ()
74
+ return module
75
+
76
+
77
+ def scale_module (module , scale ):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module .parameters ():
82
+ p .detach ().mul_ (scale )
83
+ return module
84
+
85
+
86
+ def mean_flat (tensor ):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor .mean (dim = list (range (1 , len (tensor .shape ))))
91
+
92
+
93
+ def normalization (channels ):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32 (32 , channels )
101
+
102
+
103
+ def timestep_embedding (timesteps , dim , max_period = 10000 ):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th .exp (
115
+ - math .log (max_period ) * th .arange (start = 0 , end = half , dtype = th .float32 ) / half
116
+ ).to (device = timesteps .device )
117
+ args = timesteps [:, None ].float () * freqs [None ]
118
+ embedding = th .cat ([th .cos (args ), th .sin (args )], dim = - 1 )
119
+ if dim % 2 :
120
+ embedding = th .cat ([embedding , th .zeros_like (embedding [:, :1 ])], dim = - 1 )
121
+ return embedding
122
+
123
+
124
+ def checkpoint (func , inputs , params , flag ):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag :
136
+ args = tuple (inputs ) + tuple (params )
137
+ return CheckpointFunction .apply (func , len (inputs ), * args )
138
+ else :
139
+ return func (* inputs )
140
+
141
+
142
+ class CheckpointFunction (th .autograd .Function ):
143
+ @staticmethod
144
+ def forward (ctx , run_function , length , * args ):
145
+ ctx .run_function = run_function
146
+ ctx .input_tensors = list (args [:length ])
147
+ ctx .input_params = list (args [length :])
148
+ with th .no_grad ():
149
+ output_tensors = ctx .run_function (* ctx .input_tensors )
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward (ctx , * output_grads ):
154
+ ctx .input_tensors = [x .detach ().requires_grad_ (True ) for x in ctx .input_tensors ]
155
+ with th .enable_grad ():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x .view_as (x ) for x in ctx .input_tensors ]
160
+ output_tensors = ctx .run_function (* shallow_copies )
161
+ input_grads = th .autograd .grad (
162
+ output_tensors ,
163
+ ctx .input_tensors + ctx .input_params ,
164
+ output_grads ,
165
+ allow_unused = True ,
166
+ )
167
+ del ctx .input_tensors
168
+ del ctx .input_params
169
+ del output_tensors
170
+ return (None , None ) + input_grads
0 commit comments