Skip to content

Commit 34fbf84

Browse files
committed
fix
1 parent f651165 commit 34fbf84

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

Function_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from Variable import Variable
33
import numpy as np
44
import unittest
5+
import Utils as u
56

67
def Test1():
78
x = Variable(11)
@@ -87,3 +88,13 @@ def test_backward(self):
8788

8889
expected = np.array(6.)
8990
self.assertEqual(X.grad, expected)
91+
92+
def test_backward_auto(self):
93+
X = Variable(np.array(np.random.rand(1)))
94+
y = square(X)
95+
y.backward()
96+
expected = u.numerical_diff(square, X)
97+
print(expected)
98+
print(X.grad)
99+
flag = np.allclose(X.grad, expected)
100+
self.assertTrue(flag)

Utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import numpy as np
2+
from Variable import Variable
23

34
def as_array(x) -> np.ndarray:
45
if np.isscalar(x):
56
return np.array(x)
67
return x
8+
9+
def numerical_diff(fx, x: Variable, eps=1e-4):
10+
x_l = Variable(as_array(x.data-eps))
11+
x_r = Variable(as_array(x.data+eps))
12+
y_l = fx(x_l)
13+
y_r = fx(x_r)
14+
return (y_r.data - y_l.data) / (2*eps)

0 commit comments

Comments
 (0)