Skip to content

Commit 0a789d3

Browse files
authored
Merge pull request #64 from ant-research/origin_tf_torch_v250403
1. fix metrics when mask is empty; 2. fix generation runner
2 parents 1d2ce04 + 6771913 commit 0a789d3

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@ log
2525
.idea
2626

2727
examples/checkpoints/*
28+
29+
notebooks/checkpoints/*

easy_tpp/default_registers/register_metrics.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@ def rmse_metric_function(predictions, labels, **kwargs):
1616
float: average rmse of the time predictions.
1717
"""
1818
seq_mask = kwargs.get('seq_mask')
19-
pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
20-
label = labels[PredOutputIndex.TimePredIndex][seq_mask]
19+
if seq_mask is None or len(seq_mask) == 0:
20+
# If mask is empty or None, use all predictions
21+
pred = predictions[PredOutputIndex.TimePredIndex]
22+
label = labels[PredOutputIndex.TimePredIndex]
23+
else:
24+
pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
25+
label = labels[PredOutputIndex.TimePredIndex][seq_mask]
2126

2227
pred = np.reshape(pred, [-1])
2328
label = np.reshape(label, [-1])
@@ -36,8 +41,13 @@ def acc_metric_function(predictions, labels, **kwargs):
3641
float: accuracy ratio of the type predictions.
3742
"""
3843
seq_mask = kwargs.get('seq_mask')
39-
pred = predictions[PredOutputIndex.TypePredIndex][seq_mask]
40-
label = labels[PredOutputIndex.TypePredIndex][seq_mask]
44+
if seq_mask is None or len(seq_mask) == 0:
45+
# If mask is empty or None, use all predictions
46+
pred = predictions[PredOutputIndex.TypePredIndex]
47+
label = labels[PredOutputIndex.TypePredIndex]
48+
else:
49+
pred = predictions[PredOutputIndex.TypePredIndex][seq_mask]
50+
label = labels[PredOutputIndex.TypePredIndex][seq_mask]
4151
pred = np.reshape(pred, [-1])
4252
label = np.reshape(label, [-1])
4353
return np.mean(pred == label)

easy_tpp/model/torch_model/torch_basemodel.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,18 @@ def predict_multi_step_since_last_event(self, batch, forward=False):
205205
"""Multi-step prediction since last event in the sequence.
206206
207207
Args:
208-
time_seqs (tensor): [batch_size, seq_len].
209-
time_delta_seqs (tensor): [batch_size, seq_len].
210-
type_seqs (tensor): [batch_size, seq_len].
211-
num_step (int): num of steps for prediction.
208+
batch (tuple): A tuple containing:
209+
- time_seq_label (tensor): Timestamps of events [batch_size, seq_len].
210+
- time_delta_seq_label (tensor): Time intervals between events [batch_size, seq_len].
211+
- event_seq_label (tensor): Event types [batch_size, seq_len].
212+
- batch_non_pad_mask_label (tensor): Mask for non-padding elements [batch_size, seq_len].
213+
- attention_mask (tensor): Mask for attention [batch_size, seq_len].
214+
forward (bool, optional): Whether to use the entire sequence for prediction. Defaults to False.
212215
213216
Returns:
214217
tuple: tensors of dtime and type prediction, [batch_size, seq_len].
215218
"""
216-
time_seq_label, time_delta_seq_label, event_seq_label, batch_non_pad_mask_label, _, type_mask_label = batch
219+
time_seq_label, time_delta_seq_label, event_seq_label, _, _ = batch
217220

218221
num_step = self.gen_config.num_step_gen
219222

notebooks/easytpp_1_dataset.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
],
9898
"source": [
9999
"# ues the latest release\n",
100-
"# !pip install easy_tpp\n",
100+
"# !pip install easy-tpp\n",
101101
"\n",
102102
"# or use the git main branch\n",
103103
"!pip install git+https://github.com/ant-research/EasyTemporalPointProcess.git"

version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.0'
1+
__version__ = '0.1.2'

0 commit comments

Comments
 (0)