@@ -231,7 +231,10 @@ def __init__(
231
231
dropout : float = 0.5 ,
232
232
activation : nn .Module = nn .SiLU ,
233
233
verbose : bool = False ,
234
- init_channels : int = 64
234
+ init_channels : int = 64 ,
235
+ attention : bool = True ,
236
+ attention_heads : int = 4 ,
237
+ attention_ff_dim : int = None
235
238
) -> None :
236
239
"""Constructor of UNet.
237
240
@@ -251,6 +254,14 @@ def __init__(
251
254
activation function to be used
252
255
verbose
253
256
verbose printing of tensor shapes for debbugging
257
+ init_channels
258
+ number of channels to initially transform the input to (usually 64, 128, ...)
259
+ attention
260
+ whether to use self-attention layers
261
+ attention_heads
262
+ number of attention heads to be used
263
+ attention_ff_dim
264
+ hidden dimension of feedforward layer in self attention module, None defaults to input dimension
254
265
"""
255
266
super ().__init__ ()
256
267
self .num_layers = num_encoding_blocks
@@ -263,6 +274,9 @@ def __init__(
263
274
self .activation = activation
264
275
self .verbose = verbose
265
276
self .init_channels = init_channels
277
+ self .attention = attention
278
+ self .attention_heads = attention_heads
279
+ self .attention_ff_dim = attention_ff_dim
266
280
267
281
self .encoding_channels , self .decoding_channels = self ._get_channel_lists (init_channels , num_encoding_blocks )
268
282
@@ -273,7 +287,10 @@ def __init__(
273
287
nn .Dropout (self .dropout )
274
288
)
275
289
276
- self .encoder = nn .ModuleList ([EncodingBlock (self .encoding_channels [i ], self .encoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose ) for i in range (len (self .encoding_channels [:- 1 ]))])
290
+ if attention :
291
+ self .encoder = nn .ModuleList ([AttentionEncodingBlock (self .encoding_channels [i ], self .encoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose , attention_heads , attention_ff_dim ) for i in range (len (self .encoding_channels [:- 1 ]))])
292
+ else :
293
+ self .encoder = nn .ModuleList ([EncodingBlock (self .encoding_channels [i ], self .encoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose ) for i in range (len (self .encoding_channels [:- 1 ]))])
277
294
278
295
self .bottleneck = nn .Sequential (
279
296
nn .Conv2d (self .encoding_channels [- 1 ], self .encoding_channels [- 1 ] * 2 , kernel_size = self .kernel_size , padding = "same" ),
@@ -286,8 +303,11 @@ def __init__(
286
303
nn .Dropout (self .dropout )
287
304
)
288
305
289
- self .decoder = nn .ModuleList ([DecodingBlock (self .decoding_channels [i ], self .decoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose ) for i in range (len (self .encoding_channels [:- 1 ]))])
290
-
306
+ if attention :
307
+ self .decoder = nn .ModuleList ([AttentionDecodingBlock (self .decoding_channels [i ], self .decoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose , attention_heads , attention_ff_dim ) for i in range (len (self .encoding_channels [:- 1 ]))])
308
+ else :
309
+ self .decoder = nn .ModuleList ([DecodingBlock (self .decoding_channels [i ], self .decoding_channels [i + 1 ], time_emb_size , kernel_size , dropout , self .activation , verbose ) for i in range (len (self .encoding_channels [:- 1 ]))])
310
+
291
311
self .out_conv = nn .Conv2d (init_channels , in_channels , kernel_size = kernel_size , padding = "same" )
292
312
293
313
def _get_channel_lists (self , start_channels , num_layers ):
@@ -367,4 +387,107 @@ def _check_sizes(self, x):
367
387
heights = [(elem .is_integer () and (elem % 2 == 0 )) for elem in heights ]
368
388
if (False in widths ) or (False in heights ):
369
389
return False
370
- return True
390
+ return True
391
+
392
+ class SelfAttention (nn .Module ):
393
+ def __init__ (
394
+ self ,
395
+ channels : int ,
396
+ num_heads : int ,
397
+ dropout : float ,
398
+ dim_feedforward : int = None ,
399
+ activation : nn .Module = nn .SiLU
400
+ ) -> None :
401
+ """Constructor of SelfAttention module.
402
+
403
+ Implementation of self-attention layer for image data.
404
+
405
+ Parameters
406
+ ----------
407
+ channels
408
+ number of input channels
409
+ num_heads
410
+ number of desired attention heads
411
+ dropout
412
+ dropout probability value
413
+ dim_feedforward
414
+ dimension of hidden layers in feedforward NN, defaults to number of input channels
415
+ activation
416
+ activation function to be used, as uninstantiated nn.Module
417
+ """
418
+ super ().__init__ ()
419
+ self .channels = channels
420
+ self .num_heads = num_heads
421
+ self .dropout = dropout
422
+ if dim_feedforward is not None :
423
+ self .dim_feedforward = dim_feedforward
424
+ else :
425
+ self .dim_feedforward = channels
426
+ self .activation = activation ()
427
+ self .attention_layer = nn .TransformerEncoderLayer (
428
+ channels ,
429
+ num_heads ,
430
+ self .dim_feedforward ,
431
+ dropout ,
432
+ self .activation ,
433
+ batch_first = True
434
+ )
435
+
436
+ def forward (self , x : Float [Tensor , "batch channels height width" ]) -> Float [Tensor , "batch channels height width" ]:
437
+ """Forward method of SelfAttention module.
438
+
439
+ Parameters
440
+ ----------
441
+ x
442
+ input tensor
443
+
444
+ Returns
445
+ -------
446
+ out
447
+ output tensor
448
+ """
449
+ # transform feature maps into vectors and put feature dimension (channels) at the end
450
+ orig_ize = x .size ()
451
+ x = x .view (- 1 , x .shape [1 ], x .shape [2 ]* x .shape [3 ]).swapaxes (1 ,2 )
452
+ x = self .attention_layer (x )
453
+ return x .swapaxes (1 ,2 ).view (* orig_ize )
454
+
455
+ class AttentionEncodingBlock (EncodingBlock ):
456
+ def __init__ (
457
+ self ,
458
+ in_channels : int ,
459
+ out_channels : int ,
460
+ time_embedding_size : int ,
461
+ kernel_size : int = 3 ,
462
+ dropout : float = 0.5 ,
463
+ activation : nn .Module = nn .SiLU ,
464
+ verbose : bool = False ,
465
+ attention_heads : int = 4 ,
466
+ attention_ff_dim : int = None
467
+ ) -> None :
468
+ super ().__init__ (in_channels , out_channels , time_embedding_size , kernel_size , dropout , activation , verbose )
469
+ self .sa = SelfAttention (out_channels , attention_heads , dropout , attention_ff_dim , activation )
470
+
471
+ def forward (self , x : Tensor , time_embedding : Tensor ) -> Tuple [Tensor , Tensor ]:
472
+ out , skip = super ().forward (x , time_embedding )
473
+ return self .sa (out ), skip
474
+
475
+ class AttentionDecodingBlock (DecodingBlock ):
476
+ def __init__ (
477
+ self ,
478
+ in_channels : int ,
479
+ out_channels : int ,
480
+ time_embedding_size : int ,
481
+ kernel_size : int = 3 ,
482
+ dropout : float = 0.5 ,
483
+ activation : nn .Module = nn .SiLU ,
484
+ verbose : bool = False ,
485
+ attention_heads : int = 4 ,
486
+ attention_ff_dim : int = None
487
+ ) -> None :
488
+ super ().__init__ (in_channels , out_channels , time_embedding_size , kernel_size , dropout , activation , verbose )
489
+ self .sa = SelfAttention (out_channels , attention_heads , dropout , attention_ff_dim , activation )
490
+
491
+ def forward (self , x : Tensor , skip : Tensor , time_embedding : Tensor = None ) -> Tensor :
492
+ out = super ().forward (x , skip , time_embedding )
493
+ return self .sa (out )
0 commit comments