Skip to content

Commit bfdf9aa

Browse files
Improve Mixed Precision Training Support
1 parent 8c2e42d commit bfdf9aa

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

modules.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def __init__(self, lmbda):
172172

173173
def call(self, x, training):
174174
"""Computes rate and distortion losses."""
175-
entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior, coding_rank=3, compression=False)
175+
entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior,
176+
coding_rank=3,
177+
compression=False,
178+
bottleneck_dtype=tf.keras.mixed_precision.global_policy().variable_dtype)
176179

177180
y = self.analysis_transform(x)
178181
y_hat, bits = entropy_model(y, training=training)

0 commit comments

Comments
 (0)