Skip to content

[BUG] Improvement: nan while training DenseHMM #1119

@MBradbury

Description

@MBradbury

Describe the bug
I am attempting to train a DenseHMM (code below) with 1000 sequences of length 61 observations that have a single categorical emission. But the training does not succeed. I presume that I am using the APIs incorrectly, so some direction would be appreciated.

Based on the documentation I believe I should be providing Categorical(n_categories=[n_categories]) but Categorical._initialize requires n_categories to be an int, hence Categorical(n_categories=n_categories).

Output from the script:

torch.Size([1000, 61, 1])
[1] Improvement: nan, Time: 0.6118s
[2] Improvement: nan, Time: 0.6179s
[3] Improvement: nan, Time: 0.614s
[4] Improvement: nan, Time: 0.6247s
[5] Improvement: nan, Time: 0.621s
[6] Improvement: nan, Time: 0.6254s
[7] Improvement: nan, Time: 0.612s
[8] Improvement: nan, Time: 0.6098s
[9] Improvement: nan, Time: 0.6171s
[10] Improvement: nan, Time: 1.235s
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

Version information

$ python
Python 3.12.3 (main, Sep 11 2024, 14:17:37) [GCC 13.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pomegranate
>>> pomegranate.__version__
'1.1.1'

To Reproduce

from pomegranate.hmm import DenseHMM
from pomegranate.distributions import Categorical

import torch

n_hidden_states = 123
n_categories = 61

distributions = [
    Categorical(n_categories=n_categories)
    for _ in range(n_hidden_states)
]

model = DenseHMM(
    distributions,
    max_iter=10,
    verbose=True
)

xs = torch.randint(low=0, high=n_categories, size=(1000, 61, 1))

print(xs)
print(xs.shape)

assert xs.shape[0] == 1000 # xs.shape[0] is the number of sequences
assert xs.shape[1] == n_categories
assert xs.shape[2] == 1

model.fit(xs)

print(model.predict(xs))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions