Skip to content

Commit f685be0

Browse files
committed
fixed converting boolean, keras iterator, pytorch gpu training
1 parent 5ddfa02 commit f685be0

File tree

6 files changed

+12
-6
lines changed

6 files changed

+12
-6
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ zip:
33
zip deepkit.zip deepkit/*.py deepkit/utils/*.py README.md setup.cfg setup.py
44

55
publish:
6-
rm -r dist/*
6+
rm -rf dist/*
77
python3 setup.py sdist bdist_wheel
88
python3 -m twine upload --repository-url https://upload.pypi.org/legacy/ dist/*

deepkit/experiment.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,11 @@ def floatconfig(self, path, default=None):
424424

425425
def boolconfig(self, path, default=None):
426426
v = self.get_config(path, default)
427-
return bool(v) if v is not None else default
427+
if v is None:
428+
return default
429+
if not v or v is 'false' or v is 0 or v is '0':
430+
return False
431+
return True
428432

429433
def config(self, path, default=None):
430434
v = self.get_config(path, default)
@@ -449,7 +453,7 @@ def watch_keras_model(self, model, model_input=None, name=None, is_batch=True):
449453

450454
def fit_generator(generator, *args, **kwargs):
451455
if debugger.model_input is None:
452-
debugger.set_input(generator)
456+
debugger.set_input(next(iter(generator)))
453457
return ori_fit_generator(generator, *args, **kwargs)
454458

455459
model.fit_generator = fit_generator

deepkit/keras_tf.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import PIL.Image
88
import numpy as np
9+
910
if 'keras' in sys.modules:
1011
import keras
1112
else:

deepkit/pytorch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def hook(module, input, output):
419419
module.register_forward_hook(hook)
420420

421421
def get_histogram(self, x, tensor):
422-
h = np.histogram(tensor.detach().numpy(), bins=20)
422+
h = np.histogram(tensor.cpu().detach().numpy(), bins=20)
423423
# <version><x><bins><...x><...y>, little endian
424424
# uint8|Uint32|Uint16|...Float32|...Uint32
425425
# B|L|H|...f|...L
@@ -436,7 +436,7 @@ def get_debug_data(self, x, module, output):
436436

437437
if len(output.shape) > 1:
438438
# outputs come in batch usually, so pick first
439-
sample = output[0].detach().numpy()
439+
sample = output[0].cpu().detach().numpy()
440440
if len(sample.shape) == 3:
441441
if sample.shape[0] == 3:
442442
image = PIL.Image.fromarray(get_layer_vis_square(sample))

examples/keras-cifar10/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
model.fit_generator(datagen.flow(x_train, y_train,
158158
batch_size=batch_size),
159159
epochs=epochs,
160+
steps_per_epoch=len(x_train)/batch_size,
160161
verbose=0,
161162
callbacks=callbacks,
162163
validation_data=(x_test, y_test),

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22
from setuptools import find_packages
33

4-
__version__ = '1.0.0'
4+
__version__ = '1.0.1'
55

66
setup(name='deepkit',
77
version=__version__,

0 commit comments

Comments
 (0)