File tree 1 file changed +10
-9
lines changed
1 file changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -93,18 +93,19 @@ def forward(self, x):
93
93
x = F .softmax (x )
94
94
return x
95
95
96
+ def freeze (self , skip = ConvOffset2D ):
97
+ for k , m in self ._modules .items ():
98
+ if skip is None or not isinstance (m , skip ):
99
+ for param in m .parameters ():
100
+ param .requires_grad = False
101
+
102
+ def parameters (self ):
103
+ return filter (lambda p : p .requires_grad , super (DeformConvNet , self ).parameters ())
96
104
97
105
def get_cnn ():
98
106
return ConvNet ()
99
107
100
-
101
108
def get_deform_cnn (trainable = True ):
102
109
model = DeformConvNet ()
103
- if trainable :
104
- return model
105
- else :
106
- for k , m in model ._modules .items ():
107
- if not isinstance (m , ConvOffset2D ):
108
- for param in m .parameters ():
109
- param .requires_grad = False
110
- return model
110
+ model .freeze (skip = ConvOffset2D )
111
+ return model
You can’t perform that action at this time.
0 commit comments