Skip to content

Commit f1f0472

Browse files
committed
Get MNIST data running
1 parent decf538 commit f1f0472

File tree

6 files changed

+217
-4
lines changed

6 files changed

+217
-4
lines changed

mnist.html

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!doctype html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8" />
5+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
6+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/semantic.min.css">
7+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
8+
<title>Vite + TS</title>
9+
</head>
10+
<body>
11+
<div id="mnist" class="ui container" style="text-align: center">
12+
<div style="width: 56px; height: 56px; margin: 20px auto; padding-top: 40px; padding-bottom: 50px">
13+
<canvas id="digit" style="zoom: 200%;"></canvas>
14+
</div>
15+
16+
<div id="output"></div>
17+
<button class="ui button primary" id="next">Next Digit</button>
18+
<button class="ui button secondary" id="rerun">Run More Batches</button>
19+
20+
<div id="cycles"></div>
21+
<div id="correctness"></div>
22+
</div>
23+
<div id="app"></div>
24+
<script type="module" src="/src/mnist.ts"></script>
25+
</body>
26+
</html>

package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"dependencies": {
1616
"graphology": "^0.25.4",
1717
"graphology-types": "^0.24.7",
18-
"sigma": "^3.0.0-beta.26"
18+
"mnist": "^1.1.0",
19+
"sigma": "^3.0.0-beta.26",
20+
"threads": "^1.7.0"
1921
}
2022
}

src/mnist.ts

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import Network from "./network/network.ts";
2+
import Backpropagation from "./trainer/backpropagation.ts";
3+
import { Pool, spawn, Worker } from "threads"
4+
import mnist from 'mnist';
5+
import { ActivationFunction } from "./neuron/neuron.ts";
6+
7+
document.querySelector<HTMLDivElement>('#app')!.innerHTML = `
8+
<div id="graph" style="margin: 0 auto; width: 100%; height: 400px;"></div>
9+
`
10+
11+
const network = new Network(
12+
28 * 28,
13+
{ neurons: 10 },
14+
[
15+
{ neurons: 100, activationFunction: ActivationFunction.Relu },
16+
//{ neurons: 100, activationFunction: ActivationFunction.Sigmoid },
17+
]
18+
)
19+
20+
const trainer = new Backpropagation(network, 0.2)
21+
const formatSamples = (mnistSamples) => {
22+
return mnistSamples.map(({ input, output }) => {
23+
return { inputs: input, outputs: output }
24+
})
25+
}
26+
27+
let cycles = 0
28+
29+
async function main() {
30+
const pool = Pool(() => {
31+
const worker = new Worker(new URL('./worker.ts', import.meta.url), {
32+
type: 'module'
33+
})
34+
35+
return spawn(worker)
36+
}, 1)
37+
38+
for (let i=0; i<50; i++) {
39+
pool.queue(async worker => {
40+
const exportedNetwork = network.export()
41+
const adjustments = await worker.trainBatch(exportedNetwork)
42+
trainer.applyAdjustments(adjustments)
43+
44+
cycles++
45+
46+
document.getElementById("cycles")!.innerHTML = `<h4 style="margin: 10px 0">${cycles} Batches Run</h4>`
47+
})
48+
}
49+
50+
await pool.completed()
51+
await pool.terminate()
52+
53+
for (const { inputs, outputs } of formatSamples(mnist.get(10))) {
54+
const computed = network.compute(inputs)
55+
const expected = mnist.toNumber(outputs)
56+
const calculated = mnist.toNumber(computed)
57+
58+
console.log("Expecting", expected, " == ", calculated, computed)
59+
}
60+
61+
showDigit()
62+
document.getElementById("correctness")!.innerHTML = `Network Correctness: ${correctness().toFixed(3)}%`
63+
}
64+
65+
function correctness(): number {
66+
let total = 0;
67+
let correct = 0;
68+
69+
for (const { inputs, outputs } of formatSamples(mnist.set(0, 2000).test)) {
70+
const computed = network.compute(inputs)
71+
const expected = mnist.toNumber(outputs)
72+
const calculated = mnist.toNumber(computed)
73+
74+
total++
75+
76+
if (expected === calculated) {
77+
correct++
78+
}
79+
}
80+
81+
return (100 * correct / total)
82+
}
83+
84+
function showDigit() {
85+
const data = mnist.get(1)[0]
86+
87+
const computed = network.compute(data.input)
88+
const expected = mnist.toNumber(data.output)
89+
const calculated = mnist.toNumber(computed)
90+
91+
const context = (document.getElementById<HTMLCanvasElement>('digit')).getContext('2d');
92+
document.getElementById("output")!.innerHTML = `
93+
<h5 style="margin: 5px;">Digit: ${expected} - Network Calculated: ${calculated}</h5>
94+
`
95+
mnist.draw(data.input, context); // draws a '1' mnist digit in the canvas
96+
}
97+
98+
document.getElementById("next")!.addEventListener("click", showDigit)
99+
document.getElementById("rerun")!.addEventListener("click", async () => {
100+
await main()
101+
})
102+
103+
await main()

src/trainer/backpropagation.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export default class Backpropagation {
1414

1515
train(sample: TrainingSample) {
1616
const adjustments = this.#getAdjustments(sample.inputs, sample.outputs)
17-
this.#applyAdjustments(adjustments)
17+
this.applyAdjustments(adjustments)
1818
}
1919

2020
trainBatch(samples: TrainingSample[]) {
@@ -43,7 +43,9 @@ export default class Backpropagation {
4343
})
4444
}
4545

46-
this.#applyAdjustments(finalAdjustment)
46+
this.applyAdjustments(finalAdjustment)
47+
48+
return finalAdjustment
4749
}
4850

4951
#getAdjustments(inputs: number[], outputs: number[]) {
@@ -72,7 +74,8 @@ export default class Backpropagation {
7274
return adjustments
7375
}
7476

75-
#applyAdjustments(adjustments: Adjustments) {
77+
78+
applyAdjustments(adjustments: Adjustments) {
7679
adjustments.forEach((neurons, layerIndex) => {
7780
neurons.forEach((synapses, neuronIndex) => {
7881
synapses.forEach((adjustment, synapseIndex) => {

src/worker.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { expose } from "threads/worker"
2+
import Network from "./network/network.ts";
3+
import Backpropagation from "./trainer/backpropagation.ts";
4+
import mnist from 'mnist';
5+
6+
const formatSamples = (mnistSamples) => {
7+
return mnistSamples.map(({ input, output }) => {
8+
return { inputs: input, outputs: output }
9+
})
10+
}
11+
12+
expose({
13+
trainBatch(exportedNetwork) {
14+
const network = Network.fromNetworkExport(exportedNetwork)
15+
const trainer = new Backpropagation(network, 0.2)
16+
17+
console.time("trainWorker")
18+
const adjustments = trainer.trainBatch(formatSamples(mnist.get(100)))
19+
console.timeEnd("trainWorker")
20+
21+
return adjustments
22+
}
23+
})

yarn.lock

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,18 @@
202202
resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.5.tgz#a6ce3e556e00fd9895dd872dd172ad0d4bd687f4"
203203
integrity sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==
204204

205+
callsites@^3.1.0:
206+
version "3.1.0"
207+
resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73"
208+
integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==
209+
210+
debug@^4.2.0:
211+
version "4.3.6"
212+
resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.6.tgz#2ab2c38fbaffebf8aa95fdfe6d88438c7a13c52b"
213+
integrity sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==
214+
dependencies:
215+
ms "2.1.2"
216+
205217
esbuild@^0.20.1:
206218
version "0.20.2"
207219
resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.20.2.tgz#9d6b2386561766ee6b5a55196c6d766d28c87ea1"
@@ -231,6 +243,11 @@ esbuild@^0.20.1:
231243
"@esbuild/win32-ia32" "0.20.2"
232244
"@esbuild/win32-x64" "0.20.2"
233245

246+
esm@^3.2.25:
247+
version "3.2.25"
248+
resolved "https://registry.yarnpkg.com/esm/-/esm-3.2.25.tgz#342c18c29d56157688ba5ce31f8431fbb795cc10"
249+
integrity sha512-U1suiZ2oDVWv4zPO56S0NcR5QriEahGtdN2OR6FiOG4WJvcjBVFB0qI4+eKoWFH483PKGuLuu6V8Z4T5g63UVA==
250+
234251
events@^3.3.0:
235252
version "3.3.0"
236253
resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400"
@@ -259,6 +276,21 @@ graphology@^0.25.4:
259276
events "^3.3.0"
260277
obliterator "^2.0.2"
261278

279+
is-observable@^2.1.0:
280+
version "2.1.0"
281+
resolved "https://registry.yarnpkg.com/is-observable/-/is-observable-2.1.0.tgz#5c8d733a0b201c80dff7bb7c0df58c6a255c7c69"
282+
integrity sha512-DailKdLb0WU+xX8K5w7VsJhapwHLZ9jjmazqCJq4X12CTgqq73TKnbRcnSLuXYPOoLQgV5IrD7ePiX/h1vnkBw==
283+
284+
mnist@^1.1.0:
285+
version "1.1.0"
286+
resolved "https://registry.yarnpkg.com/mnist/-/mnist-1.1.0.tgz#b83efc6af88d8db53b196665acdb50cf524bd2ba"
287+
integrity sha512-x+SfS5tSJnOEjad6jkuL91Cuq9EAVmy1IYt6vdWEsYCjJZsj0oB2xCq+sg5cgiEt9W+xBT31q9IwI2h7h1VtOw==
288+
289+
290+
version "2.1.2"
291+
resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"
292+
integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==
293+
262294
nanoid@^3.3.7:
263295
version "3.3.7"
264296
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.7.tgz#d0c301a691bc8d54efa0a2226ccf3fe2fd656bd8"
@@ -269,6 +301,11 @@ obliterator@^2.0.2:
269301
resolved "https://registry.yarnpkg.com/obliterator/-/obliterator-2.0.4.tgz#fa650e019b2d075d745e44f1effeb13a2adbe816"
270302
integrity sha512-lgHwxlxV1qIg1Eap7LgIeoBWIMFibOjbrYPIPJZcI1mmGAI2m3lNYpK12Y+GBdPQ0U1hRwSord7GIaawz962qQ==
271303

304+
observable-fns@^0.6.1:
305+
version "0.6.1"
306+
resolved "https://registry.yarnpkg.com/observable-fns/-/observable-fns-0.6.1.tgz#636eae4fdd1132e88c0faf38d33658cc79d87e37"
307+
integrity sha512-9gRK4+sRWzeN6AOewNBTLXir7Zl/i3GB6Yl26gK4flxz8BXVpD3kt8amREmWNb0mxYOGDotvE5a4N+PtGGKdkg==
308+
272309
picocolors@^1.0.0:
273310
version "1.0.0"
274311
resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c"
@@ -321,6 +358,25 @@ source-map-js@^1.2.0:
321358
resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.2.0.tgz#16b809c162517b5b8c3e7dcd315a2a5c2612b2af"
322359
integrity sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==
323360

361+
threads@^1.7.0:
362+
version "1.7.0"
363+
resolved "https://registry.yarnpkg.com/threads/-/threads-1.7.0.tgz#d9e9627bfc1ef22ada3b733c2e7558bbe78e589c"
364+
integrity sha512-Mx5NBSHX3sQYR6iI9VYbgHKBLisyB+xROCBGjjWm1O9wb9vfLxdaGtmT/KCjUqMsSNW6nERzCW3T6H43LqjDZQ==
365+
dependencies:
366+
callsites "^3.1.0"
367+
debug "^4.2.0"
368+
is-observable "^2.1.0"
369+
observable-fns "^0.6.1"
370+
optionalDependencies:
371+
tiny-worker ">= 2"
372+
373+
"tiny-worker@>= 2":
374+
version "2.3.0"
375+
resolved "https://registry.yarnpkg.com/tiny-worker/-/tiny-worker-2.3.0.tgz#715ae34304c757a9af573ae9a8e3967177e6011e"
376+
integrity sha512-pJ70wq5EAqTAEl9IkGzA+fN0836rycEuz2Cn6yeZ6FRzlVS5IDOkFHpIoEsksPRQV34GDqXm65+OlnZqUSyK2g==
377+
dependencies:
378+
esm "^3.2.25"
379+
324380
typescript@^5.2.2:
325381
version "5.4.5"
326382
resolved "https://registry.yarnpkg.com/typescript/-/typescript-5.4.5.tgz#42ccef2c571fdbd0f6718b1d1f5e6e5ef006f611"

0 commit comments

Comments
 (0)