File tree 2 files changed +19
-0
lines changed
2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change 2
2
from Variable import Variable
3
3
import numpy as np
4
4
import unittest
5
+ import Utils as u
5
6
6
7
def Test1 ():
7
8
x = Variable (11 )
@@ -87,3 +88,13 @@ def test_backward(self):
87
88
88
89
expected = np .array (6. )
89
90
self .assertEqual (X .grad , expected )
91
+
92
+ def test_backward_auto (self ):
93
+ X = Variable (np .array (np .random .rand (1 )))
94
+ y = square (X )
95
+ y .backward ()
96
+ expected = u .numerical_diff (square , X )
97
+ print (expected )
98
+ print (X .grad )
99
+ flag = np .allclose (X .grad , expected )
100
+ self .assertTrue (flag )
Original file line number Diff line number Diff line change 1
1
import numpy as np
2
+ from Variable import Variable
2
3
3
4
def as_array (x ) -> np .ndarray :
4
5
if np .isscalar (x ):
5
6
return np .array (x )
6
7
return x
8
+
9
+ def numerical_diff (fx , x : Variable , eps = 1e-4 ):
10
+ x_l = Variable (as_array (x .data - eps ))
11
+ x_r = Variable (as_array (x .data + eps ))
12
+ y_l = fx (x_l )
13
+ y_r = fx (x_r )
14
+ return (y_r .data - y_l .data ) / (2 * eps )
You can’t perform that action at this time.
0 commit comments