2
2
import torch
3
3
import math
4
4
import ReLU
5
+ import Tanh
5
6
6
7
dtype = torch .double
7
8
device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
8
9
9
10
10
11
class RNN :
11
- def __init__ (self , input_dim , hidden_dim ,output_dim ,mx = 1.0e10 ):
12
+ def __init__ (self , input_dim , hidden_dim ,output_dim ,mx = 1.0e4 ):
12
13
self .max = mx
13
14
self .input_dim = input_dim
14
15
self .hidden_dim = hidden_dim
15
16
self .output_dim = output_dim
16
17
17
- self .weights_hh = torch .randn (hidden_dim , hidden_dim , dtype = dtype , device = device )* math . sqrt ( 2.0 / self .hidden_dim )
18
- self .weights_hx = torch .randn (hidden_dim , input_dim , dtype = dtype , device = device )* math . sqrt ( 2.0 / self .hidden_dim )
19
- self .weights_hy = torch .randn (output_dim , hidden_dim , dtype = dtype , device = device )* math . sqrt ( 2.0 / self .hidden_dim )
20
- self .bias_h = torch .randn (hidden_dim , 1 , dtype = dtype , device = device )* math . sqrt ( 2.0 / self .hidden_dim ) # hidden_dim X 1
21
- self .bias_y = torch .randn (output_dim , 1 , dtype = dtype , device = device )* math . sqrt ( 2.0 / self .hidden_dim ) # output_dim X 1
18
+ self .weights_hh = torch .randn (hidden_dim , hidden_dim , dtype = dtype , device = device )* 0.01 # self.hidden_dim)
19
+ self .weights_hx = torch .randn (hidden_dim , input_dim , dtype = dtype , device = device )* 0.01 # self.hidden_dim)
20
+ self .weights_hy = torch .randn (output_dim , hidden_dim , dtype = dtype , device = device )* 0.01 # self.hidden_dim)
21
+ self .bias_h = torch .randn (hidden_dim , 1 , dtype = dtype , device = device )* 0.01 # self.hidden_dim) # hidden_dim X 1
22
+ self .bias_y = torch .randn (output_dim , 1 , dtype = dtype , device = device )* 0.01 # self.output_dim ) # output_dim X 1
22
23
23
24
self .y = None
24
25
self .h = None
@@ -32,15 +33,14 @@ def __init__(self, input_dim, hidden_dim,output_dim,mx=1.0e10):
32
33
self .grad_bias_y = None
33
34
self .grad_inp = None
34
35
self .grad_prev = None
35
- self .r = ReLU . ReLU ()
36
+ self .r = Tanh . Tanh ()
36
37
37
38
def forward (self , input ,isTrain = False ):
38
39
# if istrain:
39
40
self .y = []
40
41
# print(input)
41
42
self .h = [torch .zeros (input [0 ].size ()[0 ] , self .hidden_dim , dtype = dtype , device = device )]
42
43
self .h_bef_act = [torch .zeros (input [0 ].size ()[0 ] , self .hidden_dim , dtype = dtype , device = device )]
43
- self .prev_h = []
44
44
self .x = input
45
45
46
46
for i in range (len (input )):
@@ -74,9 +74,9 @@ def backward(self, input, gradOutput):
74
74
self .grad_bias_y = self .grad_bias_y .add (grad_y .sum (dim = 0 ).reshape (self .output_dim ,1 ))
75
75
self .grad_Why = self .grad_Why .add (grad_y .transpose (0 ,1 ).mm (self .h [i ])) # output X hidden
76
76
# print(self.h_bef_act[i],grad_ht)
77
- grad_act = self .r .backward (self .h_bef_act [i ],grad_ht ) + grad_y .mm (self .weights_hy ) # batch X hidden
77
+ grad_act = self .r .backward (self .h_bef_act [i ],grad_ht + grad_y .mm (self .weights_hy ) ) # batch X hidden
78
78
self .grad_bias_h = self .grad_bias_h .add (grad_act .sum (dim = 0 ).reshape (self .hidden_dim ,1 )) # hidden X 1
79
- self .grad_Whh = self .grad_Whh .add (grad_act .transpose (0 ,1 ).mm (self .h [i - 1 ]))
79
+ self .grad_Whh = self .grad_Whh .add (grad_act .transpose (0 ,1 ).mm (self .h_bef_act [i - 1 ]))
80
80
# print(self.grad_Whx.size(),grad_act.size(),input[i].size())
81
81
self .grad_Whx = self .grad_Whx .add (grad_act .transpose (0 ,1 ).mm (input [i ])) # hidden X input
82
82
@@ -92,28 +92,41 @@ def clearGradParam(self):
92
92
self .grad_bias_h = torch .zeros (self .hidden_dim , 1 , dtype = dtype , device = device )
93
93
self .grad_bias_y = torch .zeros (self .output_dim , 1 , dtype = dtype , device = device )
94
94
95
+ def clip (self ,M ):
96
+ M [M > self .max ] = self .max
97
+ M [M < - self .max ] = - self .max
98
+ return M
99
+
95
100
def updateParam (self , learningRate , alpha = 0 , regularizer = 0 ):
96
101
# print('update')
97
- # print(self.grad_Whx)
98
-
99
- self .grad_Whh [self .grad_Whh > self .max ] = self .max
100
- self .grad_Whx [self .grad_Whx > self .max ] = self .max
101
- self .grad_Why [self .grad_Why > self .max ] = self .max
102
- self .grad_bias_h [self .grad_bias_h > self .max ] = self .max
103
- self .grad_bias_y [self .grad_bias_y > self .max ] = self .max
104
-
105
- self .grad_Whh [self .grad_Whh < - self .max ] = - self .max
106
- self .grad_Whx [self .grad_Whx < - self .max ] = - self .max
107
- self .grad_Why [self .grad_Why < - self .max ] = - self .max
108
- self .grad_bias_h [self .grad_bias_h < - self .max ] = - self .max
109
- self .grad_bias_y [self .grad_bias_y < - self .max ] = - self .max
110
-
111
-
112
- self .weights_hh -= self .grad_Whh * learningRate
113
- self .weights_hx -= self .grad_Whx * learningRate
114
- self .weights_hy -= self .grad_Why * learningRate
115
- self .bias_h -= self .grad_bias_h * learningRate
116
- self .bias_y -= self .grad_bias_y * learningRate
102
+
103
+
104
+ # self.grad_Whh[self.grad_Whh>self.max] = self.max
105
+ # self.grad_Whx[self.grad_Whx>self.max] = self.max
106
+ # self.grad_Why[self.grad_Why>self.max] = self.max
107
+ # self.grad_bias_h[self.grad_bias_h>self.max] = self.max
108
+ # self.grad_bias_y[self.grad_bias_y>self.max] = self.max
109
+
110
+ # self.grad_Whh[self.grad_Whh<-self.max] = -self.max
111
+ # self.grad_Whx[self.grad_Whx<-self.max] = -self.max
112
+ # self.grad_Why[self.grad_Why<-self.max] = -self.max
113
+ # self.grad_bias_h[self.grad_bias_h<-self.max] = -self.max
114
+ # self.grad_bias_y[self.grad_bias_y<-self.max] = -self.max
115
+
116
+
117
+ grad_Whh = self .clip (self .grad_Whh )
118
+ grad_Whx = self .clip (self .grad_Whx )
119
+ grad_Why = self .clip (self .grad_Why )
120
+ grad_bias_h = self .clip (self .grad_bias_h )
121
+ grad_bias_y = self .clip (self .grad_bias_y )
122
+
123
+ self .weights_hh -= (self .grad_Whh * learningRate + 2 * regularizer * self .weights_hh )
124
+ self .weights_hx -= (self .grad_Whx * learningRate + 2 * regularizer * self .weights_hx )
125
+ self .weights_hy -= (self .grad_Why * learningRate + 2 * regularizer * self .weights_hy )
126
+ self .bias_h -= (self .grad_bias_h * learningRate + 2 * regularizer * self .bias_h )
127
+ self .bias_y -= (self .grad_bias_y * learningRate + 2 * regularizer * self .bias_y )
128
+
129
+ # print(self.weights_hh)
117
130
118
131
# self.W += (self.momentumW -2*regularizer*self.W)
119
132
# self.B += (self.momentumB -2*regularizer*self.B)
0 commit comments