-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_utils.py
50 lines (35 loc) · 1.04 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
'''
Created on November 26, 2017
@author: optas
'''
import tensorflow as tf
import numpy as np
def expand_scope_by_name(scope, name):
""" expand tf scope by given name.
"""
if isinstance(scope, str):
scope += '/' + name
return scope
if scope is not None:
return scope.name + '/' + name
else:
return scope
def replicate_parameter_for_all_layers(parameter, n_layers):
if parameter is not None and len(parameter) != n_layers:
if len(parameter) != 1:
raise ValueError()
parameter = np.array(parameter)
parameter = parameter.repeat(n_layers).tolist()
return parameter
def reset_tf_graph():
''' Reset's all variables of default-tf graph. Useful for jupyter.
'''
if 'sess' in globals() and sess:
sess.close()
tf.reset_default_graph()
def leaky_relu(alpha):
if not (alpha < 1 and alpha > 0):
raise ValueError()
return lambda x: tf.maximum(alpha * x, x)
def safe_log(x, eps=1e-12):
return tf.log(tf.maximum(x, eps))