Skip to content

Commit 5f27288

Browse files
committed
added repaint unet implementation
1 parent 9f93d91 commit 5f27288

14 files changed

+1745
-42
lines changed

diffusion_models/models/diffusion.py

+4
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def forward(
193193
tuple of noise predictions and noise for random timesteps in the denoising process
194194
"""
195195
timesteps = self._sample_timesteps(x.shape[0], device=x.device)
196+
if timesteps.dim() != 1:
197+
raise ValueError("Timesteps should only have batch dimension.", timesteps.shape)
196198
time_enc = self.time_encoder.get_pos_encoding(timesteps)
199+
if time_enc.dim() != 2:
200+
raise ValueError("Time Encoding should be 2 dimensional.", time_enc.shape)
197201
# make (partially) noisy versions of batch, returns noisy version + applied noise
198202
x_t, noise = self.fwd_diff(x, timesteps)
199203
# predict the applied noise from the noisy version

diffusion_models/models/positional_encoding.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def get_pos_encoding(self, t: Int64[Tensor, "batch"]) -> Float[Tensor, "batch fe
4040
out
4141
positional encodings for batch
4242
"""
43+
if (t.dim() != 1) or (t.shape[0]==1):
44+
raise ValueError("Timesteps not the right size.", t.shape)
4345
x = self.pe[t]
4446
return x.squeeze()
4547

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
2+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7+
#
8+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16+
17+
"""
18+
Helpers to train with 16-bit precision.
19+
"""
20+
21+
import numpy as np
22+
import torch as th
23+
import torch.nn as nn
24+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
25+
26+
27+
INITIAL_LOG_LOSS_SCALE = 20.0
28+
29+
30+
def convert_module_to_f16(l):
31+
"""
32+
Convert primitive modules to float16.
33+
"""
34+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
35+
l.weight.data = l.weight.data.half()
36+
if l.bias is not None:
37+
l.bias.data = l.bias.data.half()
38+
39+
40+
def convert_module_to_f32(l):
41+
"""
42+
Convert primitive modules to float32, undoing convert_module_to_f16().
43+
"""
44+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
45+
l.weight.data = l.weight.data.float()
46+
if l.bias is not None:
47+
l.bias.data = l.bias.data.float()
48+
49+
50+
def make_master_params(param_groups_and_shapes):
51+
"""
52+
Copy model parameters into a (differently-shaped) list of full-precision
53+
parameters.
54+
"""
55+
master_params = []
56+
for param_group, shape in param_groups_and_shapes:
57+
master_param = nn.Parameter(
58+
_flatten_dense_tensors(
59+
[param.detach().float() for (_, param) in param_group]
60+
).view(shape)
61+
)
62+
master_param.requires_grad = True
63+
master_params.append(master_param)
64+
return master_params
65+
66+
67+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
68+
"""
69+
Copy the gradients from the model parameters into the master parameters
70+
from make_master_params().
71+
"""
72+
for master_param, (param_group, shape) in zip(
73+
master_params, param_groups_and_shapes
74+
):
75+
master_param.grad = _flatten_dense_tensors(
76+
[param_grad_or_zeros(param) for (_, param) in param_group]
77+
).view(shape)
78+
79+
80+
def master_params_to_model_params(param_groups_and_shapes, master_params):
81+
"""
82+
Copy the master parameter data back into the model parameters.
83+
"""
84+
# Without copying to a list, if a generator is passed, this will
85+
# silently not copy any parameters.
86+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
87+
for (_, param), unflat_master_param in zip(
88+
param_group, unflatten_master_params(param_group, master_param.view(-1))
89+
):
90+
param.detach().copy_(unflat_master_param)
91+
92+
93+
def unflatten_master_params(param_group, master_param):
94+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
95+
96+
97+
def get_param_groups_and_shapes(named_model_params):
98+
named_model_params = list(named_model_params)
99+
scalar_vector_named_params = (
100+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
101+
(-1),
102+
)
103+
matrix_named_params = (
104+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
105+
(1, -1),
106+
)
107+
return [scalar_vector_named_params, matrix_named_params]
108+
109+
110+
def master_params_to_state_dict(
111+
model, param_groups_and_shapes, master_params, use_fp16
112+
):
113+
if use_fp16:
114+
state_dict = model.state_dict()
115+
for master_param, (param_group, _) in zip(
116+
master_params, param_groups_and_shapes
117+
):
118+
for (name, _), unflat_master_param in zip(
119+
param_group, unflatten_master_params(param_group, master_param.view(-1))
120+
):
121+
assert name in state_dict
122+
state_dict[name] = unflat_master_param
123+
else:
124+
state_dict = model.state_dict()
125+
for i, (name, _value) in enumerate(model.named_parameters()):
126+
assert name in state_dict
127+
state_dict[name] = master_params[i]
128+
return state_dict
129+
130+
131+
def state_dict_to_master_params(model, state_dict, use_fp16):
132+
if use_fp16:
133+
named_model_params = [
134+
(name, state_dict[name]) for name, _ in model.named_parameters()
135+
]
136+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
137+
master_params = make_master_params(param_groups_and_shapes)
138+
else:
139+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
140+
return master_params
141+
142+
143+
def zero_master_grads(master_params):
144+
for param in master_params:
145+
param.grad = None
146+
147+
148+
def zero_grad(model_params):
149+
for param in model_params:
150+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
151+
if param.grad is not None:
152+
param.grad.detach_()
153+
param.grad.zero_()
154+
155+
156+
def param_grad_or_zeros(param):
157+
if param.grad is not None:
158+
return param.grad.data.detach()
159+
else:
160+
return th.zeros_like(param)
161+
162+
163+
class MixedPrecisionTrainer:
164+
def __init__(
165+
self,
166+
*,
167+
model,
168+
use_fp16=False,
169+
fp16_scale_growth=1e-3,
170+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
171+
):
172+
self.model = model
173+
self.use_fp16 = use_fp16
174+
self.fp16_scale_growth = fp16_scale_growth
175+
176+
self.model_params = list(self.model.parameters())
177+
self.master_params = self.model_params
178+
self.param_groups_and_shapes = None
179+
self.lg_loss_scale = initial_lg_loss_scale
180+
181+
if self.use_fp16:
182+
self.param_groups_and_shapes = get_param_groups_and_shapes(
183+
self.model.named_parameters()
184+
)
185+
self.master_params = make_master_params(self.param_groups_and_shapes)
186+
self.model.convert_to_fp16()
187+
188+
def zero_grad(self):
189+
zero_grad(self.model_params)
190+
191+
def backward(self, loss: th.Tensor):
192+
if self.use_fp16:
193+
loss_scale = 2 ** self.lg_loss_scale
194+
(loss * loss_scale).backward()
195+
else:
196+
loss.backward()
197+
198+
def optimize(self, opt: th.optim.Optimizer):
199+
if self.use_fp16:
200+
return self._optimize_fp16(opt)
201+
else:
202+
return self._optimize_normal(opt)
203+
204+
def _optimize_fp16(self, opt: th.optim.Optimizer):
205+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
206+
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
207+
if check_overflow(grad_norm):
208+
self.lg_loss_scale -= 1
209+
zero_master_grads(self.master_params)
210+
return False
211+
212+
for p in self.master_params:
213+
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
214+
opt.step()
215+
zero_master_grads(self.master_params)
216+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
217+
self.lg_loss_scale += self.fp16_scale_growth
218+
return True
219+
220+
def _optimize_normal(self, opt: th.optim.Optimizer):
221+
grad_norm, param_norm = self._compute_norms()
222+
opt.step()
223+
return True
224+
225+
def _compute_norms(self, grad_scale=1.0):
226+
grad_norm = 0.0
227+
param_norm = 0.0
228+
for p in self.master_params:
229+
with th.no_grad():
230+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
231+
if p.grad is not None:
232+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
233+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
234+
235+
def master_params_to_state_dict(self, master_params):
236+
return master_params_to_state_dict(
237+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
238+
)
239+
240+
def state_dict_to_master_params(self, state_dict):
241+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
242+
243+
244+
def check_overflow(value):
245+
return (value == float("inf")) or (value == -float("inf")) or (value != value)

0 commit comments

Comments
 (0)