-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvgg5.py
90 lines (65 loc) · 2.61 KB
/
vgg5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
from keras.layers import (
Conv2D, BatchNormalization, Activation,
MaxPooling2D, Dense, Flatten
)
from model import BaseModel
from utils import load_mnist
def vgg(input_tensor):
"""Inference function for VGGNet
y = vgg(X)
Parameters
----------
input_tensor : keras.layers.Input
Returns
----------
y : softmax output tensor
"""
def two_conv_pool(x, F1, F2, name):
x = Conv2D(F1, (3, 3), activation=None, padding='same', name='{}_conv1'.format(name))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(F2, (3, 3), activation=None, padding='same', name='{}_conv2'.format(name))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='{}_pool'.format(name))(x)
return x
def three_conv_pool(x, F1, F2, F3, name):
x = Conv2D(F1, (3, 3), activation=None, padding='same', name='{}_conv1'.format(name))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(F2, (3, 3), activation=None, padding='same', name='{}_conv2'.format(name))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(F3, (3, 3), activation=None, padding='same', name='{}_conv3'.format(name))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='{}_pool'.format(name))(x)
return x
net = input_tensor
net = two_conv_pool(net, 32, 32, "block1")
net = two_conv_pool(net, 64, 64, "block2")
net = three_conv_pool(net, 128, 128, 128, "block3")
net = three_conv_pool(net, 256, 256, 256, "block4")
net = Flatten()(net)
net = Dense(512, activation='relu', name='fc-1')(net)
net = Dense(512, activation='relu', name='fc-2')(net)
net = Dense(10, activation='softmax', name='predictions')(net)
return net
class VGGNet5(BaseModel):
def __init__(self, model_path):
super(VGGNet5, self).__init__("VGG5", vgg, model_path)
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("epoch", type=int, help="Epochs")
parser.add_argument("--model_path", default="model/vggnet5.h5", type=str, help="model path (default: model/vggnet5.h5)")
args = parser.parse_args()
return args.epoch, args.model_path
def main():
EPOCH, MODEL_PATH = arg_parser()
# (X, y)
train, valid, _ = load_mnist(samplewise_normalize=True)
vggnet = VGGNet5(MODEL_PATH)
vggnet.fit((train[0], train[1]), (valid[0], valid[1]), EPOCH)
if __name__ == '__main__':
main()