@@ -94,11 +94,14 @@ def mog_loss_model(n_components, d_t):
94
94
# Use logsumexp for numeric stability:
95
95
# LL = C - log(sum(exp(-d2/(2*sig^2) + log(pi_i/sig^d))))
96
96
def make_logloss (d2 , sig , pi ):
97
+ #values = pi / K.pow(sig, d_t) * K.exp(-d2 / (2 * K.square(sig)))
98
+ # return -K.log(K.sum(values, axis=-1))
99
+
97
100
# logsumexp doesn't exist in keras 2.4; simulate it
98
101
values = - d2 / (2 * K .square (sig )) + K .log (pi / K .pow (sig , d_t ))
99
102
# logsumexp(a,b,c) = log(exp(a)+exp(b)+exp(c)) = log((exp(a-k)+exp(b-k)+exp(c-k))*exp(k))
100
103
# = log((exp(a-k)+exp(b-k)+exp(c-k))) + k
101
- mx = K .max (values , axis = - 1 )
104
+ mx = K .stop_gradient ( K . max (values , axis = - 1 ) )
102
105
return - K .log (K .sum (K .exp (values - L .Reshape ((- 1 , 1 ))(mx )), axis = - 1 )) - mx
103
106
104
107
ll = L .Lambda (lambda dsp : make_logloss (* dsp ), output_shape = (1 ,))([d2 , sig , pi ])
@@ -350,7 +353,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
350
353
351
354
ll = mog_loss_model (n_components , d_t )([pi , mu , sig , t_in ])
352
355
353
- model = Model ([z_in , x_in , t_in ], [ll ])
356
+ model = Model ([z_in , x_in , t_in ], [])
354
357
model .add_loss (L .Lambda (K .mean )(ll ))
355
358
model .compile (self ._optimizer )
356
359
# TODO: do we need to give the user more control over other arguments to fit?
@@ -365,7 +368,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
365
368
self ._n_samples , self ._use_upper_bound_loss , self ._n_gradient_samples )
366
369
367
370
rl = lm ([z_in , x_in , y_in ])
368
- response_model = Model ([z_in , x_in , y_in ], [rl ])
371
+ response_model = Model ([z_in , x_in , y_in ], [])
369
372
response_model .add_loss (L .Lambda (K .mean )(rl ))
370
373
response_model .compile (self ._optimizer )
371
374
# TODO: do we need to give the user more control over other arguments to fit?
0 commit comments