Skip to content

Commit

Permalink
Get MNIST data running
Browse files Browse the repository at this point in the history
  • Loading branch information
isometriks committed Sep 6, 2024
1 parent decf538 commit b9f3936
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 5 deletions.
1 change: 1 addition & 0 deletions mnist.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
declare module 'mnist'
26 changes: 26 additions & 0 deletions mnist.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/semantic.min.css">
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Vite + TS</title>
</head>
<body>
<div id="mnist" class="ui container" style="text-align: center">
<div style="width: 56px; height: 56px; margin: 20px auto; padding-top: 40px; padding-bottom: 50px">
<canvas id="digit" style="zoom: 200%;"></canvas>
</div>

<div id="output"></div>
<button class="ui button primary" id="next">Next Digit</button>
<button class="ui button secondary" id="rerun">Run More Batches</button>

<div id="cycles"></div>
<div id="correctness"></div>
</div>
<div id="app"></div>
<script type="module" src="/src/mnist.ts"></script>
</body>
</html>
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"dependencies": {
"graphology": "^0.25.4",
"graphology-types": "^0.24.7",
"sigma": "^3.0.0-beta.26"
"mnist": "^1.1.0",
"sigma": "^3.0.0-beta.26",
"threads": "^1.7.0"
}
}
103 changes: 103 additions & 0 deletions src/mnist.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import Network from "./network/network.ts";
import Backpropagation from "./trainer/backpropagation.ts";
import { Pool, spawn, Worker } from "threads"
import mnist from 'mnist';
import { ActivationFunction } from "./neuron/neuron.ts";

document.querySelector<HTMLDivElement>('#app')!.innerHTML = `
<div id="graph" style="margin: 0 auto; width: 100%; height: 400px;"></div>
`

const network = new Network(
28 * 28,
{ neurons: 10 },
[
{ neurons: 100, activationFunction: ActivationFunction.Relu },
//{ neurons: 100, activationFunction: ActivationFunction.Sigmoid },
]
)

const trainer = new Backpropagation(network, 0.2)
const formatSamples = (mnistSamples) => {
return mnistSamples.map(({ input, output }) => {
return { inputs: input, outputs: output }
})
}

let cycles = 0

async function main() {
const pool = Pool(() => {
const worker = new Worker(new URL('./worker.ts', import.meta.url), {
type: 'module'
})

return spawn(worker)
}, 1)

for (let i=0; i<50; i++) {
pool.queue(async worker => {
const exportedNetwork = network.export()
const adjustments = await worker.trainBatch(exportedNetwork)
trainer.applyAdjustments(adjustments)

cycles++

document.getElementById("cycles")!.innerHTML = `<h4 style="margin: 10px 0">${cycles} Batches Run</h4>`
})
}

await pool.completed()
await pool.terminate()

for (const { inputs, outputs } of formatSamples(mnist.get(10))) {
const computed = network.compute(inputs)
const expected = mnist.toNumber(outputs)
const calculated = mnist.toNumber(computed)

console.log("Expecting", expected, " == ", calculated, computed)
}

showDigit()
document.getElementById("correctness")!.innerHTML = `Network Correctness: ${correctness().toFixed(3)}%`
}

function correctness(): number {
let total = 0;
let correct = 0;

for (const { inputs, outputs } of formatSamples(mnist.set(0, 2000).test)) {
const computed = network.compute(inputs)
const expected = mnist.toNumber(outputs)
const calculated = mnist.toNumber(computed)

total++

if (expected === calculated) {
correct++
}
}

return (100 * correct / total)
}

function showDigit() {
const data = mnist.get(1)[0]

const computed = network.compute(data.input)
const expected = mnist.toNumber(data.output)
const calculated = mnist.toNumber(computed)

const context = (document.querySelector<HTMLCanvasElement>('#digit')!).getContext('2d');
document.getElementById("output")!.innerHTML = `
<h5 style="margin: 5px;">Digit: ${expected} - Network Calculated: ${calculated}</h5>
`
mnist.draw(data.input, context); // draws a '1' mnist digit in the canvas
}

