forked from UCSBarchlab/OpenTPU
-
Notifications
You must be signed in to change notification settings - Fork 0
/
checker.py
77 lines (62 loc) · 2.17 KB
/
checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
""" The checker assumes results are always written to hostmem
consecutively starting from location 0.
If the result is shorter than HW width: (X is don't care)
--------HW WIDTH---------
D D D D D D D X X X X X X
D D D D D D D X X X X X X
else:
--------HW WIDTH---------
D D D D D D D D D D D D D
D D X X X X X X X X X X X
D D D D D D D D D D D D D
D D X X X X X X X X X X X
"""
import argparse
import numpy as np
args = None
def equal(a1, a2):
assert a1.shape == a2.shape, 'result file shape mismatch.'
if a1.dtype == np.int8:
a1 = a1.astype(np.uint8)
if a2.dtype == np.int8:
a2 = a2.astype(np.uint8)
for x, y in np.nditer([a1, a2]):
assert x == y, 'result value mismatch.'
def check(p1, p2, width=None):
r1 = np.load(p1)
r2 = np.load(p2)
if not width:
# Checking sim8 against hw8.
equal(r1, r2)
else:
# Checking gt32 against sim32.
#assert width == r2.shape[1]
r_width = r1.shape[1]
if r_width <= width:
r2 = r2[:, :r_width]
equal(r1, r2)
else:
r2 = np.concatenate((r2[::2], r2[1::2]), axis=1)
r2 = r2[:, :r_width]
equal(r1, r2)
def parse_args():
global args
parser = argparse.ArgumentParser()
parser.add_argument('--width', action='store', type=int, default=16,
help='HW WIDTH.')
parser.add_argument('--gt32', action='store', default='gt32.npy',
help='path to f32 ground truth result.')
parser.add_argument('--sim32', action='store', default='sim32.npy',
help='path to f32 simulator result.')
parser.add_argument('--sim8', action='store', default='sim8.npy',
help='path to i8 simulator result.')
parser.add_argument('--hw8', action='store', default='hw8.npy',
help='path to i8 hardware result.')
args = parser.parse_args()
if __name__ == '__main__':
parse_args()
print 'HW width set to %d.' % args.width
check(args.gt32, args.sim32, args.width)
print '32-bit passed.'
check(args.sim8, args.hw8)
print '8-bit passed.'