Skip to content

Commit 1e6b10c

Browse files
committed
Fix .tensor() method; add tests
1 parent bf84a41 commit 1e6b10c

File tree

6 files changed

+39
-3
lines changed

6 files changed

+39
-3
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "libtorchjs",
3-
"version": "1.0.0-alpha.3",
3+
"version": "1.0.0-alpha.4",
44
"description": "Node.js N-API wrapper for Libtorch",
55
"main": "lib/index.js",
66
"author": "Vova Manannikov <vova@promail.spb.ru>",

src/libtorchjs.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ namespace libtorchjs {
3434
Napi::Float32Array arr = info[0].As<Napi::Float32Array>();
3535
size_t elements = arr.ElementLength();
3636
// make torch tensor
37-
at::Tensor tensor = torch::tensor(at::ArrayRef<float>(arr.Data(), elements), torch::requires_grad(false));
37+
torch::TensorOptions options;
38+
at::Tensor tensor = torch::tensor(at::ArrayRef<float>(arr.Data(), elements), options);
3839
// napi tensor
3940
auto napiTensor = Tensor::NewInstance();
4041
Napi::ObjectWrap<Tensor>::Unwrap(napiTensor)->setTensor(tensor);

src/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ namespace libtorchjs {
5656
uint64_t size = this->tensor.numel();
5757
// make unit8 type tensor
5858
auto byteTensor = this->tensor.clamp(0, 255).to(at::ScalarType::Byte);
59-
auto byteData = byteTensor.contiguous().data<uint8_t>();
59+
auto byteData = byteTensor.contiguous().data_ptr<uint8_t>();
6060
// wrap in napi unit8 array
6161
auto arr = Napi::Uint8Array::New(env, size);
6262
for (uint64_t i = 0; i < size; i++) {

tests/data/make_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
3+
def mul2(x):
4+
return x * 2
5+
6+
torch.jit.trace(mul2, torch.randn(3, 3)).save("mul2.pt")

tests/data/mul2.pt

1.69 KB
Binary file not shown.

tests/tensor.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
const expect = require('chai').expect;
2+
const torch = require('../lib/');
3+
const path = require('path');
4+
5+
describe('LibtorchJS', function() {
6+
7+
it('ones', function() {
8+
const ones = torch.ones([2, 2]);
9+
expect([...ones.toUint8Array()]).to.deep.equal([1, 1, 1, 1]);
10+
});
11+
12+
it('tensor', function() {
13+
const arr = new Float32Array([1, 2, 3.3, 4]);
14+
const tensor = torch.tensor(arr);
15+
expect([...tensor.toUint8Array()]).to.deep.equal([1, 2, 3, 4]);
16+
});
17+
18+
it('load', function(done) {
19+
const input = torch.tensor(new Float32Array([2.5, 3.5]));
20+
torch.load(path.join(__dirname, 'data', 'mul2.pt'), function(err, model) {
21+
model.forward(input, function(err, result) {
22+
const output = result.toUint8Array();
23+
expect([...output]).to.deep.equal([5, 7]);
24+
done();
25+
});
26+
});
27+
});
28+
29+
});

0 commit comments

Comments
 (0)