Skip to content

Commit 6761a15

Browse files
authored
#3103 Fix error trying to load model with .onnx_data file (#84)
<!-- Thank you so much for your contribution! Please fill in all the sections below. Please open the PR as a draft initially. Once it is reviewed and approved, we will ask you to add documentation and examples. Please note that PRs with breaking changes or without tests will be rejected. Please note that PRs will be reviewed based on the priority of the issues they address. We ask for your patience. We are doing our best to review your PR as quickly as possible. Please refrain from pinging and asking when it will be reviewed. Thank you for understanding! --> ## Issue Closes #3103 ## Change The OnnxBertBiEncoder now accepts Paths directly (instead of only InputStreams) ## General checklist <!-- Please double-check the following points and mark them like this: [X] --> - [x] There are no breaking changes - [ ] I have added unit and/or integration tests for my change - [ ] The tests cover both positive and negative cases - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green <!-- Before adding documentation and example(s) (below), please wait until the PR is reviewed and approved. --> - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
1 parent c721970 commit 6761a15

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

langchain4j-embeddings/src/main/java/dev/langchain4j/model/embedding/onnx/AbstractInProcessEmbeddingModel.java

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import dev.langchain4j.model.output.Response;
88
import dev.langchain4j.model.output.TokenUsage;
99

10-
import java.io.IOException;
1110
import java.io.InputStream;
1211
import java.nio.file.Path;
1312
import java.util.ArrayList;
@@ -16,7 +15,6 @@
1615

1716
import static dev.langchain4j.internal.Utils.getOrDefault;
1817
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
19-
import static java.nio.file.Files.newInputStream;
2018
import static java.util.Collections.singletonList;
2119
import static java.util.concurrent.CompletableFuture.supplyAsync;
2220
import static java.util.concurrent.TimeUnit.SECONDS;
@@ -48,19 +46,7 @@ protected static OnnxBertBiEncoder loadFromJar(String modelFileName, String toke
4846
}
4947

5048
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, Path pathToTokenizer, PoolingMode poolingMode) {
51-
try {
52-
return new OnnxBertBiEncoder(newInputStream(pathToModel), newInputStream(pathToTokenizer), poolingMode);
53-
} catch (IOException e) {
54-
throw new RuntimeException(e);
55-
}
56-
}
57-
58-
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, InputStream tokenizer, PoolingMode poolingMode) {
59-
try {
60-
return new OnnxBertBiEncoder(newInputStream(pathToModel), tokenizer, poolingMode);
61-
} catch (IOException e) {
62-
throw new RuntimeException(e);
63-
}
49+
return new OnnxBertBiEncoder(pathToModel, pathToTokenizer, poolingMode);
6450
}
6551

6652
protected abstract OnnxBertBiEncoder model();

langchain4j-embeddings/src/main/java/dev/langchain4j/model/embedding/onnx/OnnxBertBiEncoder.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.io.ByteArrayOutputStream;
1212
import java.io.IOException;
1313
import java.io.InputStream;
14+
import java.nio.file.Path;
1415
import java.util.*;
1516

1617
import static ai.onnxruntime.OnnxTensor.createTensor;
@@ -30,6 +31,18 @@ public class OnnxBertBiEncoder {
3031
private final HuggingFaceTokenizer tokenizer;
3132
private final PoolingMode poolingMode;
3233

34+
public OnnxBertBiEncoder(Path pathToModel, Path pathToTokenizer, PoolingMode poolingMode) {
35+
try {
36+
this.environment = OrtEnvironment.getEnvironment();
37+
this.session = environment.createSession(pathToModel.toString());
38+
this.expectedInputs = session.getInputNames();
39+
this.tokenizer = HuggingFaceTokenizer.newInstance(pathToTokenizer, singletonMap("padding", "false"));
40+
this.poolingMode = ensureNotNull(poolingMode, "poolingMode");
41+
} catch (Exception e) {
42+
throw new RuntimeException(e);
43+
}
44+
}
45+
3346
public OnnxBertBiEncoder(InputStream model, InputStream tokenizer, PoolingMode poolingMode) {
3447
try {
3548
this.environment = OrtEnvironment.getEnvironment();

0 commit comments

Comments
 (0)