14
14
15
15
16
16
from dataclasses import dataclass
17
- from typing import Dict , List , Optional , Tuple , Union
17
+ from typing import Optional , Union
18
18
19
19
import torch
20
20
from diffusers .configuration_utils import ConfigMixin , register_to_config
34
34
35
35
@dataclass
36
36
class ControlNetOutput (BaseOutput ):
37
- down_block_res_samples : Tuple [torch .Tensor ]
37
+ down_block_res_samples : tuple [torch .Tensor ]
38
38
mid_block_res_sample : torch .Tensor
39
39
40
40
@@ -52,7 +52,7 @@ def __init__(
52
52
self ,
53
53
conditioning_embedding_channels : int ,
54
54
conditioning_channels : int = 3 ,
55
- block_out_channels : Tuple [int ] = (16 , 32 , 96 , 256 ),
55
+ block_out_channels : tuple [int ] = (16 , 32 , 96 , 256 ),
56
56
):
57
57
super ().__init__ ()
58
58
@@ -92,7 +92,7 @@ def __init__(
92
92
in_channels : int = 4 ,
93
93
out_channels : int = 320 ,
94
94
controlnet_conditioning_channel_order : str = "rgb" ,
95
- conditioning_embedding_out_channels : Optional [Tuple [int ]] = (16 , 32 , 96 , 256 ),
95
+ conditioning_embedding_out_channels : Optional [tuple [int ]] = (16 , 32 , 96 , 256 ),
96
96
):
97
97
super ().__init__ ()
98
98
@@ -104,7 +104,7 @@ def __init__(
104
104
105
105
@property
106
106
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
107
- def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
107
+ def attn_processors (self ) -> dict [str , AttentionProcessor ]:
108
108
r"""
109
109
Returns:
110
110
`dict` of attention processors: A dictionary containing all attention processors used in the model with
@@ -113,7 +113,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]:
113
113
# set recursively
114
114
processors = {}
115
115
116
- def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
116
+ def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : dict [str , AttentionProcessor ]):
117
117
if hasattr (module , "set_processor" ):
118
118
processors [f"{ name } .processor" ] = module .processor
119
119
@@ -128,7 +128,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
128
128
return processors
129
129
130
130
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
131
- def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
131
+ def set_attn_processor (self , processor : Union [AttentionProcessor , dict [str , AttentionProcessor ]]):
132
132
r"""
133
133
Parameters:
134
134
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
@@ -220,7 +220,7 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
220
220
# Recursively walk through all the children.
221
221
# Any children which exposes the set_attention_slice method
222
222
# gets the message
223
- def fn_recursive_set_attention_slice (module : torch .nn .Module , slice_size : List [int ]):
223
+ def fn_recursive_set_attention_slice (module : torch .nn .Module , slice_size : list [int ]):
224
224
if hasattr (module , "set_attention_slice" ):
225
225
module .set_attention_slice (slice_size .pop ())
226
226
@@ -238,7 +238,7 @@ def _set_gradient_checkpointing(self, module, value=False):
238
238
def forward (
239
239
self ,
240
240
controlnet_cond : torch .FloatTensor ,
241
- ) -> Union [ControlNetOutput , Tuple ]:
241
+ ) -> Union [ControlNetOutput , tuple ]:
242
242
# check channel order
243
243
channel_order = self .config .controlnet_conditioning_channel_order
244
244
0 commit comments