Skip to content

Commit 03b53ec

Browse files
committed
add freeze function
1 parent ad7a0e7 commit 03b53ec

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

Diff for: torch_deform_conv/cnn.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,19 @@ def forward(self, x):
9393
x = F.softmax(x)
9494
return x
9595

96+
def freeze(self, skip=ConvOffset2D):
97+
for k, m in self._modules.items():
98+
if skip is None or not isinstance(m, skip):
99+
for param in m.parameters():
100+
param.requires_grad = False
101+
102+
def parameters(self):
103+
return filter(lambda p: p.requires_grad, super(DeformConvNet, self).parameters())
96104

97105
def get_cnn():
98106
return ConvNet()
99107

100-
101108
def get_deform_cnn(trainable=True):
102109
model = DeformConvNet()
103-
if trainable:
104-
return model
105-
else:
106-
for k, m in model._modules.items():
107-
if not isinstance(m, ConvOffset2D):
108-
for param in m.parameters():
109-
param.requires_grad = False
110-
return model
110+
model.freeze(skip=ConvOffset2D)
111+
return model

0 commit comments

Comments
 (0)