Skip to content

Commit

Permalink
Merge pull request #71 from OpenMined/vvm/update-serde
Browse files Browse the repository at this point in the history
Update serde, fix with-node example
  • Loading branch information
cereallarceny authored Dec 29, 2019
2 parents e1e30fb + e5964f6 commit 45bb344
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 10 deletions.
5 changes: 3 additions & 2 deletions examples/with-grid/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@

<title>syft.js Example</title>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<!-- NOTE: TFJS version must match with one in package-lock.json -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://webrtc.github.io/adapter/adapter-latest.js"></script>
</head>
<body>
Expand Down Expand Up @@ -71,7 +72,7 @@ <h1>syft.js/grid.js testing</h1>
>.
</p>
<input type="text" id="grid-server" value="ws://localhost:3000" />
<input type="text" id="protocol" value="5259950754" />
<input type="text" id="protocol" value="50801316202" />
<button id="connect">Connect to grid.js server</button>
<div id="app">
<button id="disconnect">Disconnect</button>
Expand Down
17 changes: 10 additions & 7 deletions jest-globals.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
global.tf = {
tensor: (value, shape, type) => {
console.log('Created a tensor!', value, shape, type);
},
add: jest.fn(),
abs: jest.fn()
};
import * as tf from '@tensorflow/tfjs';
const tensor = tf.tensor;
jest.spyOn(tf, 'tensor').mockImplementation((values, shape, dtype) => {
let t = tensor(values, shape, dtype);
// override id to always return same value
Object.defineProperty(t, 'id', {
value: 42
});
return t;
});
1 change: 1 addition & 0 deletions src/_helpers.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { TorchTensor } from './types/torch';
import PointerTensor from './types/pointer-tensor';
import { CANNOT_FIND_COMMAND } from './_errors';
import * as tf from '@tensorflow/tfjs';

export const pickTensors = tree => {
const objects = {};
Expand Down
2 changes: 1 addition & 1 deletion src/serde.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export const detail = data => {
[proto['torch.Size']]: d => new TorchSize(d),
[proto['syft.messaging.plan.plan.Plan']]: d => new Plan(...d.map(i => parse(i))),
[proto['syft.messaging.plan.state.State']]: d => new State(...d.map(i => parse(i))),
[proto['syft.messaging.plan.procedure.Procedure']]: d => new Procedure(d[0].map(i => parse(i)), ...d.slice(1).map(i => parse(i))),
[proto['syft.messaging.plan.procedure.Procedure']]: d => new Procedure(...d.map(i => parse(i))),
[proto['syft.messaging.protocol.Protocol']]: d => new Protocol(...d.map(i => parse(i))),
[proto['syft.generic.pointers.pointer_tensor.PointerTensor']]: d => new PointerTensor(...d.map(i => parse(i))),
[proto['syft.messaging.message.Message']]: d => new Message(...d.map(i => parse(i))),
Expand Down
1 change: 1 addition & 0 deletions src/types/torch.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { default as proto } from '../proto';
import * as tf from '@tensorflow/tfjs';

export class TorchTensor {
constructor(id, bin, chain, gradChain, tags, description, serializer) {
Expand Down

0 comments on commit 45bb344

Please sign in to comment.