document.getElementById("next")!.addEventListener("click", showDigit)
document.getElementById("rerun")!.addEventListener("click", async () => {
await main()
})

await main()
2 changes: 1 addition & 1 deletion src/network/network.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ interface HiddenLayerConfiguration extends LayerConfiguration {
bias?: number
}

interface NetworkExport {
export interface NetworkExport {
shape: {
inputs: number,
outputs: LayerConfiguration,
Expand Down
9 changes: 6 additions & 3 deletions src/trainer/backpropagation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export default class Backpropagation {

train(sample: TrainingSample) {
const adjustments = this.#getAdjustments(sample.inputs, sample.outputs)
this.#applyAdjustments(adjustments)
this.applyAdjustments(adjustments)
}

trainBatch(samples: TrainingSample[]) {
Expand Down Expand Up @@ -43,7 +43,9 @@ export default class Backpropagation {
})
}

this.#applyAdjustments(finalAdjustment)
this.applyAdjustments(finalAdjustment)

return finalAdjustment
}

#getAdjustments(inputs: number[], outputs: number[]) {
Expand Down Expand Up @@ -72,7 +74,8 @@ export default class Backpropagation {
return adjustments
}

#applyAdjustments(adjustments: Adjustments) {

applyAdjustments(adjustments: Adjustments) {
adjustments.forEach((neurons, layerIndex) => {
neurons.forEach((synapses, neuronIndex) => {
synapses.forEach((adjustment, synapseIndex) => {
Expand Down
23 changes: 23 additions & 0 deletions src/worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { expose } from "threads/worker"
import Network, { NetworkExport } from "./network/network.ts";
import Backpropagation from "./trainer/backpropagation.ts";
import mnist from 'mnist';

const formatSamples = (mnistSamples: { input: number[], output: number[] }[]) => {
return mnistSamples.map(({ input, output }) => {
return { inputs: input, outputs: output }
})
}

expose({
trainBatch(exportedNetwork: NetworkExport) {
const network = Network.fromNetworkExport(exportedNetwork)
const trainer = new Backpropagation(network, 0.2)

console.time("trainWorker")
const adjustments = trainer.trainBatch(formatSamples(mnist.get(100)))
console.timeEnd("trainWorker")

return adjustments
}
})
1 change: 1 addition & 0 deletions tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"module": "ESNext",
"lib": ["ES2023", "DOM", "DOM.Iterable"],
"skipLibCheck": true,
"noImplicitAny": false,

/* Bundler mode */
"moduleResolution": "bundler",
Expand Down
56 changes: 56 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@
resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.5.tgz#a6ce3e556e00fd9895dd872dd172ad0d4bd687f4"
integrity sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==

callsites@^3.1.0:
version "3.1.0"
resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73"
integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==

debug@^4.2.0:
version "4.3.6"
resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.6.tgz#2ab2c38fbaffebf8aa95fdfe6d88438c7a13c52b"
integrity sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==
dependencies:
ms "2.1.2"

esbuild@^0.20.1:
version "0.20.2"
resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.20.2.tgz#9d6b2386561766ee6b5a55196c6d766d28c87ea1"
Expand Down Expand Up @@ -231,6 +243,11 @@ esbuild@^0.20.1:
"@esbuild/win32-ia32" "0.20.2"
"@esbuild/win32-x64" "0.20.2"

esm@^3.2.25:
version "3.2.25"
resolved "https://registry.yarnpkg.com/esm/-/esm-3.2.25.tgz#342c18c29d56157688ba5ce31f8431fbb795cc10"
integrity sha512-U1suiZ2oDVWv4zPO56S0NcR5QriEahGtdN2OR6FiOG4WJvcjBVFB0qI4+eKoWFH483PKGuLuu6V8Z4T5g63UVA==

events@^3.3.0:
version "3.3.0"
resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400"
Expand Down Expand Up @@ -259,6 +276,21 @@ graphology@^0.25.4:
events "^3.3.0"
obliterator "^2.0.2"

is-observable@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/is-observable/-/is-observable-2.1.0.tgz#5c8d733a0b201c80dff7bb7c0df58c6a255c7c69"
integrity sha512-DailKdLb0WU+xX8K5w7VsJhapwHLZ9jjmazqCJq4X12CTgqq73TKnbRcnSLuXYPOoLQgV5IrD7ePiX/h1vnkBw==

mnist@^1.1.0:
version "1.1.0"
resolved "https://registry.yarnpkg.com/mnist/-/mnist-1.1.0.tgz#b83efc6af88d8db53b196665acdb50cf524bd2ba"
integrity sha512-x+SfS5tSJnOEjad6jkuL91Cuq9EAVmy1IYt6vdWEsYCjJZsj0oB2xCq+sg5cgiEt9W+xBT31q9IwI2h7h1VtOw==

[email protected]:
version "2.1.2"
resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"
integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==

nanoid@^3.3.7:
version "3.3.7"
resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.7.tgz#d0c301a691bc8d54efa0a2226ccf3fe2fd656bd8"
Expand All @@ -269,6 +301,11 @@ obliterator@^2.0.2:
resolved "https://registry.yarnpkg.com/obliterator/-/obliterator-2.0.4.tgz#fa650e019b2d075d745e44f1effeb13a2adbe816"
integrity sha512-lgHwxlxV1qIg1Eap7LgIeoBWIMFibOjbrYPIPJZcI1mmGAI2m3lNYpK12Y+GBdPQ0U1hRwSord7GIaawz962qQ==

observable-fns@^0.6.1:
version "0.6.1"
resolved "https://registry.yarnpkg.com/observable-fns/-/observable-fns-0.6.1.tgz#636eae4fdd1132e88c0faf38d33658cc79d87e37"
integrity sha512-9gRK4+sRWzeN6AOewNBTLXir7Zl/i3GB6Yl26gK4flxz8BXVpD3kt8amREmWNb0mxYOGDotvE5a4N+PtGGKdkg==

picocolors@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c"
Expand Down Expand Up @@ -321,6 +358,25 @@ source-map-js@^1.2.0:
resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.2.0.tgz#16b809c162517b5b8c3e7dcd315a2a5c2612b2af"
integrity sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==

threads@^1.7.0:
version "1.7.0"
resolved "https://registry.yarnpkg.com/threads/-/threads-1.7.0.tgz#d9e9627bfc1ef22ada3b733c2e7558bbe78e589c"
integrity sha512-Mx5NBSHX3sQYR6iI9VYbgHKBLisyB+xROCBGjjWm1O9wb9vfLxdaGtmT/KCjUqMsSNW6nERzCW3T6H43LqjDZQ==
dependencies:
callsites "^3.1.0"
debug "^4.2.0"
is-observable "^2.1.0"
observable-fns "^0.6.1"
optionalDependencies:
tiny-worker ">= 2"

"tiny-worker@>= 2":
version "2.3.0"
resolved "https://registry.yarnpkg.com/tiny-worker/-/tiny-worker-2.3.0.tgz#715ae34304c757a9af573ae9a8e3967177e6011e"
integrity sha512-pJ70wq5EAqTAEl9IkGzA+fN0836rycEuz2Cn6yeZ6FRzlVS5IDOkFHpIoEsksPRQV34GDqXm65+OlnZqUSyK2g==
dependencies:
esm "^3.2.25"

typescript@^5.2.2:
version "5.4.5"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-5.4.5.tgz#42ccef2c571fdbd0f6718b1d1f5e6e5ef006f611"
Expand Down

0 comments on commit b9f3936

Please sign in to comment.