11import torch
22from torch import nn , Tensor
3- from jaxtyping import Float , Int64
3+ from jaxtyping import Float , Int64 , Int
44from typing import Literal
55
6- class DiffusionModel (nn .Module ):
7- def __init__ (
8- self ,
9- backbone : nn .Module ,
10- timesteps : int ,
11- t_start : float = 0.0001 ,
12- t_end : float = 0.02 ,
13- schedule_type : Literal ["linear" , "cosine" ]= "linear"
14- ) -> None :
15- super ().__init__ ()
16- self .model = backbone
17- self .fwd_diff = ForwardDiffusion (timesteps , t_start , t_end , schedule_type )
18-
19- def forward (self , x ):
20- t = self ._sample_timestep (x .shape [0 ])
21- t = t .unsqueeze (- 1 ).type (torch .float )
22- t = self ._pos_encoding (t , self .time_dim )
23- x_t , noise = self .fwd_diff (x , t )
24- noise_pred = self .model (x_t , t )
25- return noise_pred , noise
26-
27- def _pos_encoding (self , t , channels ):
28- inv_freq = 1.0 / (10000 ** (torch .arange (0 , channels , 2 , device = self .device ).float () / channels ))
29- pos_enc_a = torch .sin (t .repeat (1 , channels // 2 ) * inv_freq )
30- pos_enc_b = torch .cos (t .repeat (1 , channels // 2 ) * inv_freq )
31- pos_enc = torch .cat ([pos_enc_a , pos_enc_b ], dim = - 1 )
32- return pos_enc
33-
34- def _sample_timestep (self , batch_size : int ) -> Int64 [Tensor , "batch" ]:
35- return torch .randint (low = 1 , high = self .fwd_diff .noise_steps , size = (batch_size ,))
36-
37-
386class ForwardDiffusion (nn .Module ):
397 """Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models).
408
@@ -82,7 +50,11 @@ def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: L
8250
8351 self .register_buffer ("noise_normal" , torch .empty ((1 )), persistent = False )
8452
85- def forward (self , x_0 : Float [Tensor , "batch channels height width" ], t : int ) -> Float [Tensor , "batch channels height width" ]:
53+ def forward (
54+ self ,
55+ x_0 : Float [Tensor , "batch channels height width" ],
56+ t : Int [Tensor , "batch" ]
57+ ) -> Float [Tensor , "batch channels height width" ]:
8658 """Forward method of ForwardDiffusion class.
8759
8860 Parameters
@@ -98,15 +70,65 @@ def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) ->
9870 tensor with applied noise according to schedule and chosen timestep
9971 """
10072 self .noise_normal = torch .randn_like (x_0 )
101- if t > self .timesteps - 1 :
73+ if True in torch . gt ( t , self .timesteps - 1 ) :
10274 raise IndexError ("t ({}) chosen larger than max. available t ({})" .format (t , self .timesteps - 1 ))
10375 sqrt_alpha_dash_t = self .sqrt_alphas_dash [t ]
10476 sqrt_one_minus_alpha_dash_t = self .sqrt_one_minus_alpha_dash [t ]
105- x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * self .noise_normal
106- return x_t
77+ x_t = sqrt_alpha_dash_t .view (- 1 , 1 , 1 , 1 ) * x_0
78+ x_t += sqrt_one_minus_alpha_dash_t .view (- 1 , 1 , 1 , 1 ) * self .noise_normal
79+ return x_t , self .noise_normal
10780
10881 def _linear_scheduler (self , timesteps , start , end ):
10982 return torch .linspace (start , end , timesteps )
11083
11184 def _cosine_scheduler (self , timesteps , start , end ):
112- raise NotImplementedError ("Cosine scheduler not implemented yet." )
85+ raise NotImplementedError ("Cosine scheduler not implemented yet." )
86+
87+ class DiffusionModel (nn .Module ):
88+ def __init__ (
89+ self ,
90+ backbone : nn .Module ,
91+ fwd_diff : ForwardDiffusion ,
92+ time_enc_dim : int = 256
93+ ) -> None :
94+ super ().__init__ ()
95+ self .model = backbone
96+ self .fwd_diff = fwd_diff
97+ self .time_enc_dim = time_enc_dim
98+
99+ self .register_buffer ("timesteps" , torch .empty ((1 )), persistent = False )
100+ self .register_buffer ("time_enc" , torch .empty ((1 )), persistent = False )
101+
102+ def forward (self , x ):
103+ # sample batch of timesteps and create batch of positional/time encodings
104+ self .timesteps = self ._sample_timesteps (x .shape [0 ])
105+
106+ # convert timesteps into time encodings
107+ self .time_enc = self ._time_encoding (self .timesteps , self .time_enc_dim )
108+
109+ # create batch of noisy images
110+ x_t , noise = self .fwd_diff (x , self .timesteps )
111+
112+ # run noisy images, conditioned on time through model
113+ noise_pred = self .model (x_t , self .time_enc )
114+ return noise_pred , noise
115+
116+ def sample (self , n ):
117+ """Sample a batch of images."""
118+ pass
119+
120+ def _time_encoding (
121+ self ,
122+ t : Int [Tensor , "batch" ],
123+ channels : int
124+ ) -> Float [Tensor , "batch time_enc_dim" ]:
125+ t = t .unsqueeze (- 1 ).type (torch .float )
126+ inv_freq = 1.0 / (10000 ** (torch .arange (0 , channels , 2 ).float () / channels ))
127+ inv_freq = inv_freq .to (t .device )
128+ pos_enc_a = torch .sin (t .repeat (1 , channels // 2 ) * inv_freq )
129+ pos_enc_b = torch .cos (t .repeat (1 , channels // 2 ) * inv_freq )
130+ pos_enc = torch .cat ([pos_enc_a , pos_enc_b ], dim = - 1 )
131+ return pos_enc
132+
133+ def _sample_timesteps (self , batch_size : int ) -> Int64 [Tensor , "batch" ]:
134+ return torch .randint (low = 1 , high = self .fwd_diff .timesteps , size = (batch_size ,))
0 commit comments