We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8c2e42d commit bfdf9aaCopy full SHA for bfdf9aa
modules.py
@@ -172,7 +172,10 @@ def __init__(self, lmbda):
172
173
def call(self, x, training):
174
"""Computes rate and distortion losses."""
175
- entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior, coding_rank=3, compression=False)
+ entropy_model = tfc.ContinuousBatchedEntropyModel(self.prior,
176
+ coding_rank=3,
177
+ compression=False,
178
+ bottleneck_dtype=tf.keras.mixed_precision.global_policy().variable_dtype)
179
180
y = self.analysis_transform(x)
181
y_hat, bits = entropy_model(y, training=training)
0 commit comments