Skip to content

Commit ea238c3

Browse files
committed
2 parents 26899b6 + f1583af commit ea238c3

File tree

15 files changed

+485
-560
lines changed

15 files changed

+485
-560
lines changed

easy_tpp/config_factory/model_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def __init__(self, **kwargs):
202202
self.time_emb_size = kwargs.get('time_emb_size', 16)
203203
self.num_layers = kwargs.get('num_layers', 2)
204204
self.num_heads = kwargs.get('num_heads', 2)
205-
self.mc_num_sample_per_step = kwargs.get('mc_num_sample_per_step', 20)
206205
self.sharing_param_layer = kwargs.get('sharing_param_layer', False)
207-
self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20)
206+
self.use_mc_samples = kwargs.get('use_mc_samples', True) # if using MC samples in computing log-likelihood
207+
self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20) # mc_num_sample_per_step
208208
self.dropout_rate = kwargs.get('dropout_rate', 0.0)
209209
self.use_ln = kwargs.get('use_ln', False)
210210
self.thinning = ThinningConfig.parse_from_yaml_config(kwargs.get('thinning'))
@@ -227,7 +227,6 @@ def get_yaml_config(self):
227227
'hidden_size': self.hidden_size,
228228
'time_emb_size': self.time_emb_size,
229229
'num_layers': self.num_layers,
230-
'mc_num_sample_per_step': self.mc_num_sample_per_step,
231230
'sharing_param_layer': self.sharing_param_layer,
232231
'loss_integral_num_sample_per_step': self.loss_integral_num_sample_per_step,
233232
'dropout_rate': self.dropout_rate,
@@ -265,7 +264,6 @@ def copy(self):
265264
hidden_size=self.hidden_size,
266265
time_emb_size=self.time_emb_size,
267266
num_layers=self.num_layers,
268-
mc_num_sample_per_step=self.mc_num_sample_per_step,
269267
sharing_param_layer=self.sharing_param_layer,
270268
loss_integral_num_sample_per_step=self.loss_integral_num_sample_per_step,
271269
dropout_rate=self.dropout_rate,

easy_tpp/model/torch_model/torch_attnhp.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import nn
55

6-
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention
6+
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, ScaledSoftplus
77
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
88

99

@@ -52,7 +52,7 @@ def __init__(self, model_config):
5252
if self.use_norm:
5353
self.norm = nn.LayerNorm(self.d_model)
5454
self.inten_linear = nn.Linear(self.d_model * self.n_head, self.num_event_types)
55-
self.softplus = nn.Softplus()
55+
self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta
5656
self.layer_event_emb = nn.Linear(self.d_model + self.d_time, self.d_model)
5757
self.layer_intensity = nn.Sequential(self.inten_linear, self.softplus)
5858
self.eps = torch.finfo(torch.float32).eps
@@ -151,7 +151,7 @@ def make_layer_mask(self, attention_mask):
151151
a diagonal matrix, [batch_size, seq_len, seq_len]
152152
"""
153153
# [batch_size, seq_len, seq_len]
154-
layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask)
154+
layer_mask = (torch.eye(attention_mask.size(1), device=self.device) < 1).unsqueeze(0).expand_as(attention_mask)
155155
return layer_mask
156156

157157
def make_combined_att_mask(self, attention_mask, layer_mask):
@@ -205,11 +205,11 @@ def loglike_loss(self, batch):
205205
Returns:
206206
list: loglike loss, num events.
207207
"""
208-
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch
208+
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch
209209
# 1. compute event-loglik
210210
# the prediction of last event has no label, so we proceed to the last but one
211211
# att mask => diag is False, not mask.
212-
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, 1:, :-1], time_seqs[:, 1:])
212+
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1], time_seqs[:, 1:])
213213
# [batch_size, seq_len, num_event_types]
214214
lambda_at_event = self.layer_intensity(enc_out)
215215

@@ -227,17 +227,16 @@ def loglike_loss(self, batch):
227227
time_delta_seqs[:, :-1], # not used
228228
type_seqs[:, :-1],
229229
sample_times,
230-
attention_mask=attention_mask[:, 1:, :-1])
230+
attention_mask=attention_mask[:, :-1, :-1])
231231

232232
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
233233
lambdas_loss_samples=lambda_t_sample,
234234
time_delta_seq=time_delta_seqs[:, 1:],
235235
seq_mask=batch_non_pad_mask[:, 1:],
236-
lambda_type_mask=type_mask[:, 1:])
236+
type_seq=type_seqs[:, 1:])
237237

238-
# return enc_inten to compute accuracy
238+
# compute loss to minimize
239239
loss = - (event_ll - non_event_ll).sum()
240-
241240
return loss, num_events
242241

