Skip to content

Commit feec510

Browse files
YunYang1994YunYang1994
authored andcommitted
I hate tensorflow
1 parent 72a512e commit feec510

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

convert_weight.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
var_name = var.op.name
3535
var_name_mess = str(var_name).split('/')
3636
var_shape = var.shape
37-
if (var_name_mess[-1] not in ['weights', 'gamma', 'beta', 'moving_mean', 'moving_variance']) or \
38-
(var_name_mess[1] == 'yolo-v3' and (var_name_mess[-2] in preserve_org_names)): continue
37+
if flag.train_from_coco:
38+
if (var_name_mess[-1] not in ['weights', 'gamma', 'beta', 'moving_mean', 'moving_variance']) or \
39+
(var_name_mess[1] == 'yolo-v3' and (var_name_mess[-2] in preserve_org_names)): continue
3940
org_weights_mess.append([var_name, var_shape])
4041
print("=> " + str(var_name).ljust(50), var_shape)
4142
print()
@@ -52,7 +53,8 @@
5253
var_name_mess = str(var_name).split('/')
5354
var_shape = var.shape
5455
print(var_name_mess[0])
55-
if var_name_mess[0] in preserve_cur_names: continue
56+
if flag.train_from_coco:
57+
if var_name_mess[0] in preserve_cur_names: continue
5658
cur_weights_mess.append([var_name, var_shape])
5759
print("=> " + str(var_name).ljust(50), var_shape)
5860

0 commit comments

Comments
 (0)