-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_functions.py
74 lines (48 loc) · 2.28 KB
/
test_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# All test code
import unittest
from IPython.display import Markdown, display
# Helper functions for printing markdown text (text in color/bold/etc)
def printmd(string):
display(Markdown(string))
# Print a test failed message, given an error
def print_fail():
printmd('**<span style="color: red;">TEST FAILED</span>**')
# Print a test passed message
def print_pass():
printmd('**<span style="color: green;">TEST PASSED</span>**')
# A class holding all tests
class Tests(unittest.TestCase):
# Tests the `one_hot_encode` function, which is passed in as an argument
def test_one_hot(self, one_hot_function):
# Test that the generated one-hot labels match the expected one-hot label
# For all three cases (red, yellow, green)
try:
self.assertEqual([1,0,0], one_hot_function('red'))
self.assertEqual([0,1,0], one_hot_function('yellow'))
self.assertEqual([0,0,1], one_hot_function('green'))
# If the function does *not* pass all 3 tests above, it enters this exception
except self.failureException as e:
# Print out an error message
print_fail()
print("Your function did not return the expected one-hot label.")
print('\n'+str(e))
return
# Print out a "test passed" message
print_pass()
# Tests if any misclassified images are red but mistakenly classified as green
def test_red_as_green(self, misclassified_images):
# Loop through each misclassified image and the labels
for im, predicted_label, true_label in misclassified_images:
# Check if the image is one of a red light
if(true_label == [1,0,0]):
try:
# Check that it is NOT labeled as a green light
self.assertNotEqual(predicted_label, [0, 0, 1])
except self.failureException as e:
# Print out an error message
print_fail()
print("Warning: A red light is classified as green.")
print('\n'+str(e))
return
# No red lights are classified as green; test passed
print_pass()