|
| 1 | +from typing import List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus |
| 7 | +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel |
| 8 | +from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH |
| 9 | + |
| 10 | + |
| 11 | +class ComplexEmbedding(nn.Module): |
| 12 | + def __init__(self, *args, **kwargs): |
| 13 | + super(ComplexEmbedding, self).__init__() |
| 14 | + self.real_embedding = nn.Embedding(*args, **kwargs) |
| 15 | + self.imag_embedding = nn.Embedding(*args, **kwargs) |
| 16 | + |
| 17 | + self.real_embedding.weight.data *= 1e-3 |
| 18 | + self.imag_embedding.weight.data *= 1e-3 |
| 19 | + |
| 20 | + def forward(self, x): |
| 21 | + return torch.complex( |
| 22 | + self.real_embedding(x), |
| 23 | + self.imag_embedding(x), |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +class IntensityNet(nn.Module): |
| 28 | + def __init__(self, input_dim, bias, num_event_types): |
| 29 | + super().__init__() |
| 30 | + self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias) |
| 31 | + self.softplus = ScaledSoftplus(num_event_types) |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + return self.softplus(self.intensity_net(x)) |
| 35 | + |
| 36 | + |
| 37 | +class S2P2(TorchBaseModel): |
| 38 | + def __init__(self, model_config): |
| 39 | + """Initialize the model |
| 40 | +
|
| 41 | + Args: |
| 42 | + model_config (EasyTPP.ModelConfig): config of model specs. |
| 43 | + """ |
| 44 | + super(S2P2, self).__init__(model_config) |
| 45 | + self.n_layers = model_config.num_layers |
| 46 | + self.P = model_config.model_specs["P"] # Hidden state dimension |
| 47 | + self.H = model_config.hidden_size # Residual stream dimension |
| 48 | + self.beta = model_config.model_specs.get("beta", 1.0) |
| 49 | + self.bias = model_config.model_specs.get("bias", True) |
| 50 | + self.simple_mark = model_config.model_specs.get("simple_mark", True) |
| 51 | + |
| 52 | + layer_kwargs = dict( |
| 53 | + P=self.P, |
| 54 | + H=self.H, |
| 55 | + dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4), |
| 56 | + dt_init_max=model_config.model_specs.get("dt_init_max", 0.1), |
| 57 | + act_func=model_config.model_specs.get("act_func", "full_glu"), |
| 58 | + dropout_rate=model_config.model_specs.get("dropout_rate", 0.0), |
| 59 | + for_loop=model_config.model_specs.get("for_loop", False), |
| 60 | + pre_norm=model_config.model_specs.get("pre_norm", True), |
| 61 | + post_norm=model_config.model_specs.get("post_norm", False), |
| 62 | + simple_mark=self.simple_mark, |
| 63 | + relative_time=model_config.model_specs.get("relative_time", False), |
| 64 | + complex_values=model_config.model_specs.get("complex_values", True), |
| 65 | + ) |
| 66 | + |
| 67 | + int_forward_variant = model_config.model_specs.get("int_forward_variant", False) |
| 68 | + int_backward_variant = model_config.model_specs.get( |
| 69 | + "int_backward_variant", False |
| 70 | + ) |
| 71 | + assert ( |
| 72 | + int_forward_variant + int_backward_variant |
| 73 | + ) <= 1 # Only one at most is allowed to be specified |
| 74 | + |
| 75 | + if int_forward_variant: |
| 76 | + llh_layer = Int_Forward_LLH |
| 77 | + elif int_backward_variant: |
| 78 | + llh_layer = Int_Backward_LLH |
| 79 | + else: |
| 80 | + llh_layer = LLH |
| 81 | + |
| 82 | + self.backward_variant = int_backward_variant |
| 83 | + |
| 84 | + self.layers = nn.ModuleList( |
| 85 | + [ |
| 86 | + llh_layer(**layer_kwargs, is_first_layer=i == 0) |
| 87 | + for i in range(self.n_layers) |
| 88 | + ] |
| 89 | + ) |
| 90 | + self.layers_mark_emb = nn.Embedding( |
| 91 | + self.num_event_types_pad, |
| 92 | + self.H, |
| 93 | + ) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse |
| 94 | + self.layer_type_emb = None # Remove old embeddings from EasyTPP |
| 95 | + self.intensity_net = IntensityNet( |
| 96 | + input_dim=self.H, |
| 97 | + bias=self.bias, |
| 98 | + num_event_types=self.num_event_types, |
| 99 | + ) |
| 100 | + |
| 101 | + def _get_intensity( |
| 102 | + self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH |
| 103 | + ) -> torch.Tensor: |
| 104 | + """ |
| 105 | + Assume time has already been evolved, take a vertical stack of hidden states and produce intensity. |
| 106 | + """ |
| 107 | + left_u_H = None |
| 108 | + for i, layer in enumerate(self.layers): |
| 109 | + if isinstance( |
| 110 | + x_LP, list |
| 111 | + ): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor |
| 112 | + left_u_H = layer.depth_pass( |
| 113 | + x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i] |
| 114 | + ) |
| 115 | + else: |
| 116 | + left_u_H = layer.depth_pass( |
| 117 | + x_LP[..., i, :], |
| 118 | + current_left_u_H=left_u_H, |
| 119 | + prev_right_u_H=right_us_BNH[i], |
| 120 | + ) |
| 121 | + |
| 122 | + return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H)) |
| 123 | + |
| 124 | + def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H): |
| 125 | + left_u_GH = None |
| 126 | + for i, layer in enumerate(self.layers): |
| 127 | + x_GP = layer.get_left_limit( |
| 128 | + right_limit_P=x_LP[..., i, :], |
| 129 | + dt_G=dt_G, |
| 130 | + next_left_u_GH=left_u_GH, |
| 131 | + current_right_u_H=right_us_H[i], |
| 132 | + ) |
| 133 | + left_u_GH = layer.depth_pass( |
| 134 | + current_left_x_P=x_GP, |
| 135 | + current_left_u_H=left_u_GH, |
| 136 | + prev_right_u_H=right_us_H[i], |
| 137 | + ) |
| 138 | + return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH)) |
| 139 | + |
| 140 | + def forward( |
| 141 | + self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs |
| 142 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 143 | + """ |
| 144 | + Batch operations of self._forward |
| 145 | + """ |
| 146 | + t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch |
| 147 | + |
| 148 | + right_xs_BNP = [] # including both t_0 and t_N |
| 149 | + left_xs_BNm1P = [] |
| 150 | + right_us_BNH = [ |
| 151 | + None |
| 152 | + ] # Start with None as this is the 'input' to the first layer |
| 153 | + left_u_BNH, right_u_BNH = None, None |
| 154 | + alpha_BNP = self.layers_mark_emb(marks_BN) |
| 155 | + |
| 156 | + for l_i, layer in enumerate(self.layers): |
| 157 | + # for each event, compute the fixed impulse via alpha_m for event i of type m |
| 158 | + init_state = ( |
| 159 | + initial_state_BLP[:, l_i] if initial_state_BLP is not None else None |
| 160 | + ) |
| 161 | + |
| 162 | + # Returns right limit of xs and us for [t0, t1, ..., tN] |
| 163 | + # "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys) |
| 164 | + # x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N] |
| 165 | + # next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant |
| 166 | + # next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT |
| 167 | + x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward( |
| 168 | + left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state |
| 169 | + ) |
| 170 | + assert next_layer_right_u_BNH is not None |
| 171 | + |
| 172 | + right_xs_BNP.append(x_BNP) |
| 173 | + if next_layer_left_u_BNH is None: # NOT backward variant |
| 174 | + left_xs_BNm1P.append( |
| 175 | + layer.get_left_limit( # current and next at event level |
| 176 | + x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}] |
| 177 | + dt_BN[..., 1:].unsqueeze( |
| 178 | + -1 |
| 179 | + ), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}] |
| 180 | + current_right_u_H=right_u_BNH |
| 181 | + if right_u_BNH is None |
| 182 | + else right_u_BNH[ |
| 183 | + ..., :-1, : |
| 184 | + ], # at time [t_0, t_1, ..., t_{N-1}] |
| 185 | + next_left_u_GH=left_u_BNH |
| 186 | + if left_u_BNH is None |
| 187 | + else left_u_BNH[..., 1:, :].unsqueeze( |
| 188 | + -2 |
| 189 | + ), # at time [t_1, t_2 ..., t_N] |
| 190 | + ).squeeze(-2) |
| 191 | + ) |
| 192 | + right_us_BNH.append(next_layer_right_u_BNH) |
| 193 | + |
| 194 | + left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH |
| 195 | + |
| 196 | + right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2) |
| 197 | + |
| 198 | + ret_val = { |
| 199 | + "right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N] |
| 200 | + "right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None |
| 201 | + } |
| 202 | + |
| 203 | + if left_u_BNH is not None: # backward variant |
| 204 | + ret_val["left_u_BNm1H"] = left_u_BNH[ |
| 205 | + ..., 1:, : |
| 206 | + ] # The next inputs after last layer -> transformation of ys |
| 207 | + else: # NOT backward variant |
| 208 | + ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2) |
| 209 | + |
| 210 | + # 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not) |
| 211 | + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us |
| 212 | + return ret_val |
| 213 | + |
| 214 | + def loglike_loss(self, batch, **kwargs): |
| 215 | + # hidden states at the left and right limits around event time; note for the shift by 1 in indices: |
| 216 | + # consider a sequence [t0, t1, ..., tN] |
| 217 | + # Produces the following: |
| 218 | + # left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers |
| 219 | + # OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer |
| 220 | + # |
| 221 | + # right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers |
| 222 | + # right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers |
| 223 | + forward_results = self.forward( |
| 224 | + batch |
| 225 | + ) # N minus 1 comparing with sequence lengths |
| 226 | + right_xs_BNLP, right_us_BNH = ( |
| 227 | + forward_results["right_xs_BNLP"], |
| 228 | + forward_results["right_us_BNH"], |
| 229 | + ) |
| 230 | + right_us_BNm1H = [ |
| 231 | + None if right_u_BNH is None else right_u_BNH[:, :-1, :] |
| 232 | + for right_u_BNH in right_us_BNH |
| 233 | + ] |
| 234 | + |
| 235 | + ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch |
| 236 | + |
| 237 | + # evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M] |
| 238 | + # left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N |
| 239 | + # Note: no need to discard the left limit of t_N because "marks_mask" will deal with it |
| 240 | + if "left_u_BNm1H" in forward_results: # ONLY backward variant |
| 241 | + intensity_B_Nm1_M = self.intensity_net( |
| 242 | + forward_results["left_u_BNm1H"] |
| 243 | + ) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"])) |
| 244 | + else: # NOT backward variant |
| 245 | + intensity_B_Nm1_M = self._get_intensity( |
| 246 | + forward_results["left_xs_BNm1LP"], right_us_BNm1H |
| 247 | + ) |
| 248 | + |
| 249 | + # sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample] |
| 250 | + # N-1 because we only consider the intervals between N events |
| 251 | + # G for grid points |
| 252 | + dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) |
| 253 | + |
| 254 | + # evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M) |
| 255 | + intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts( |
| 256 | + right_xs_BNLP[ |
| 257 | + :, :-1 |
| 258 | + ], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1 |
| 259 | + dts_sample_B_Nm1_G, |
| 260 | + right_us_BNm1H, |
| 261 | + ) |
| 262 | + |
| 263 | + event_ll, non_event_ll, num_events = self.compute_loglikelihood( |
| 264 | + lambda_at_event=intensity_B_Nm1_M, |
| 265 | + lambdas_loss_samples=intensity_dts_B_Nm1_G_M, |
| 266 | + time_delta_seq=dts_BN[:, 1:], |
| 267 | + seq_mask=batch_non_pad_mask[:, 1:], |
| 268 | + type_seq=marks_BN[:, 1:], |
| 269 | + ) |
| 270 | + |
| 271 | + # compute loss to optimize |
| 272 | + loss = -(event_ll - non_event_ll).sum() |
| 273 | + |
| 274 | + return loss, num_events |
| 275 | + |
| 276 | + def compute_intensities_at_sample_times( |
| 277 | + self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs |
| 278 | + ): |
| 279 | + """Compute the intensity at sampled times, not only event times. *from the left limit* |
| 280 | +
|
| 281 | + Args: |
| 282 | + time_seq (tensor): [batch_size, seq_len], times seqs. |
| 283 | + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. |
| 284 | + event_seq (tensor): [batch_size, seq_len], event type seqs. |
| 285 | + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. |
| 286 | +
|
| 287 | + Returns: |
| 288 | + tensor: [batch_size, num_times, num_mc_sample, num_event_types], |
| 289 | + intensity at each timestamp for each event type. |
| 290 | + """ |
| 291 | + |
| 292 | + compute_last_step_only = kwargs.get("compute_last_step_only", False) |
| 293 | + |
| 294 | + # assume inter_event_times_BN always starts from 0 |
| 295 | + _input = event_times_BN, inter_event_times_BN, marks_BN, None, None |
| 296 | + |
| 297 | + # 'seq_len - 1' left limit for [t_1, ..., t_N] |
| 298 | + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] |
| 299 | + |
| 300 | + forward_results = self.forward( |
| 301 | + _input |
| 302 | + ) # N minus 1 comparing with sequence lengths |
| 303 | + right_xs_BNLP, right_us_BNH = ( |
| 304 | + forward_results["right_xs_BNLP"], |
| 305 | + forward_results["right_us_BNH"], |
| 306 | + ) |
| 307 | + |
| 308 | + if ( |
| 309 | + compute_last_step_only |
| 310 | + ): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...] |
| 311 | + right_us_B1H = [ |
| 312 | + None if right_u_BNH is None else right_u_BNH[:, -1:, :] |
| 313 | + for right_u_BNH in right_us_BNH |
| 314 | + ] |
| 315 | + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( |
| 316 | + right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H |
| 317 | + ) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...] |
| 318 | + else: |
| 319 | + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( |
| 320 | + right_xs_BNLP, sample_dtimes, right_us_BNH |
| 321 | + ) |
| 322 | + return sampled_intensity # [B, N, MC, M] |
0 commit comments