Skip to content

Commit 176afc7

Browse files
committed
1 parent c51da2d commit 176afc7

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

autokeras/utils/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,13 @@ def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs):
100100
try:
101101
history = func(x=x, validation_data=validation_data, **fit_kwargs)
102102
break
103-
except tf.errors.ResourceExhaustedError as e:
103+
except tf.errors.ResourceExhaustedError:
104104
if batch_size == 1:
105-
raise e
105+
print(
106+
"Not enough memory, reduced batch size is already set to 1. "
107+
"Current model will be skipped."
108+
)
109+
break
106110
batch_size //= 2
107111
print(
108112
"Not enough memory, reduce batch size to {batch_size}.".format(

autokeras/utils/utils_test.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,23 @@ def test_check_kt_version_error():
5353
)
5454

5555

56-
def test_run_with_adaptive_batch_size_raise_error():
56+
def test_run_with_adaptive_batch_size_raise_error(capfd):
5757
def func(**kwargs):
5858
raise tf.errors.ResourceExhaustedError(0, "", None)
5959

60-
with pytest.raises(tf.errors.ResourceExhaustedError):
61-
utils.run_with_adaptive_batch_size(
62-
batch_size=64,
63-
func=func,
64-
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
65-
validation_data=tf.data.Dataset.from_tensor_slices(
66-
np.random.rand(100, 1)
67-
).batch(64),
68-
)
60+
utils.run_with_adaptive_batch_size(
61+
batch_size=64,
62+
func=func,
63+
x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64),
64+
validation_data=tf.data.Dataset.from_tensor_slices(
65+
np.random.rand(100, 1)
66+
).batch(64),
67+
)
68+
_, err = capfd.readouterr()
69+
assert (
70+
err == "Not enough memory, reduced batch size is already set to 1. "
71+
"Current model will be skipped."
72+
)
6973

7074

7175
def test_get_hyperparameter_with_none_return_hp():

0 commit comments

Comments
 (0)