-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
decf538
commit f1f0472
Showing
6 changed files
with
217 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
<!doctype html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8" /> | ||
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> | ||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/semantic.min.css"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | ||
<title>Vite + TS</title> | ||
</head> | ||
<body> | ||
<div id="mnist" class="ui container" style="text-align: center"> | ||
<div style="width: 56px; height: 56px; margin: 20px auto; padding-top: 40px; padding-bottom: 50px"> | ||
<canvas id="digit" style="zoom: 200%;"></canvas> | ||
</div> | ||
|
||
<div id="output"></div> | ||
<button class="ui button primary" id="next">Next Digit</button> | ||
<button class="ui button secondary" id="rerun">Run More Batches</button> | ||
|
||
<div id="cycles"></div> | ||
<div id="correctness"></div> | ||
</div> | ||
<div id="app"></div> | ||
<script type="module" src="/src/mnist.ts"></script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import Network from "./network/network.ts"; | ||
import Backpropagation from "./trainer/backpropagation.ts"; | ||
import { Pool, spawn, Worker } from "threads" | ||
import mnist from 'mnist'; | ||
import { ActivationFunction } from "./neuron/neuron.ts"; | ||
|
||
document.querySelector<HTMLDivElement>('#app')!.innerHTML = ` | ||
<div id="graph" style="margin: 0 auto; width: 100%; height: 400px;"></div> | ||
` | ||
|
||
const network = new Network( | ||
28 * 28, | ||
{ neurons: 10 }, | ||
[ | ||
{ neurons: 100, activationFunction: ActivationFunction.Relu }, | ||
//{ neurons: 100, activationFunction: ActivationFunction.Sigmoid }, | ||
] | ||
) | ||
|
||
const trainer = new Backpropagation(network, 0.2) | ||
const formatSamples = (mnistSamples) => { | ||
return mnistSamples.map(({ input, output }) => { | ||
Check failure on line 22 in src/mnist.ts GitHub Actions / deploy
|
||
return { inputs: input, outputs: output } | ||
}) | ||
} | ||
|
||
let cycles = 0 | ||
|
||
async function main() { | ||
const pool = Pool(() => { | ||
const worker = new Worker(new URL('./worker.ts', import.meta.url), { | ||
type: 'module' | ||
}) | ||
|
||
return spawn(worker) | ||
}, 1) | ||
|
||
for (let i=0; i<50; i++) { | ||
pool.queue(async worker => { | ||
const exportedNetwork = network.export() | ||
const adjustments = await worker.trainBatch(exportedNetwork) | ||
trainer.applyAdjustments(adjustments) | ||
|
||
cycles++ | ||
|
||
document.getElementById("cycles")!.innerHTML = `<h4 style="margin: 10px 0">${cycles} Batches Run</h4>` | ||
}) | ||
} | ||
|
||
await pool.completed() | ||
await pool.terminate() | ||
|
||
for (const { inputs, outputs } of formatSamples(mnist.get(10))) { | ||
const computed = network.compute(inputs) | ||
const expected = mnist.toNumber(outputs) | ||
const calculated = mnist.toNumber(computed) | ||
|
||
console.log("Expecting", expected, " == ", calculated, computed) | ||
} | ||
|
||
showDigit() | ||
document.getElementById("correctness")!.innerHTML = `Network Correctness: ${correctness().toFixed(3)}%` | ||
} | ||
|
||
function correctness(): number { | ||
let total = 0; | ||
let correct = 0; | ||
|
||
for (const { inputs, outputs } of formatSamples(mnist.set(0, 2000).test)) { | ||
const computed = network.compute(inputs) | ||
const expected = mnist.toNumber(outputs) | ||
const calculated = mnist.toNumber(computed) | ||
|
||
total++ | ||
|
||
if (expected === calculated) { | ||
correct++ | ||
} | ||
} | ||
|
||
return (100 * correct / total) | ||
} | ||
|
||
function showDigit() { | ||
const data = mnist.get(1)[0] | ||
|
||
const computed = network.compute(data.input) | ||
const expected = mnist.toNumber(data.output) | ||
const calculated = mnist.toNumber(computed) | ||
|
||
const context = (document.getElementById<HTMLCanvasElement>('digit')).getContext('2d'); | ||
Check failure on line 91 in src/mnist.ts GitHub Actions / deploy
Check failure on line 91 in src/mnist.ts GitHub Actions / deploy
|
||
document.getElementById("output")!.innerHTML = ` | ||
<h5 style="margin: 5px;">Digit: ${expected} - Network Calculated: ${calculated}</h5> | ||
` | ||
mnist.draw(data.input, context); // draws a '1' mnist digit in the canvas | ||
} | ||
|
||
document.getElementById("next")!.addEventListener("click", showDigit) | ||
document.getElementById("rerun")!.addEventListener("click", async () => { | ||
await main() | ||
}) | ||
|
||
await main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import { expose } from "threads/worker" | ||
import Network from "./network/network.ts"; | ||
import Backpropagation from "./trainer/backpropagation.ts"; | ||
import mnist from 'mnist'; | ||
|
||
const formatSamples = (mnistSamples) => { | ||
return mnistSamples.map(({ input, output }) => { | ||
return { inputs: input, outputs: output } | ||
}) | ||
} | ||
|
||
expose({ | ||
trainBatch(exportedNetwork) { | ||
const network = Network.fromNetworkExport(exportedNetwork) | ||
const trainer = new Backpropagation(network, 0.2) | ||
|
||
console.time("trainWorker") | ||
const adjustments = trainer.trainBatch(formatSamples(mnist.get(100))) | ||
console.timeEnd("trainWorker") | ||
|
||
return adjustments | ||
} | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -202,6 +202,18 @@ | |
resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.5.tgz#a6ce3e556e00fd9895dd872dd172ad0d4bd687f4" | ||
integrity sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw== | ||
|
||
callsites@^3.1.0: | ||
version "3.1.0" | ||
resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73" | ||
integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ== | ||
|
||
debug@^4.2.0: | ||
version "4.3.6" | ||
resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.6.tgz#2ab2c38fbaffebf8aa95fdfe6d88438c7a13c52b" | ||
integrity sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg== | ||
dependencies: | ||
ms "2.1.2" | ||
|
||
esbuild@^0.20.1: | ||
version "0.20.2" | ||
resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.20.2.tgz#9d6b2386561766ee6b5a55196c6d766d28c87ea1" | ||
|
@@ -231,6 +243,11 @@ esbuild@^0.20.1: | |
"@esbuild/win32-ia32" "0.20.2" | ||
"@esbuild/win32-x64" "0.20.2" | ||
|
||
esm@^3.2.25: | ||
version "3.2.25" | ||
resolved "https://registry.yarnpkg.com/esm/-/esm-3.2.25.tgz#342c18c29d56157688ba5ce31f8431fbb795cc10" | ||
integrity sha512-U1suiZ2oDVWv4zPO56S0NcR5QriEahGtdN2OR6FiOG4WJvcjBVFB0qI4+eKoWFH483PKGuLuu6V8Z4T5g63UVA== | ||
|
||
events@^3.3.0: | ||
version "3.3.0" | ||
resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400" | ||
|
@@ -259,6 +276,21 @@ graphology@^0.25.4: | |
events "^3.3.0" | ||
obliterator "^2.0.2" | ||
|
||
is-observable@^2.1.0: | ||
version "2.1.0" | ||
resolved "https://registry.yarnpkg.com/is-observable/-/is-observable-2.1.0.tgz#5c8d733a0b201c80dff7bb7c0df58c6a255c7c69" | ||
integrity sha512-DailKdLb0WU+xX8K5w7VsJhapwHLZ9jjmazqCJq4X12CTgqq73TKnbRcnSLuXYPOoLQgV5IrD7ePiX/h1vnkBw== | ||
|
||
mnist@^1.1.0: | ||
version "1.1.0" | ||
resolved "https://registry.yarnpkg.com/mnist/-/mnist-1.1.0.tgz#b83efc6af88d8db53b196665acdb50cf524bd2ba" | ||
integrity sha512-x+SfS5tSJnOEjad6jkuL91Cuq9EAVmy1IYt6vdWEsYCjJZsj0oB2xCq+sg5cgiEt9W+xBT31q9IwI2h7h1VtOw== | ||
|
||
[email protected]: | ||
version "2.1.2" | ||
resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" | ||
integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== | ||
|
||
nanoid@^3.3.7: | ||
version "3.3.7" | ||
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.7.tgz#d0c301a691bc8d54efa0a2226ccf3fe2fd656bd8" | ||
|
@@ -269,6 +301,11 @@ obliterator@^2.0.2: | |
resolved "https://registry.yarnpkg.com/obliterator/-/obliterator-2.0.4.tgz#fa650e019b2d075d745e44f1effeb13a2adbe816" | ||
integrity sha512-lgHwxlxV1qIg1Eap7LgIeoBWIMFibOjbrYPIPJZcI1mmGAI2m3lNYpK12Y+GBdPQ0U1hRwSord7GIaawz962qQ== | ||
|
||
observable-fns@^0.6.1: | ||
version "0.6.1" | ||
resolved "https://registry.yarnpkg.com/observable-fns/-/observable-fns-0.6.1.tgz#636eae4fdd1132e88c0faf38d33658cc79d87e37" | ||
integrity sha512-9gRK4+sRWzeN6AOewNBTLXir7Zl/i3GB6Yl26gK4flxz8BXVpD3kt8amREmWNb0mxYOGDotvE5a4N+PtGGKdkg== | ||
|
||
picocolors@^1.0.0: | ||
version "1.0.0" | ||
resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" | ||
|
@@ -321,6 +358,25 @@ source-map-js@^1.2.0: | |
resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.2.0.tgz#16b809c162517b5b8c3e7dcd315a2a5c2612b2af" | ||
integrity sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg== | ||
|
||
threads@^1.7.0: | ||
version "1.7.0" | ||
resolved "https://registry.yarnpkg.com/threads/-/threads-1.7.0.tgz#d9e9627bfc1ef22ada3b733c2e7558bbe78e589c" | ||
integrity sha512-Mx5NBSHX3sQYR6iI9VYbgHKBLisyB+xROCBGjjWm1O9wb9vfLxdaGtmT/KCjUqMsSNW6nERzCW3T6H43LqjDZQ== | ||
dependencies: | ||
callsites "^3.1.0" | ||
debug "^4.2.0" | ||
is-observable "^2.1.0" | ||
observable-fns "^0.6.1" | ||
optionalDependencies: | ||
tiny-worker ">= 2" | ||
|
||
"tiny-worker@>= 2": | ||
version "2.3.0" | ||
resolved "https://registry.yarnpkg.com/tiny-worker/-/tiny-worker-2.3.0.tgz#715ae34304c757a9af573ae9a8e3967177e6011e" | ||
integrity sha512-pJ70wq5EAqTAEl9IkGzA+fN0836rycEuz2Cn6yeZ6FRzlVS5IDOkFHpIoEsksPRQV34GDqXm65+OlnZqUSyK2g== | ||
dependencies: | ||
esm "^3.2.25" | ||
|
||
typescript@^5.2.2: | ||
version "5.4.5" | ||
resolved "https://registry.yarnpkg.com/typescript/-/typescript-5.4.5.tgz#42ccef2c571fdbd0f6718b1d1f5e6e5ef006f611" | ||
|