Skip to content

Commit

Permalink
fix 'Unsupported' object has no attribute 'info'
Browse files Browse the repository at this point in the history
  • Loading branch information
sshane committed May 23, 2020
1 parent accbac4 commit 2d028c8
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 98 deletions.
34 changes: 20 additions & 14 deletions examples/build_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,32 @@
from tensorflow.keras.optimizers import Adam
from utils.BASEDIR import BASEDIR


def one_hot(idx):
x = [0 for _ in range(3)]
x[idx] = 1
return x


samples = 10000
x_train = (np.random.rand(samples, 1) * 10)
# y_train = x_train.take(axis=1, indices=1) * 2
y_train = ((x_train * 1.5) + 2.5) / 2
x_train = (np.random.rand(samples, 3) * 10)
y_train = np.array([one_hot(np.argmax(sample)) for sample in x_train])

model = Sequential()
model.add(Dense(256, activation='relu', input_shape=x_train.shape[1:]))
model.add(Dense(32, activation='relu', input_shape=x_train.shape[1:]))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(128, activation='relu'))
# model.add(BatchNormalization())
model.add(Dense(64, activation='relu'))
# model.add(BatchNormalization())
model.add(Dense(1, activation='linear'))

model.compile(optimizer=Adam(lr=0.001, amsgrad=True), loss='mse')
model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=1, validation_split=0.2)
model.add(Dense(16, activation='relu'))
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(Dense(3, activation='softmax'))

model.compile(optimizer=Adam(lr=0.003, amsgrad=True), loss='categorical_crossentropy')
model.fit(x_train, y_train, batch_size=64, epochs=20, verbose=1, validation_split=0.2)

model.save('{}/examples/batch_norm.h5'.format(BASEDIR))
print(model.predict([[4.5]]))
print(model.predict([[4.5, 4.5, 9]]).tolist())
print('Saved!')
print(model.layers[0].get_weights()[0].shape)
print(model.layers[1].get_weights()[0].shape)
# exit()
2 changes: 1 addition & 1 deletion examples/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from utils.BASEDIR import BASEDIR

model = load_model('{}/examples/batch_norm.h5'.format(BASEDIR))
print(model.predict([[0.5]]))
print(model.predict([[[4.5, 4.5]]]).tolist())


# exit()
10 changes: 5 additions & 5 deletions konverter/utils/konverter_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,26 @@ def attr_map(self, classes, attr):
def get_model_info(self, model):
name = getattr(model, '_keras_api_names_v1')[0]
model_class = self.get_class_from_name(name, 'models')
model_class.info = BaseModelInfo()
if not model_class:
model_class = Models.Unsupported()
model_class.name = name
return model_class
else:
model_class.info.supported = True
model_class.info.input_shape = model.input_shape

model_class.info = BaseModelInfo()
model_class.info.input_shape = model.input_shape
model_class.info.supported = True
return model_class

def get_layer_info(self, layer):
name = getattr(layer, '_keras_api_names_v1')
if not len(name):
name = getattr(layer, '_keras_api_names')
layer_class = self.get_class_from_name(name[0], 'layers') # assume only one name
layer_class.info = BaseLayerInfo()
if not layer_class:
layer_class = Layers.Unsupported() # add activation below to raise exception with
layer_class.name = name

layer_class.info = BaseLayerInfo()
layer_class.info.is_ignored = layer_class.name in self.ignored_layers

is_linear = False
Expand Down
4 changes: 2 additions & 2 deletions misc/old/konverter/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

os.chdir(BASEDIR)

model = load_model('examples/batch_norm.h5')
model = load_model('examples/latest_maybe_good.h5')
kon = Konverter()
kon.konvert(model, 'examples/batch_norm.py', 2, verbose=True)
kon.konvert(model, 'examples/latest_maybe_good.py', 2, verbose=True)
Loading

0 comments on commit 2d028c8

Please sign in to comment.