Skip to content

Commit 4e7a2c0

Browse files
committed
Fix a critical bug in regularization
Fix a critical bug in regularization, it will cause weights blew up. Also fix other bugs
1 parent 78fd151 commit 4e7a2c0

File tree

9 files changed

+23
-33
lines changed

9 files changed

+23
-33
lines changed

RL/deepqlearn.js

+2-10
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,10 @@ class DQN {
122122
// compute the value of doing any action in this state
123123
// and return the argmax action and its value
124124
let action_values = this.value_net.forward(new Vol(s));
125-
let maxk = action_values.max_index();
125+
let maxk = action_values.max_index;
126126
return { action: maxk, value: action_values.w[maxk] };
127127
}
128128

129-
_toarray(arr) {
130-
let a = [];
131-
for (let i in arr) {
132-
a.push(arr[i]);
133-
}
134-
return a;
135-
}
136-
137129
getNetInput(xt) {
138130
// return s = (x, a, x, a, x, a, xt) state vector.
139131
// It's a concatenation of last window_size (x,a) pairs and current state x
@@ -150,7 +142,7 @@ class DQN {
150142
// we dont want weight regularization to undervalue this information, as it only exists once
151143
let action1ofk = one_hot(this.num_actions, action, 1.0 * this.num_states);
152144

153-
w = w.concat(this._toarray(action1ofk)); // do not concat array & floatarray
145+
w = w.concat(Array.prototype.slice.call(action1ofk)); // do not concat array & floatarray
154146
}
155147
return w;
156148
}

backend.js

+4-8
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@ function TensorVectorProduct(ov, m, v) {
66
let ncol = m.axis(-1) | 0;
77
let nrow = m.axis(-2) | 0;
88
let new_shape = m.shape.slice(); new_shape.pop();
9-
let bs = ncol * nrow | 0;
10-
let N = (m.size / bs) | 0;
9+
let N = (m.size / ncol) | 0;
1110

1211
let mw = m.w, vw = v.w, ow = ov.w;
1312
ow.fill(0.);
14-
for (let z = 0; z < N; z++) {
15-
for (let i = 0; i < nrow; i++) {
16-
for (let j = 0; j < ncol; j++) {
17-
ow[z * nrow + i] += mw[z * bs + i * ncol + j] * vw[j];
18-
}
13+
for (let i = 0; i < N; i++) {
14+
for (let j = 0; j < ncol; j++) {
15+
ow[i] += mw[i * ncol + j] * vw[j];
1916
}
2017
}
2118
}
@@ -49,7 +46,6 @@ function TransposedTensorVectorProductAdd(ov, m, v) {
4946
}
5047

5148

52-
5349
/**
5450
* HadmardProduct apply to self
5551
*/

layers/layer.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Layer {
1919

2020
compile(options) {
2121
// setup objects for training
22-
this.updated.forEach(function(V) {
22+
this.updated.forEach((V) => {
2323
V.dw = V.zeros_like();
2424
V.optimizer = get_optimizer(V.size, options);
2525
if (V.allow_regl) V.regularizer = new Regularization(options.l2_decay, options.l1_decay, this.l2_decay_mul, this.l1_decay_mul);

objective.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
function meanSquaredError(x, y) {
77
let N = x.size;
88
let loss = 0.;
9-
let aw = a.w, yw = x.w, adw = a.dw;
9+
let xw = x.w, yw = y.w, xdw = x.dw;
1010
for (let i = 0; i < N; i++) {
11-
let dx = aw[i] - yw[i];
11+
let dx = xw[i] - yw[i];
1212
xdw[i] += dx;
1313
loss += 0.5 * dx * dx;
1414
}

regularization.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Regularization {
2323
decay_loss += l1_decay * Math.abs(p);
2424
let l1grad = l1_decay * (p > 0 ? 1 : -1);
2525
let l2grad = l2_decay * (p);
26-
dx[i] -= (l2grad + l1grad);
26+
dx[i] += (l2grad + l1grad);
2727
}
2828
return decay_loss;
2929
}

topology/vallia.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Net {
8080

8181
// this is a convenience function for returning the argmax
8282
// return index of the class with highest class probability
83-
get prediction(x) {
83+
prediction(x) {
8484
if (typeof x !== 'undefined') this.forward(x);
8585
// assume output is a vector
8686
return this.output.max_index;

trainer.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class Trainer {
4040
let updates = this.net.trainables;
4141
for (let i in updates) {
4242
let T = updates[i];
43+
44+
4345
if (T.regularizer) regular_loss += T.regularizer.punish(T);
4446
// make raw batch gradient
4547
T.batchGrad(this.batch_size);
@@ -51,7 +53,7 @@ class Trainer {
5153

5254
return {
5355
fwd_time: timer.getTime('forward'),
54-
bwd_time: getTime('backward'),
56+
bwd_time: timer.getTime('backward'),
5557

5658
regular_loss: regular_loss,
5759
cost_loss: cost_loss,

util/timing.js

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@ class Timer {
22
constructor() {
33
this.lasttime = {};
44
this.sum = {};
5-
if (performance.now) {
6-
this.get_time = performance.now;
7-
} else {
8-
this.get_time = new Date.now;
9-
}
5+
// if (performance.now) {
6+
// this.get_time = performance.now;
7+
// } else {
8+
// this.get_time = new Date.now;
9+
// }
1010
}
1111

1212
start(name) {
1313
if (!this.sum[name]) this.sum[name]
1414
this.lastname = name;
15-
lasttime[name] = this.get_time();
15+
this.lasttime[name] = performance.now();
1616
}
1717

1818
stop(name) {
19-
this.sum[name] += this.get_time() - this.lasttime[name];
19+
this.sum[name] += performance.now() - this.lasttime[name];
2020
}
2121

2222
stoplast() {
23-
this.sum[this.lastname] += this.get_time() - this.lasttime[this.lastname];
23+
this.stop(this.lastname);
2424
}
2525

2626
passto(name) {

vol.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Vol {
2828
}
2929
}
3030

31-
// this.dw = this.zeros_like(); // -- save memory, allocmem at training?
31+
this.dw = this.zeros_like(); // -- save memory, allocmem at training?
3232
this.length = this.size;
3333
}
3434

0 commit comments

Comments
 (0)