Skip to content

Commit f9461b5

Browse files
committed
Avoid operations on uninitialised memory in HMMs
1 parent 56ab929 commit f9461b5

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

pomegranate/hmm/dense_hmm.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,24 +199,24 @@ def add_edge(self, start, end, prob):
199199

200200
if start == self.start:
201201
if self.starts is None:
202-
self.starts = torch.empty(n, dtype=self.dtype,
203-
device=self.device) - inf
202+
self.starts = torch.full((n,), NEGINF, dtype=self.dtype,
203+
device=self.device)
204204

205205
idx = self.distributions.index(end)
206206
self.starts[idx] = math.log(prob)
207207

208208
elif end == self.end:
209209
if self.ends is None:
210-
self.ends = torch.empty(n, dtype=self.dtype,
211-
device=self.device) - inf
210+
self.ends = torch.full((n,), NEGINF, dtype=self.dtype,
211+
device=self.device)
212212

213213
idx = self.distributions.index(start)
214214
self.ends[idx] = math.log(prob)
215215

216216
else:
217217
if self.edges is None:
218-
self.edges = torch.empty((n, n), dtype=self.dtype,
219-
device=self.device) - inf
218+
self.edges = torch.full((n, n), NEGINF, dtype=self.dtype,
219+
device=self.device)
220220

221221
idx1 = self.distributions.index(start)
222222
idx2 = self.distributions.index(end)
@@ -250,8 +250,8 @@ def sample(self, n):
250250
+ "end probabilities.")
251251

252252
if self.ends is None:
253-
ends = torch.zeros(self.n_distributions, dtype=self.edges.dtype,
254-
device=self.edges.device) + float("-inf")
253+
ends = torch.full((self.n_distributions,), NEGINF, dtype=self.edges.dtype,
254+
device=self.edges.device)
255255
else:
256256
ends = self.ends
257257

@@ -454,8 +454,8 @@ def backward(self, X=None, emissions=None, priors=None):
454454
emissions = _check_inputs(self, X, emissions, priors)
455455
n, l, _ = emissions.shape
456456

457-
b = torch.zeros(l, n, self.n_distributions, dtype=self.dtype,
458-
device=self.device) + float("-inf")
457+
b = torch.full((l, n, self.n_distributions), NEGINF, dtype=self.dtype,
458+
device=self.device)
459459
b[-1] = self.ends
460460

461461
t_max = self.edges.max()

pomegranate/hmm/sparse_hmm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def unpack_edges(self, edges, starts, ends):
6060
self.starts = _cast_as_parameter(torch.log(starts))
6161

6262
if ends is None:
63-
self.ends = torch.empty(n, dtype=self.dtype, device=self.device) - inf
63+
self.ends = torch.full((n,), NEGINF, dtype=self.dtype, device=self.device)
6464
else:
6565
ends = _check_parameter(_cast_as_tensor(ends), "ends", ndim=1,
6666
shape=(n,), min_value=0., max_value=1.)
@@ -93,8 +93,8 @@ def unpack_edges(self, edges, starts, ends):
9393

9494
if ni is self.start:
9595
if self.starts is None:
96-
self.starts = torch.zeros(n, dtype=self.dtype,
97-
device=self.device) - inf
96+
self.starts = torch.full((n,), NEGINF, dtype=self.dtype,
97+
device=self.device)
9898

9999
j = self.distributions.index(nj)
100100
self.starts[j] = math.log(probability)
@@ -302,9 +302,9 @@ def sample(self, n):
302302
+ "end probabilities.")
303303

304304
if self.ends is None:
305-
ends = torch.zeros(self.n_distributions,
305+
ends = torch.full((self.n_distributions,), NEGINF,
306306
dtype=self._edge_log_probs.dtype,
307-
device=self._edge_log_probs.device) + float("-inf")
307+
device=self._edge_log_probs.device)
308308
else:
309309
ends = self.ends
310310

0 commit comments

Comments
 (0)