forked from lawaii/DLMI2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMulticlassTruePositives.py
23 lines (18 loc) · 995 Bytes
/
MulticlassTruePositives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tensorflow as tf
class MulticlassTruePositives(tf.keras.metrics.Metric):
def __init__(self, name='multiclass_true_positives', **kwargs):
super(MulticlassTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
values = tf.cast(y_true, 'int32') == tf.cast(y_pred, 'int32')
values = tf.cast(values, 'float32')
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, 'float32')
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
def result(self):
return self.true_positives
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.)