1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : unet.py
4
+ @Time : 2020/08/02 10:19:44
5
+ @Author : AngYi
6
+
7
+ @Department : QDKD shuli
8
+ @description :
9
+ '''
10
+
11
+ # here put the import lib
12
+ import torch
13
+ from torch import nn
14
+ from torchvision import models
15
+ import numpy as np
16
+
17
+ pretrained_model = models .vgg16 (pretrained = True ) #用于 FCN32x FCN16x FCN8x
18
+ pretrained_net = models .resnet34 (pretrained = True ) #用于 FCN8s
19
+
20
+ def bilinear_kernel (in_channels , out_channels , kernel_size ):
21
+ '''
22
+ return a bilinear filter tensor
23
+ 双线性卷积核,用于反卷积
24
+ '''
25
+ factor = (kernel_size + 1 ) // 2
26
+ if kernel_size % 2 == 1 :
27
+ center = factor - 1
28
+ else :
29
+ center = factor - 0.5
30
+ og = np .ogrid [:kernel_size , :kernel_size ]
31
+ filt = (1 - abs (og [0 ] - center ) / factor ) * (1 - abs (og [1 ] - center ) / factor )
32
+ weight = np .zeros ((in_channels , out_channels , kernel_size , kernel_size ), dtype = 'float32' )
33
+ weight [range (in_channels ), range (out_channels ), :, :] = filt
34
+ return torch .from_numpy (weight )
35
+
36
+ class FCN32s (nn .Module ):
37
+ def __init__ (self ,num_classes ):
38
+ super (FCN32s , self ).__init__ ()
39
+
40
+ self .feature = pretrained_model .features
41
+
42
+ self .conv = nn .Conv2d (512 ,num_classes , kernel_size = 1 , stride = 1 , padding = 0 )
43
+ self .upsample32x = nn .Sequential (
44
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
45
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
46
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
47
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
48
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
49
+ )
50
+
51
+ for m in self .modules ():
52
+ if isinstance (m ,nn .ConvTranspose2d ):
53
+ m .weight .data .copy_ (bilinear_kernel (int (m .in_channels ),int (m .out_channels ),m .kernel_size [0 ]))
54
+
55
+ def forward (self ,x ):
56
+ x = self .feature (x ) # 1/32
57
+ x = self .conv (x )
58
+ x = self .upsample32x (x )
59
+ return x
60
+
61
+ class FCN16s (nn .Module ):
62
+ def __init__ (self ,num_classes ):
63
+ super (FCN16s , self ).__init__ ()
64
+
65
+ self .feature_1 = nn .Sequential (* list (pretrained_model .features .children ())[:24 ])
66
+ self .feature_2 = nn .Sequential (* list (pretrained_model .features .children ())[24 :])
67
+
68
+ self .conv_1 = nn .Conv2d (512 ,num_classes ,kernel_size = 1 ,stride = 1 ,padding = 0 )
69
+ self .conv_2 = nn .Conv2d (512 , num_classes , kernel_size = 1 , stride = 1 , padding = 0 )
70
+
71
+ self .upsample2x = nn .ConvTranspose2d (num_classes ,num_classes ,kernel_size = 3 ,stride = 2 ,padding = 1 ,output_padding = 1 ,dilation = 1 )
72
+ self .upsample16x = nn .Sequential (
73
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
74
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
75
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
76
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
77
+ )
78
+
79
+ for m in self .modules ():
80
+ if isinstance (m ,nn .ConvTranspose2d ):
81
+ m .weight .data .copy_ (bilinear_kernel (m .in_channels ,m .out_channels ,m .kernel_size [0 ]))
82
+
83
+ def forward (self , x ):
84
+ x1 = self .feature_1 (x )
85
+ x2 = self .feature_2 (x1 )
86
+
87
+ x1 = self .conv_1 (x1 )
88
+ x2 = self .conv_2 (x2 )
89
+ x2 = self .upsample2x (x2 )
90
+ x2 += x1
91
+
92
+ x2 = self .upsample16x (x2 )
93
+ return x2
94
+
95
+
96
+ class FCN8s (nn .Module ):
97
+ def __init__ (self , num_classes ):
98
+ super (FCN8s , self ).__init__ ()
99
+
100
+ self .feature_1 = nn .Sequential (* list (pretrained_model .features .children ())[:17 ])
101
+ self .feature_2 = nn .Sequential (* list (pretrained_model .features .children ())[17 :24 ])
102
+ self .feature_3 = nn .Sequential (* list (pretrained_model .features .children ())[24 :])
103
+
104
+ self .conv_1 = nn .Conv2d (512 ,num_classes , kernel_size = 1 , stride = 1 , padding = 0 )
105
+ self .conv_2 = nn .Conv2d (256 ,num_classes ,kernel_size = 1 , stride = 1 , padding = 0 )
106
+ self .conv_3 = nn .Conv2d (512 ,num_classes , kernel_size = 1 , stride = 1 , padding = 0 )
107
+
108
+ self .upsample2x_1 = nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 , dilation = 1 )
109
+ self .upsample2x_2 = nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 ,dilation = 1 )
110
+ self .upsample8x = nn .Sequential (
111
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
112
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
113
+ nn .ConvTranspose2d (num_classes , num_classes , kernel_size = 3 , stride = 2 , padding = 1 , dilation = 1 , output_padding = 1 ),
114
+ )
115
+
116
+ for m in self .modules ():
117
+ if isinstance (m , nn .ConvTranspose2d ):
118
+ m .weight .data = bilinear_kernel (m .in_channels , m .out_channels , m .kernel_size [0 ])
119
+
120
+ def forward (self , x ):
121
+ x1 = self .feature_1 (x )
122
+ x2 = self .feature_2 (x1 )
123
+ x3 = self .feature_3 (x2 )
124
+
125
+ x2 = self .conv_1 (x2 )
126
+ x3 = self .conv_3 (x3 )
127
+ x3 = self .upsample2x_1 (x3 )
128
+ x3 += x2
129
+
130
+ x1 = self .conv_2 (x1 )
131
+ x3 = self .upsample2x_2 (x3 )
132
+ x3 += x1
133
+
134
+ x3 = self .upsample8x (x3 )
135
+ return x3
136
+
137
+ class FCN8x (nn .Module ):
138
+ def __init__ (self , num_classes ):
139
+ super (FCN8x , self ).__init__ ()
140
+
141
+ self .stage1 = nn .Sequential (* list (pretrained_net .children ())[:- 4 ]) # 第一段
142
+ self .stage2 = list (pretrained_net .children ())[- 4 ] # 第二段
143
+ self .stage3 = list (pretrained_net .children ())[- 3 ] # 第三段
144
+
145
+ self .scores1 = nn .Conv2d (512 , num_classes , 1 )
146
+ self .scores2 = nn .Conv2d (256 , num_classes , 1 )
147
+ self .scores3 = nn .Conv2d (128 , num_classes , 1 )
148
+
149
+ self .upsample_8x = nn .ConvTranspose2d (num_classes , num_classes , 16 , 8 , 4 , bias = False )
150
+ self .upsample_8x .weight .data = bilinear_kernel (num_classes , num_classes , 16 ) # 使用双线性 kernel
151
+
152
+ self .upsample_4x = nn .ConvTranspose2d (num_classes , num_classes , 4 , 2 , 1 , bias = False )
153
+ self .upsample_4x .weight .data = bilinear_kernel (num_classes , num_classes , 4 ) # 使用双线性 kernel
154
+
155
+ self .upsample_2x = nn .ConvTranspose2d (num_classes , num_classes , 4 , 2 , 1 , bias = False )
156
+ self .upsample_2x .weight .data = bilinear_kernel (num_classes , num_classes , 4 ) # 使用双线性 kernel
157
+
158
+ def forward (self , x ):
159
+ x = self .stage1 (x )
160
+ s1 = x # 1/8
161
+
162
+ x = self .stage2 (x )
163
+ s2 = x # 1/16
164
+
165
+ x = self .stage3 (x )
166
+ s3 = x # 1/32
167
+
168
+ s3 = self .scores1 (s3 )
169
+ s3 = self .upsample_2x (s3 )
170
+ s2 = self .scores2 (s2 )
171
+ s2 = s2 + s3
172
+
173
+ s1 = self .scores3 (s1 )
174
+ s2 = self .upsample_4x (s2 )
175
+ s = s1 + s2
176
+
177
+ s = self .upsample_8x (s2 )
178
+ return s
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+ if __name__ == "__main__" :
187
+ pass
188
+ from torchsummary import summary
189
+ fcn = FCN32s (3 )
190
+ # fcn.cuda(1)
191
+ # summary(fcn,(3,128,128))
192
+ # pretrained_model.cuda(1)
193
+ # summary(pretrained_model,(3,128,128))
0 commit comments