22
33import torch
44from torch import nn
5+ from torch .nn import functional as F
56
67from easy_tpp .model .torch_model .torch_thinning import EventSampler
78from easy_tpp .utils import set_device
@@ -29,6 +30,7 @@ def __init__(self, model_config):
2930 self .gen_config = model_config .thinning
3031 self .event_sampler = None
3132 self .device = set_device (model_config .gpu )
33+ self .use_mc_samples = model_config .use_mc_samples
3234
3335 self .to (self .device )
3436
@@ -81,8 +83,7 @@ def get_logits_at_last_step(logits, batch_non_pad_mask, sample_len=None):
8183 last_logits = torch .gather (logits , dim = 1 , index = select_index ).squeeze (1 )
8284 return last_logits
8385
84- def compute_loglikelihood (self , time_delta_seq , lambda_at_event , lambdas_loss_samples , seq_mask ,
85- lambda_type_mask ):
86+ def compute_loglikelihood (self , time_delta_seq , lambda_at_event , lambdas_loss_samples , seq_mask , type_seq ):
8687 """Compute the loglikelihood of the event sequence based on Equation (8) of NHP paper.
8788
8889 Args:
@@ -92,38 +93,35 @@ def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_sa
9293 lambdas_loss_samples (tensor): [batch_size, seq_len, num_sample, num_event_types],
9394 intensity at sampling times.
9495 seq_mask (tensor): [batch_size, seq_len], sequence mask vector to mask the padded events.
95- lambda_type_mask (tensor): [batch_size, seq_len, num_event_types], type mask matrix to mask the
96- padded event types.
96+ type_seq (tensor): [batch_size, seq_len], sequence of mark ids, with padded events having a mark of self.pad_token_id
9797
9898 Returns:
9999 tuple: event loglike, non-event loglike, intensity at event with padding events masked
100100 """
101101
102- # Sum of lambda over every type and every event point
103- # [batch_size, seq_len]
104- event_lambdas = torch .sum (lambda_at_event * lambda_type_mask , dim = - 1 ) + self .eps
105-
106- # mask the pad event
107- event_lambdas = event_lambdas .masked_fill_ (~ seq_mask , 1.0 )
108-
109- # [batch_size, seq_len)
110- event_ll = torch .log (event_lambdas )
102+ # First, add an epsilon to every marked intensity for stability
103+ lambda_at_event = lambda_at_event + self .eps
104+ lambdas_loss_samples = lambdas_loss_samples + self .eps
111105
112- # Compute the big lambda integral in equation (8) of NHP paper
113- # 1 - take num_mc_sample rand points in each event interval
114- # 2 - compute its lambda value for every sample point
115- # 3 - take average of these sample points
116- # 4 - times the interval length
106+ log_marked_event_lambdas = lambda_at_event .log ()
107+ total_sampled_lambdas = lambdas_loss_samples .sum (dim = - 1 )
117108
118- # [batch_size, seq_len, n_loss_sample]
119- lambdas_total_samples = lambdas_loss_samples .sum (dim = - 1 )
109+ # Compute event LL - [batch_size, seq_len]
110+ event_ll = - F .nll_loss (
111+ log_marked_event_lambdas .permute (0 , 2 , 1 ), # mark dimension needs to come second, not third to match nll_loss specs
112+ target = type_seq ,
113+ ignore_index = self .pad_token_id , # Padded events have a pad_token_id as a value
114+ reduction = 'none' , # Does not aggregate, and replaces what would have been the log(marked intensity) with 0.
115+ )
120116
121- # interval_integral - [batch_size, seq_len]
117+ # Compute non-event LL [batch_size, seq_len]
122118 # interval_integral = length_interval * average of sampled lambda(t)
123- non_event_ll = lambdas_total_samples .mean (dim = - 1 ) * time_delta_seq * seq_mask
119+ if self .use_mc_samples :
120+ non_event_ll = total_sampled_lambdas .mean (dim = - 1 ) * time_delta_seq * seq_mask
121+ else : # Use trapezoid rule
122+ non_event_ll = 0.5 * (total_sampled_lambdas [..., 1 :] + total_sampled_lambdas [..., :- 1 ]).mean (dim = - 1 ) * time_delta_seq * seq_mask
124123
125124 num_events = torch .masked_select (event_ll , event_ll .ne (0.0 )).size ()[0 ]
126-
127125 return event_ll , non_event_ll , num_events
128126
129127 def make_dtime_loss_samples (self , time_delta_seq ):
@@ -160,37 +158,47 @@ def predict_one_step_at_every_event(self, batch):
160158 Returns:
161159 tuple: tensors of dtime and type prediction, [batch_size, seq_len].
162160 """
163- time_seq , time_delta_seq , event_seq , batch_non_pad_mask , _ , type_mask = batch
161+ time_seq , time_delta_seq , event_seq , batch_non_pad_mask , _ = batch
164162
165163 # remove the last event, as the prediction based on the last event has no label
166- # time_delta_seq should start from 1, because the first one is zero
167- time_seq , time_delta_seq , event_seq = time_seq [:, :- 1 ], time_delta_seq [:, 1 :], event_seq [:, :- 1 ]
164+ # note: the first dts is 0
165+ # [batch_size, seq_len]
166+ time_seq , time_delta_seq , event_seq = time_seq [:, :- 1 ], time_delta_seq [:, :- 1 ], event_seq [:, :- 1 ]
168167
169168 # [batch_size, seq_len]
170- dtime_boundary = time_delta_seq + self .event_sampler .dtime_max
169+ dtime_boundary = torch .max (time_delta_seq * self .event_sampler .dtime_max ,
170+ time_delta_seq + self .event_sampler .dtime_max )
171171
172172 # [batch_size, seq_len, num_sample]
173173 accepted_dtimes , weights = self .event_sampler .draw_next_time_one_step (time_seq ,
174174 time_delta_seq ,
175175 event_seq ,
176176 dtime_boundary ,
177- self .compute_intensities_at_sample_times )
177+ self .compute_intensities_at_sample_times ,
178+ compute_last_step_only = False ) # make it explicit
178179
179- # [batch_size, seq_len]
180- dtimes_pred = torch .sum (accepted_dtimes * weights , dim = - 1 )
181-
182- # [batch_size, seq_len, 1, event_num]
180+ # We should condition on each accepted time to sample event mark, but not conditioned on the expected event time.
181+ # 1. Use all accepted_dtimes to get intensity.
182+ # [batch_size, seq_len, num_sample, num_marks]
183183 intensities_at_times = self .compute_intensities_at_sample_times (time_seq ,
184184 time_delta_seq ,
185185 event_seq ,
186- dtimes_pred [:, :, None ],
187- max_steps = event_seq .size ()[1 ])
186+ accepted_dtimes )
188187
189- # [batch_size, seq_len, event_num]
190- intensities_at_times = intensities_at_times .squeeze (dim = - 2 )
188+ # 2. Normalize the intensity over last dim and then compute the weighted sum over the `num_sample` dimension.
189+ # Each of the last dimension is a categorical distribution over all marks.
190+ # [batch_size, seq_len, num_sample, num_marks]
191+ intensities_normalized = intensities_at_times / intensities_at_times .sum (dim = - 1 , keepdim = True )
191192
192- types_pred = torch .argmax (intensities_at_times , dim = - 1 )
193+ # 3. Compute weighted sum of distributions and then take argmax.
194+ # [batch_size, seq_len, num_marks]
195+ intensities_weighted = torch .einsum ('...s,...sm->...m' , weights , intensities_normalized )
193196
197+ # [batch_size, seq_len]
198+ types_pred = torch .argmax (intensities_weighted , dim = - 1 )
199+
200+ # [batch_size, seq_len]
201+ dtimes_pred = torch .sum (accepted_dtimes * weights , dim = - 1 ) # compute the expected next event time
194202 return dtimes_pred , types_pred
195203
196204 def predict_multi_step_since_last_event (self , batch , forward = False ):
0 commit comments