Skip to content

NaN during training #5

@chrhck

Description

@chrhck

this gives NaN after a few epochs:

pdf = jammy_flows.pdf("e1+s2", "ggg+v", conditional_input_dim=4, hidden_mlp_dims_sub_pdfs="64-128-64")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [20], line 20
     17 w = data[:, 3] *data.shape[0]/ sum(data[:, 3])
     18 labels = labels.to(device)
---> 20 log_pdf, _, _ = pdf(inp, conditional_input=labels) 
     21 neg_log_loss = (-log_pdf * w).mean()
     22 neg_log_loss.backward()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/jammy_flows/flows.py:975, in pdf.forward(self, x, conditional_input, amortization_parameters, force_embedding_coordinates, force_intrinsic_coordinates)
    968 tot_log_det = torch.zeros(x.shape[0]).type_as(x)
    970 base_pos, tot_log_det=self.all_layer_inverse(x, tot_log_det, conditional_input, amortization_parameters=amortization_parameters, force_embedding_coordinates=force_embedding_coordinates, force_intrinsic_coordinates=force_intrinsic_coordinates)
    972 log_pdf = torch.distributions.MultivariateNormal(
    973     torch.zeros_like(base_pos).to(x),
    974     covariance_matrix=torch.eye(self.total_base_dim).type_as(x).to(x),
--> 975 ).log_prob(base_pos)
    978 return log_pdf + tot_log_det, log_pdf, base_pos

File ~/.local/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:210, in MultivariateNormal.log_prob(self, value)
    208 def log_prob(self, value):
    209     if self._validate_args:
--> 210         self._validate_sample(value)
    211     diff = value - self.loc
    212     M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)

File ~/.local/lib/python3.10/site-packages/torch/distributions/distribution.py:293, in Distribution._validate_sample(self, value)
    291 valid = support.check(value)
    292 if not valid.all():
--> 293     raise ValueError(
    294         "Expected value argument "
    295         f"({type(value).__name__} of shape {tuple(value.shape)}) "
    296         f"to be within the support ({repr(support)}) "
    297         f"of the distribution {repr(self)}, "
    298         f"but found invalid values:\n{value}"
    299     )

ValueError: Expected value argument (Tensor of shape (200, 3)) to be within the support (IndependentConstraint(Real(), 1)) of the distribution MultivariateNormal(loc: torch.Size([200, 3]), covariance_matrix: torch.Size([200, 3, 3])), but found invalid values:
tensor([[    nan,  0.1067, -2.2454],
        [    nan, -0.4479, -1.3993],
        [    nan,  1.1414, -0.2839],
        [    nan,  0.2720, -0.9769],
        [    nan,  0.4975,  0.5888],
        [    nan,  0.3729,  0.7307],
        [    nan, -0.5783, -0.6921],
        [    nan, -0.0498,  1.1616],
        [    nan,  1.1821, -1.6822],
        [    nan,  1.7657,  1.9744],
        [    nan, -1.0785,  1.1321],
....

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions