Skip to content

Commit 70038ed

Browse files
authored
Merge pull request #73 from yuxinc17/s2p2
Merge the state-space point process (S2P2) model implementation.
2 parents 191f0d9 + 3c298eb commit 70038ed

File tree

8 files changed

+1381
-2
lines changed

8 files changed

+1381
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ We provide reference implementations of various state-of-the-art TPP papers:
7777
| 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) |
7878
| 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) |
7979
| 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) |
80+
| 9 | NeurIPS'25 | S2P2 | Deep Continuous-Time State-Space Models for Marked Event Sequences | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) |
8081

8182

8283

easy_tpp/model/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from easy_tpp.model.torch_model.torch_nhp import NHP as TorchNHP
77
from easy_tpp.model.torch_model.torch_ode_tpp import ODETPP as TorchODETPP
88
from easy_tpp.model.torch_model.torch_rmtpp import RMTPP as TorchRMTPP
9+
from easy_tpp.model.torch_model.torch_s2p2 import S2P2 as TorchS2P2
910
from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP
1011
from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP
1112

@@ -18,4 +19,5 @@
1819
'TorchIntensityFree',
1920
'TorchODETPP',
2021
'TorchRMTPP',
21-
'TorchANHN']
22+
'TorchANHN',
23+
'TorchS2P2']
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import torch
4+
from torch import nn
5+
6+
from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus
7+
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
8+
from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH
9+
10+
11+
class ComplexEmbedding(nn.Module):
12+
def __init__(self, *args, **kwargs):
13+
super(ComplexEmbedding, self).__init__()
14+
self.real_embedding = nn.Embedding(*args, **kwargs)
15+
self.imag_embedding = nn.Embedding(*args, **kwargs)
16+
17+
self.real_embedding.weight.data *= 1e-3
18+
self.imag_embedding.weight.data *= 1e-3
19+
20+
def forward(self, x):
21+
return torch.complex(
22+
self.real_embedding(x),
23+
self.imag_embedding(x),
24+
)
25+
26+
27+
class IntensityNet(nn.Module):
28+
def __init__(self, input_dim, bias, num_event_types):
29+
super().__init__()
30+
self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias)
31+
self.softplus = ScaledSoftplus(num_event_types)
32+
33+
def forward(self, x):
34+
return self.softplus(self.intensity_net(x))
35+
36+
37+
class S2P2(TorchBaseModel):
38+
def __init__(self, model_config):
39+
"""Initialize the model
40+
41+
Args:
42+
model_config (EasyTPP.ModelConfig): config of model specs.
43+
"""
44+
super(S2P2, self).__init__(model_config)
45+
self.n_layers = model_config.num_layers
46+
self.P = model_config.model_specs["P"] # Hidden state dimension
47+
self.H = model_config.hidden_size # Residual stream dimension
48+
self.beta = model_config.model_specs.get("beta", 1.0)
49+
self.bias = model_config.model_specs.get("bias", True)
50+
self.simple_mark = model_config.model_specs.get("simple_mark", True)
51+
52+
layer_kwargs = dict(
53+
P=self.P,
54+
H=self.H,
55+
dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4),
56+
dt_init_max=model_config.model_specs.get("dt_init_max", 0.1),
57+
act_func=model_config.model_specs.get("act_func", "full_glu"),
58+
dropout_rate=model_config.model_specs.get("dropout_rate", 0.0),
59+
for_loop=model_config.model_specs.get("for_loop", False),
60+
pre_norm=model_config.model_specs.get("pre_norm", True),
61+
post_norm=model_config.model_specs.get("post_norm", False),
62+
simple_mark=self.simple_mark,
63+
relative_time=model_config.model_specs.get("relative_time", False),
64+
complex_values=model_config.model_specs.get("complex_values", True),
65+
)
66+
67+
int_forward_variant = model_config.model_specs.get("int_forward_variant", False)
68+
int_backward_variant = model_config.model_specs.get(
69+
"int_backward_variant", False
70+
)
71+
assert (
72+
int_forward_variant + int_backward_variant
73+
) <= 1 # Only one at most is allowed to be specified
74+
75+
if int_forward_variant:
76+
llh_layer = Int_Forward_LLH
77+
elif int_backward_variant:
78+
llh_layer = Int_Backward_LLH
79+
else:
80+
llh_layer = LLH
81+
82+
self.backward_variant = int_backward_variant
83+
84+
self.layers = nn.ModuleList(
85+
[
86+
llh_layer(**layer_kwargs, is_first_layer=i == 0)
87+
for i in range(self.n_layers)
88+
]
89+
)
90+
self.layers_mark_emb = nn.Embedding(
91+
self.num_event_types_pad,
92+
self.H,
93+
) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse
94+
self.layer_type_emb = None # Remove old embeddings from EasyTPP
95+
self.intensity_net = IntensityNet(
96+
input_dim=self.H,
97+
bias=self.bias,
98+
num_event_types=self.num_event_types,
99+
)
100+
101+
def _get_intensity(
102+
self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH
103+
) -> torch.Tensor:
104+
"""
105+
Assume time has already been evolved, take a vertical stack of hidden states and produce intensity.
106+
"""
107+
left_u_H = None
108+
for i, layer in enumerate(self.layers):
109+
if isinstance(
110+
x_LP, list
111+
): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor
112+
left_u_H = layer.depth_pass(
113+
x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i]
114+
)
115+
else:
116+
left_u_H = layer.depth_pass(
117+
x_LP[..., i, :],
118+
current_left_u_H=left_u_H,
119+
prev_right_u_H=right_us_BNH[i],
120+
)
121+
122+
return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H))
123+
124+
def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H):
125+
left_u_GH = None
126+
for i, layer in enumerate(self.layers):
127+
x_GP = layer.get_left_limit(
128+
right_limit_P=x_LP[..., i, :],
129+
dt_G=dt_G,
130+
next_left_u_GH=left_u_GH,
131+
current_right_u_H=right_us_H[i],
132+
)
133+
left_u_GH = layer.depth_pass(
134+
current_left_x_P=x_GP,
135+
current_left_u_H=left_u_GH,
136+
prev_right_u_H=right_us_H[i],
137+
)
138+
return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH))
139+
140+
def forward(
141+
self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs
142+
) -> Tuple[torch.Tensor, torch.Tensor]:
143+
"""
144+
Batch operations of self._forward
145+
"""
146+
t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch
147+
148+
right_xs_BNP = [] # including both t_0 and t_N
149+
left_xs_BNm1P = []
150+
right_us_BNH = [
151+
None
152+
] # Start with None as this is the 'input' to the first layer
153+
left_u_BNH, right_u_BNH = None, None
154+
alpha_BNP = self.layers_mark_emb(marks_BN)
155+
156+
for l_i, layer in enumerate(self.layers):
157+
# for each event, compute the fixed impulse via alpha_m for event i of type m
158+
init_state = (
159+
initial_state_BLP[:, l_i] if initial_state_BLP is not None else None
160+
)
161+
162+
# Returns right limit of xs and us for [t0, t1, ..., tN]
163+
# "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys)
164+
# x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N]
165+
# next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant
166+
# next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT
167+
x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward(
168+
left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state
169+
)
170+
assert next_layer_right_u_BNH is not None
171+
172+
right_xs_BNP.append(x_BNP)
173+
if next_layer_left_u_BNH is None: # NOT backward variant
174+
left_xs_BNm1P.append(
175+
layer.get_left_limit( # current and next at event level
176+
x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}]
177+
dt_BN[..., 1:].unsqueeze(
178+
-1
179+
), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}]
180+
current_right_u_H=right_u_BNH
181+
if right_u_BNH is None
182+
else right_u_BNH[
183+
..., :-1, :
184+
], # at time [t_0, t_1, ..., t_{N-1}]
185+
next_left_u_GH=left_u_BNH
186+
if left_u_BNH is None
187+
else left_u_BNH[..., 1:, :].unsqueeze(
188+
-2
189+
), # at time [t_1, t_2 ..., t_N]
190+
).squeeze(-2)
191+
)
192+
right_us_BNH.append(next_layer_right_u_BNH)
193+
194+
left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH
195+
196+
right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2)
197+
198+
ret_val = {
199+
"right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N]
200+
"right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None
201+
}
202+
203+
if left_u_BNH is not None: # backward variant
204+
ret_val["left_u_BNm1H"] = left_u_BNH[
205+
..., 1:, :
206+
] # The next inputs after last layer -> transformation of ys
207+
else: # NOT backward variant
208+
ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2)
209+
210+
# 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not)
211+
# 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us
212+
return ret_val
213+
214+
def loglike_loss(self, batch, **kwargs):
215+
# hidden states at the left and right limits around event time; note for the shift by 1 in indices:
216+
# consider a sequence [t0, t1, ..., tN]
217+
# Produces the following:
218+
# left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers
219+
# OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer
220+
#
221+
# right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers
222+
# right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers
223+
forward_results = self.forward(
224+
batch
225+
) # N minus 1 comparing with sequence lengths
226+
right_xs_BNLP, right_us_BNH = (
227+
forward_results["right_xs_BNLP"],
228+
forward_results["right_us_BNH"],
229+
)
230+
right_us_BNm1H = [
231+
None if right_u_BNH is None else right_u_BNH[:, :-1, :]
232+
for right_u_BNH in right_us_BNH
233+
]
234+
235+
ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch
236+
237+
# evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M]
238+
# left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N
239+
# Note: no need to discard the left limit of t_N because "marks_mask" will deal with it
240+
if "left_u_BNm1H" in forward_results: # ONLY backward variant
241+
intensity_B_Nm1_M = self.intensity_net(
242+
forward_results["left_u_BNm1H"]
243+
) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"]))
244+
else: # NOT backward variant
245+
intensity_B_Nm1_M = self._get_intensity(
246+
forward_results["left_xs_BNm1LP"], right_us_BNm1H
247+
)
248+
249+
# sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample]
250+
# N-1 because we only consider the intervals between N events
251+
# G for grid points
252+
dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:])
253+
254+
# evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M)
255+
intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts(
256+
right_xs_BNLP[
257+
:, :-1
258+
], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1
259+
dts_sample_B_Nm1_G,
260+
right_us_BNm1H,
261+
)
262+
263+
event_ll, non_event_ll, num_events = self.compute_loglikelihood(
264+
lambda_at_event=intensity_B_Nm1_M,
265+
lambdas_loss_samples=intensity_dts_B_Nm1_G_M,
266+
time_delta_seq=dts_BN[:, 1:],
267+
seq_mask=batch_non_pad_mask[:, 1:],
268+
type_seq=marks_BN[:, 1:],
269+
)
270+
271+
# compute loss to optimize
272+
loss = -(event_ll - non_event_ll).sum()
273+
274+
return loss, num_events
275+
276+
def compute_intensities_at_sample_times(
277+
self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs
278+
):
279+
"""Compute the intensity at sampled times, not only event times. *from the left limit*
280+
281+
Args:
282+
time_seq (tensor): [batch_size, seq_len], times seqs.
283+
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
284+
event_seq (tensor): [batch_size, seq_len], event type seqs.
285+
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps.
286+
287+
Returns:
288+
tensor: [batch_size, num_times, num_mc_sample, num_event_types],
289+
intensity at each timestamp for each event type.
290+
"""
291+
292+
compute_last_step_only = kwargs.get("compute_last_step_only", False)
293+
294+
# assume inter_event_times_BN always starts from 0
295+
_input = event_times_BN, inter_event_times_BN, marks_BN, None, None
296+
297+
# 'seq_len - 1' left limit for [t_1, ..., t_N]
298+
# 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N]
299+
300+
forward_results = self.forward(
301+
_input
302+
) # N minus 1 comparing with sequence lengths
303+
right_xs_BNLP, right_us_BNH = (
304+
forward_results["right_xs_BNLP"],
305+
forward_results["right_us_BNH"],
306+
)
307+
308+
if (
309+
compute_last_step_only
310+
): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...]
311+
right_us_B1H = [
312+
None if right_u_BNH is None else right_u_BNH[:, -1:, :]
313+
for right_u_BNH in right_us_BNH
314+
]
315+
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts(
316+
right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H
317+
) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...]
318+
else:
319+
sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts(
320+
right_xs_BNLP, sample_dtimes, right_us_BNH
321+
)
322+
return sampled_intensity # [B, N, MC, M]

easy_tpp/ssm/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)