Skip to content

Commit 191f0d9

Browse files
authored
Merge pull request #70 from ajboyd2/patch-2
Correct scaling of output intensities in FullyNN.
2 parents da59374 + 66406ef commit 191f0d9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

easy_tpp/model/torch_model/torch_fullynn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def forward(self, hidden_states, time_delta_seqs):
6464
derivative_integral_lambdas = []
6565
for i in range(integral_lambda.shape[-1]): # iterate over marks
6666
derivative_integral_lambdas.append(grad(
67-
integral_lambda[..., i].mean(),
67+
integral_lambda[..., i].sum(),
6868
time_delta_seqs,
6969
create_graph=True, retain_graph=True)[0])
7070
derivative_integral_lambda = torch.stack(derivative_integral_lambdas, dim=-1) # TODO: Check that it is okay to iterate over marks like this
7171
else:
72-
derivative_integral_lambda = grad(
73-
integral_lambda.sum(dim=-1).mean(),
72+
derivative_integral_lambda = grad(
73+
integral_lambda.sum(),
7474
time_delta_seqs,
7575
create_graph=True, retain_graph=True)[0]
7676
derivative_integral_lambda = derivative_integral_lambda.unsqueeze(-1).expand(*derivative_integral_lambda.shape, self.num_event_types) / self.num_event_types

0 commit comments

Comments
 (0)