@@ -199,24 +199,24 @@ def add_edge(self, start, end, prob):
199199
200200 if start == self .start :
201201 if self .starts is None :
202- self .starts = torch .empty ( n , dtype = self .dtype ,
203- device = self .device ) - inf
202+ self .starts = torch .full (( n ,), NEGINF , dtype = self .dtype ,
203+ device = self .device )
204204
205205 idx = self .distributions .index (end )
206206 self .starts [idx ] = math .log (prob )
207207
208208 elif end == self .end :
209209 if self .ends is None :
210- self .ends = torch .empty ( n , dtype = self .dtype ,
211- device = self .device ) - inf
210+ self .ends = torch .full (( n ,), NEGINF , dtype = self .dtype ,
211+ device = self .device )
212212
213213 idx = self .distributions .index (start )
214214 self .ends [idx ] = math .log (prob )
215215
216216 else :
217217 if self .edges is None :
218- self .edges = torch .empty ((n , n ), dtype = self .dtype ,
219- device = self .device ) - inf
218+ self .edges = torch .full ((n , n ), NEGINF , dtype = self .dtype ,
219+ device = self .device )
220220
221221 idx1 = self .distributions .index (start )
222222 idx2 = self .distributions .index (end )
@@ -250,8 +250,8 @@ def sample(self, n):
250250 + "end probabilities." )
251251
252252 if self .ends is None :
253- ends = torch .zeros ( self .n_distributions , dtype = self .edges .dtype ,
254- device = self .edges .device ) + float ( "-inf" )
253+ ends = torch .full (( self .n_distributions ,), NEGINF , dtype = self .edges .dtype ,
254+ device = self .edges .device )
255255 else :
256256 ends = self .ends
257257
@@ -454,8 +454,8 @@ def backward(self, X=None, emissions=None, priors=None):
454454 emissions = _check_inputs (self , X , emissions , priors )
455455 n , l , _ = emissions .shape
456456
457- b = torch .zeros ( l , n , self .n_distributions , dtype = self .dtype ,
458- device = self .device ) + float ( "-inf" )
457+ b = torch .full (( l , n , self .n_distributions ), NEGINF , dtype = self .dtype ,
458+ device = self .device )
459459 b [- 1 ] = self .ends
460460
461461 t_max = self .edges .max ()
0 commit comments