Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neg partial log likelihood loss is 0 every time the batch size is 1 #35

Closed
mayurmallya opened this issue May 23, 2024 · 5 comments
Closed
Assignees
Labels
good first issue Good for newcomers question Further information is requested

Comments

@mayurmallya
Copy link

Hi there,

Thanks for sharing this wonderful library!

I was trying to run a survival analysis using the Cox proportional hazards model and due to the GPU constraints, I have to go with the batch size of 1. And every time I run the model, I observe that the loss value is always 0 when I'm using cox.neg_partial_log_likelihood.

I looked into the implementation of the _partial_likelihood_cox and it seems that the log_denominator gets the same value as the log_hz_sorted when the batch size is 1, resulting in the loss to be 0.

I was wondering if there is a workaround for this issue, please let me know. Also attaching the link to the corresponding code

log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)

Thank you in advance!

@melodiemonod
Copy link
Collaborator

Dear @mayurmallya,

Thank you for your interest and for your question!

Summary
When the batch size is 1, the Cox partial log likelihood evaluates to 0. This behavior is due to the nature of the formula and the essence of the Cox partial log likelihood, not an issue in the code. To get meaningful results, you need to increase the sample size. The more participants you include, the more informative your partial log likelihood will be. However, there is a trade-off between efficiency and computational power.

More details
The Cox partial likelihood is constructed from the product of conditional probabilities, specifically the probability that that subject experience an event compared to all the other subjects who are still at risk of having an event. Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

Including more subjects in your sample results in a more refined and accurate ranking, which improves the estimation of the log hazards. Conversely, with only one subject, the likelihood provides no information for parameter estimation (because the subject is compared to ... no one).

In maths
From the documentation, the Cox partial log likelihood is

$$ pll = \sum_{i: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right) $$

With only one subject $i$, if $\delta_i = 0$ then the sum is null, and if $\delta_i = 1$,

$$ pll = \left(\log \theta_1 - \log\left(\theta_1 \right) \right) = 0. $$

Side note
The partial log likelihood function can be viewed as a ranking measure. As a side note, Raykar et al. (2007) demonstrated that maximizing the Cox partial likelihood is approximately equivalent to maximizing the C-index. You can find more details in their paper here.

I hope this helps,

Melodie

@melodiemonod melodiemonod self-assigned this May 24, 2024
@melodiemonod melodiemonod added good first issue Good for newcomers question Further information is requested labels May 24, 2024
@mayurmallya
Copy link
Author

Thank you @melodiemonod for the detailed answer, much appreciated!

If I have 300 samples in the dataset, I believe the ideal case scenario would be a batch size of 300 (right?). But because of the computational constraints, the batch size would be lower, let's say 10. In that case the likelihood would be calculated for subjects 1-10 and 10-20 separately right?

I'm just trying to understand what you meant by-

Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

Also, based on your experience, what batch size would you recommend? Or is it simply higher the better?

Thank you once again :)

@melodiemonod
Copy link
Collaborator

Hi @mayurmallya,

1/

I'm just trying to understand what you meant by-

Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

You cannot decompose the log likelihood as follows when dealing the cox partial likelihood
$p(Y_1, Y_2) \neq p(Y_1) \times p(Y_2)$

2/

Also, based on your experience, what batch size would you recommend? Or is it simply higher the better?

It's a tradeoff between converging faster and computational power. The primary constraint is often the memory available on your GPU or TPU. On the other hand, larger batch sizes provide more accurate estimates of the gradient, potentially leading to more stable and faster convergence. Practical guidelines advise to start small: Begin with a small batch size (e.g., 32 or 64) to ensure your model trains correctly without memory issues. Monitor performance and track that the loss and accuracy on both the training and validation sets to see if increasing the batch size improves performance. If yes and if you have the memory capacity, gradually increase the batch size to see if it speeds up training without compromising the model’s ability to generalize.

Best regards

Melodie

@mayurmallya
Copy link
Author

Thank you @ahmedhshahin and @melodiemonod
Much appreciated! :)

@Novartis Novartis deleted a comment from ahmedhshahin Jun 26, 2024
@tcoroller tcoroller pinned this issue Jul 8, 2024
@PierpaoloV
Copy link

Hi!
Thank you for your nice library!!

I was wondering:

I am working with inputs that are tensor of dimension [x, 1024], with x bein the number of patches in a histology image.
The goal is to predict survival and I would like to use a CoxPH model with this function.
Because of the nature of the input I cannot initializate a dataloader with batch size >1, so what I am doing during the training step is actually to perform n forward passes (usually 8) and then stacking all the predictions, the times, the events and then performing a single loss calculation and backward.
I see that the model somehow trains as the C-index over the test set is >0.6 but I am wondering whether I could do that step differently?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants