Skip to content

Commit 1c0d1dc

Browse files
authored
Merge pull request #9 from adujardin/fix_cudnn_error
Fix possible CUDNN_STATUS_INTERNAL_ERROR on RTX cards
2 parents d0488dd + e816702 commit 1c0d1dc

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

wrappers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def __init__(self, model_path):
2323
self.graph = tf.Graph()
2424
with self.graph.as_default():
2525
with tf.variable_scope('prior_based_hand'):
26-
self.sess = tf.Session()
26+
config = tf.ConfigProto()
27+
config.gpu_options.allow_growth = True
28+
self.sess = tf.Session(config=config)
2729
self.input_ph = tf.placeholder(tf.uint8, [128, 128, 3])
2830
self.feed_img = \
2931
tf.cast(tf.expand_dims(self.input_ph, 0), tf.float32) / 255
@@ -93,7 +95,9 @@ def __init__(self, input_size, network_fn, model_path, net_depth, net_width):
9395
with tf.name_scope('network'):
9496
self.theta = \
9597
network_fn(self.input_ph, net_depth, net_width, training=False)[0]
96-
self.sess = tf.Session()
98+
config = tf.ConfigProto()
99+
config.gpu_options.allow_growth = True
100+
self.sess = tf.Session(config=config)
97101
tf.train.Saver().restore(self.sess, model_path)
98102

99103
def process(self, joints):

0 commit comments

Comments
 (0)