243242
def compute_states_at_sample_times(self,
@@ -285,7 +284,7 @@ def compute_states_at_sample_times(self,
285284
encoder_output = encoder_output.permute((1, 2, 0, 3))
286285
return encoder_output
287286

288-
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs):
287+
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs):
289288
"""Compute the intensity at sampled times.
290289
291290
Args:
@@ -302,17 +301,17 @@ def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_s
302301

303302
if attention_mask is None:
304303
batch_size, seq_len = time_seqs.size()
305-
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0)
304+
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0).to(type_seqs.device)
306305
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool)
307306

308-
if sample_times.size()[1] < time_seqs.size()[1]:
307+
if sample_dtimes.size()[1] < time_seqs.size()[1]:
309308
# we pass sample_dtimes for last time step here
310309
# we do a temp solution
311310
# [batch_size, seq_len, num_samples]
312-
sample_times = time_seqs[:, :, None] + torch.tile(sample_times, [1, time_seqs.size()[1], 1])
311+
sample_dtimes = time_seqs[:, :, None] + torch.tile(sample_dtimes, [1, time_seqs.size()[1], 1])
313312

314313
# [batch_size, seq_len, num_samples, hidden_size]
315-
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_times)
314+
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_dtimes)
316315

317316
if compute_last_step_only:
318317
lambdas = self.layer_intensity(encoder_output[:, -1:, :, :])

easy_tpp/model/torch_model/torch_baselayer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@ def attention(query, key, value, mask=None, dropout=None):
1616
return torch.matmul(p_attn, value), p_attn
1717

1818

19+
class ScaledSoftplus(nn.Module):
20+
'''
21+
Use different beta for mark-specific intensities
22+
'''
23+
def __init__(self, num_marks, threshold=20.):
24+
super(ScaledSoftplus, self).__init__()
25+
self.threshold = threshold
26+
self.log_beta = nn.Parameter(torch.zeros(num_marks), requires_grad=True) # [num_marks]
27+
28+
def forward(self, x):
29+
'''
30+
:param x: [..., num_marks]
31+
'''
32+
beta = self.log_beta.exp()
33+
beta_x = beta * x
34+
return torch.where(
35+
beta_x <= self.threshold,
36+
torch.log1p(beta_x.clamp(max=math.log(1e5)).exp()) / beta,
37+
x, # if above threshold, then the transform is effectively linear
38+
)
39+
40+
1941
class MultiHeadAttention(nn.Module):
2042
def __init__(self, n_head, d_input, d_model, dropout=0.1, output_linear=False):
2143
super(MultiHeadAttention, self).__init__()

easy_tpp/model/torch_model/torch_basemodel.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import nn
5+
from torch.nn import functional as F
56

67
from easy_tpp.model.torch_model.torch_thinning import EventSampler
78
from easy_tpp.utils import set_device
@@ -29,6 +30,7 @@ def __init__(self, model_config):
2930
self.gen_config = model_config.thinning
3031
self.event_sampler = None
3132
self.device = set_device(model_config.gpu)
33+
self.use_mc_samples = model_config.use_mc_samples
3234

3335
self.to(self.device)
3436

@@ -81,8 +83,7 @@ def get_logits_at_last_step(logits, batch_non_pad_mask, sample_len=None):
8183
last_logits = torch.gather(logits, dim=1, index=select_index).squeeze(1)
8284
return last_logits
8385

84-
def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_samples, seq_mask,
85-
lambda_type_mask):
86+
def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_samples, seq_mask, type_seq):
8687
"""Compute the loglikelihood of the event sequence based on Equation (8) of NHP paper.
8788
8889
Args:
@@ -92,38 +93,35 @@ def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_sa
9293
lambdas_loss_samples (tensor): [batch_size, seq_len, num_sample, num_event_types],
9394
intensity at sampling times.
9495
seq_mask (tensor): [batch_size, seq_len], sequence mask vector to mask the padded events.
95-
lambda_type_mask (tensor): [batch_size, seq_len, num_event_types], type mask matrix to mask the
96-
padded event types.
96+
type_seq (tensor): [batch_size, seq_len], sequence of mark ids, with padded events having a mark of self.pad_token_id
9797
9898
Returns:
9999
tuple: event loglike, non-event loglike, intensity at event with padding events masked
100100
"""
101101

102-
# Sum of lambda over every type and every event point
103-
# [batch_size, seq_len]
104-
event_lambdas = torch.sum(lambda_at_event * lambda_type_mask, dim=-1) + self.eps
105-
106-
# mask the pad event
107-
event_lambdas = event_lambdas.masked_fill_(~seq_mask, 1.0)
108-
109-
# [batch_size, seq_len)
110-
event_ll = torch.log(event_lambdas)
102+
# First, add an epsilon to every marked intensity for stability
103+
lambda_at_event = lambda_at_event + self.eps
104+
lambdas_loss_samples = lambdas_loss_samples + self.eps
111105

112-
# Compute the big lambda integral in equation (8) of NHP paper
113-
# 1 - take num_mc_sample rand points in each event interval
114-
# 2 - compute its lambda value for every sample point
115-
# 3 - take average of these sample points
116-
# 4 - times the interval length
106+
log_marked_event_lambdas = lambda_at_event.log()
107+
total_sampled_lambdas = lambdas_loss_samples.sum(dim=-1)
117108

118-
# [batch_size, seq_len, n_loss_sample]
119-
lambdas_total_samples = lambdas_loss_samples.sum(dim=-1)
109+
# Compute event LL - [batch_size, seq_len]
110+
event_ll = -F.nll_loss(
111+
log_marked_event_lambdas.permute(0, 2, 1), # mark dimension needs to come second, not third to match nll_loss specs
112+
target=type_seq,
113+
ignore_index=self.pad_token_id, # Padded events have a pad_token_id as a value
114+
reduction='none', # Does not aggregate, and replaces what would have been the log(marked intensity) with 0.
115+
)
120116

121-
# interval_integral - [batch_size, seq_len]
117+
# Compute non-event LL [batch_size, seq_len]
122118
# interval_integral = length_interval * average of sampled lambda(t)
123-
non_event_ll = lambdas_total_samples.mean(dim=-1) * time_delta_seq * seq_mask
119+
if self.use_mc_samples:
120+
non_event_ll = total_sampled_lambdas.mean(dim=-1) * time_delta_seq * seq_mask
121+
else: # Use trapezoid rule
122+
non_event_ll = 0.5 * (total_sampled_lambdas[..., 1:] + total_sampled_lambdas[..., :-1]).mean(dim=-1) * time_delta_seq * seq_mask
124123

125124
num_events = torch.masked_select(event_ll, event_ll.ne(0.0)).size()[0]
126-
127125
return event_ll, non_event_ll, num_events
128126

129127
def make_dtime_loss_samples(self, time_delta_seq):
@@ -160,37 +158,47 @@ def predict_one_step_at_every_event(self, batch):
160158
Returns:
161159
tuple: tensors of dtime and type prediction, [batch_size, seq_len].
162160
"""
163-
time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _, type_mask = batch
161+
time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch
164162

