forked from IsaacHaze/main_core_retention
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
59 lines (48 loc) · 1.29 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import sys
import os
from glob import iglob
def load_unigrams(fname):
unigrams = set([])
with open(fname) as f:
for line in f:
for tok in line.split():
unigrams.add(tok)
return unigrams
def avg(lst):
return sum(lst) / len(lst)
gold, eval = sys.argv[1], sys.argv[2]
gold_unigrams = {}
for fname in iglob(os.path.join(gold, '*')):
bn, _ = os.path.basename(fname).split('.', 1)
gold_unigrams[bn] = load_unigrams(fname)
eval_unigrams = {}
for fname in iglob(os.path.join(eval, '*')):
bn, _ = os.path.basename(fname).split('.', 1)
eval_unigrams[bn] = load_unigrams(fname)
ps = []
rs = []
f1s = []
for bn in gold_unigrams:
try:
g = gold_unigrams[bn]
e = eval_unigrams[bn]
i = g.intersection(e)
p = float(len(i))/len(e)
r = float(len(i))/len(g)
try:
f1 = 2*p*r/(p+r)
except ZeroDivisionError:
f1 = 0.0
except KeyError:
print >>sys.stderr, "ugh: {}".format(bn)
p = 0.0
r = 0.0
f1 = 0.0
print >>sys.stderr, "bn: {: >4}, p: {:1.3f}, r: {:1.3f}, f1: {:1.3f}".format(bn, p, r, f1)
ps.append(p)
rs.append(r)
f1s.append(f1)
print "num:", len(f1s)
print "precision:",avg(ps)
print "recall:",avg(rs)
print "f1:",avg(f1s)