165163
# remove the last event, as the prediction based on the last event has no label
166-
# time_delta_seq should start from 1, because the first one is zero
167-
time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, 1:], event_seq[:, :-1]
164+
# note: the first dts is 0
165+
# [batch_size, seq_len]
166+
time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1]
168167

169168
# [batch_size, seq_len]
170-
dtime_boundary = time_delta_seq + self.event_sampler.dtime_max
169+
dtime_boundary = torch.max(time_delta_seq * self.event_sampler.dtime_max,
170+
time_delta_seq + self.event_sampler.dtime_max)
171171

172172
# [batch_size, seq_len, num_sample]
173173
accepted_dtimes, weights = self.event_sampler.draw_next_time_one_step(time_seq,
174174
time_delta_seq,
175175
event_seq,
176176
dtime_boundary,
177-
self.compute_intensities_at_sample_times)
177+
self.compute_intensities_at_sample_times,
178+
compute_last_step_only=False) # make it explicit
178179

179-
# [batch_size, seq_len]
180-
dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1)
181-
182-
# [batch_size, seq_len, 1, event_num]
180+
# We should condition on each accepted time to sample event mark, but not conditioned on the expected event time.
181+
# 1. Use all accepted_dtimes to get intensity.
182+
# [batch_size, seq_len, num_sample, num_marks]
183183
intensities_at_times = self.compute_intensities_at_sample_times(time_seq,
184184
time_delta_seq,
185185
event_seq,
186-
dtimes_pred[:, :, None],
187-
max_steps=event_seq.size()[1])
186+
accepted_dtimes)
188187

189-
# [batch_size, seq_len, event_num]
190-
intensities_at_times = intensities_at_times.squeeze(dim=-2)
188+
# 2. Normalize the intensity over last dim and then compute the weighted sum over the `num_sample` dimension.
189+
# Each of the last dimension is a categorical distribution over all marks.
190+
# [batch_size, seq_len, num_sample, num_marks]
191+
intensities_normalized = intensities_at_times / intensities_at_times.sum(dim=-1, keepdim=True)
191192

192-
types_pred = torch.argmax(intensities_at_times, dim=-1)
193+
# 3. Compute weighted sum of distributions and then take argmax.
194+
# [batch_size, seq_len, num_marks]
195+
intensities_weighted = torch.einsum('...s,...sm->...m', weights, intensities_normalized)
193196

197+
# [batch_size, seq_len]
198+
types_pred = torch.argmax(intensities_weighted, dim=-1)
199+
200+
# [batch_size, seq_len]
201+
dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1) # compute the expected next event time
194202
return dtimes_pred, types_pred
195203

196204
def predict_multi_step_since_last_event(self, batch, forward=False):

0 commit comments

Comments
 (0)