Skip to content

Commit 4b8f25b

Browse files
Fix the compile and run execution session in Python (onnx#2373)
Signed-off-by: Alexandre Eichenberger <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent 6bc651e commit 4b8f25b

16 files changed

+330
-134
lines changed

docs/doc_example/main.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ std::string readArgs(int argc, char *argv[]) {
1818
int main(int argc, char *argv[]) {
1919
// Read compiler options from command line and compile the doc example into a
2020
// model library.
21-
const char *errorMessage = NULL;
22-
const char *compiledFilename;
21+
char *errorMessage = nullptr;
22+
char *compiledFilename = nullptr;
2323
std::string flags = readArgs(argc, argv);
2424
flags += "-o add-cpp-interface";
2525
std::cout << "Compile with options \"" << flags << "\"\n";
@@ -30,11 +30,15 @@ int main(int argc, char *argv[]) {
3030
if (errorMessage)
3131
std::cerr << " and message \"" << errorMessage << "\"";
3232
std::cerr << "." << std::endl;
33+
free(compiledFilename);
34+
free(errorMessage);
3335
return rc;
3436
}
3537
std::string libFilename(compiledFilename);
3638
std::cout << "Compiled succeeded with results in file: " << libFilename
3739
<< std::endl;
40+
free(compiledFilename);
41+
free(errorMessage);
3842

3943
// Prepare the execution session.
4044
onnx_mlir::ExecutionSession *session;

docs/mnist_example/README.md

+19-4
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,7 @@ The runtime use an `OMExecutionSession` object to hold a specific model and entr
279279

280280
```Python
281281
#Load the model mnist.so compiled with onnx-mlir.
282-
model = 'mnist.so'
283-
session = OMExecutionSession(model)
282+
session = OMExecutionSession('mnist.so')
284283
#Print the models input / output signature, for display.
285284
#If there are problems with the signature functions, \
286285
they can be simply commented out.
@@ -295,10 +294,10 @@ outputs = session.run([input])
295294

296295
The outputs can then be analyzed by inspecting the values inside the `output` list of numpy arrays.
297296

298-
The full code is available [here](mnist.py). It finds that `0` is the most likely digit for the given input. The command is:
297+
The full code is available [here](mnist-runPyRuntime.py). It finds that `0` is the most likely digit for the given input. The command is:
299298

300299
```shell
301-
./mnist.py
300+
./mnist-runPyRuntime.py
302301
```
303302

304303
and produces an output similar to the following (you may see slightly different prediction numbers if you train the model yourself):
@@ -321,6 +320,22 @@ prediction 9 = 8.650948e-15
321320
The digit is 0
322321
```
323322
323+
We provide two additional Python interfaces.
324+
The second interface extends the above execution session by simply compiling a model before loading it for execution (see [here](mnist-runPyCompileAndRuntime.py)).
325+
The user simply passes the `.onnx` model and the flags needed to compile the model.
326+
Unless explicitly disabled by the `reuse_compiled_model=0`, the execution session will reuse a previously compiled model whose name matches the name the output file generated by the compiler.
327+
Note that the execution session does not check if the cached version was compiled using identical compiler flags; it is the responsibility of the user to then clear the cached version, or disable the reuse using the provided optional flag.
328+
329+
For example, the code below will compile and load the `mnist.onnx` model, compiling only when the `mnist2.so` binary file cannot be located. Model inference can then proceed using the `session.run(...)` command.
330+
331+
```Python
332+
# Load onnx model and create CompileExecutionSession object,
333+
# by first compiling the mnist.onnx model with the "-O3" options.
334+
session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2")
335+
```
336+
337+
The third interface provides a simple interface to explicitly compile an onnx model (see [here](mnist-compile.py)).
338+
324339
## Write a Java Driver Code
325340
326341
Inference APIs and data structures for Java closely mirror those for C/C++. Documentation of the APIs are found [here](https://onnx.ai/onnx-mlir/doxygen_html/OMModel_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_model.html), with the Java interface for Tensor [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensor_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor.html) and TensorList [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensorList_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor_list.html).

docs/mnist_example/mnist-compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
file = './mnist.onnx'
88
compiler = OMCompileSession(file)
99
# Generate the library file. Success when rc == 0 while set the opt as "-O3"
10-
rc = compiler.compile("-O3")
10+
rc = compiler.compile("-O3 -o mnist")
1111
# Get the output file name
1212
model = compiler.get_compiled_file_name()
1313
if rc:

docs/mnist_example/mnist-runPyCompileAndRuntime.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import numpy as np
44
from PyCompileAndRuntime import OMCompileExecutionSession
55

6-
# Load onnx model and create CompileExecutionSession object.
7-
inputFileName = './mnist.onnx'
8-
# Set the full name of compiled model
9-
sharedLibPath = './mnist.so'
10-
# Set the compile option as "-O3"
11-
session = OMCompileExecutionSession(inputFileName,sharedLibPath,"-O3")
12-
6+
# Load onnx model and create CompileExecutionSession object,
7+
# by first compiling the mnist.onnx model with the "-O3" options.
8+
session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2",
9+
reuse_compiled_model=1)
10+
if session.get_compiled_result():
11+
print("error with :" + session.get_error_message())
12+
exit(1)
1313
# Print the models input/output signature, for display.
1414
# Signature functions for info only, commented out if they cause problems.
1515
print("input signature in json", session.input_signature())

docs/mnist_example/mnist-runPyRuntime.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from PyRuntime import OMExecutionSession
55

66
# Load the model mnist.so compiled with onnx-mlir.
7-
model = './mnist.so'
8-
session = OMExecutionSession(model)
7+
session = OMExecutionSession('./mnist.so')
98
# Print the models input/output signature, for display.
109
# Signature functions for info only, commented out if they cause problems.
1110
print("input signature in json", session.input_signature())

include/OnnxMlirCompiler.h

+25-10
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,18 @@ namespace onnx_mlir {
6060
* Name may include a path, and must include the file name and its extention.
6161
*
6262
* @param outputFilename Output file name of the compiled output for the given
63-
* emission target. User is responsible for freeing the string.
63+
* emission target. User is responsible for freeing the string.
6464
*
6565
* @param flags A char * contains all the options provided to compile the
66-
* model.
66+
* model.
6767
*
6868
* @param errorMessage Output error message, if any. User is responsible for
69-
* freeing the string.
69+
* freeing the string.
7070
*
7171
* @return 0 on success or OnnxMlirCompilerErrorCodes on failure.
7272
*/
7373
ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
74-
const char *flags, const char **outputFilename, const char **errorMessage);
74+
const char *flags, char **outputFilename, char **errorMessage);
7575

7676
/*!
7777
* Compile an onnx model from an ONNX protobuf array. This method is not thread
@@ -85,18 +85,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
8585
* @param bufferSize Size of ONNX protobuf array.
8686
* @param outputBaseName File name without extension to write output.
8787
* Name may include a path, must include the file name, and should not include
88-
* an extention.
88+
* an extention.
8989
* @param emissionTarget Target format to compile to.
9090
* @param outputFilename Output file name of the compiled output for the given
91-
* emission target. User is responsible for freeing the string.
92-
* @param errorMessage Error message.
91+
* emission target. User is responsible for freeing the string.
92+
* @param errorMessage Error message, if any. User is responsible for freeing
93+
* the string.
9394
* @return 0 on success or OnnxMlirCompilerErrorCodes failure. User is
94-
* responsible for freeing the string.
95+
* responsible for freeing the string.
9596
*/
9697
ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
9798
int64_t bufferSize, const char *outputBaseName,
98-
EmissionTargetType emissionTarget, const char **outputFilename,
99-
const char **errorMessage);
99+
EmissionTargetType emissionTarget, char **outputFilename,
100+
char **errorMessage);
101+
102+
/*!
103+
* Compute the file name of the compiled output for the given
104+
* emission target. User is responsible for freeing the string.
105+
*
106+
* @param inputFilename File name pointing onnx model protobuf or MLIR.
107+
* Name may include a path, and must include the file name and its extention.
108+
* @param flags A char * contains all the options provided to compile the
109+
* model.
110+
* @return string containing the file name. User is responsible for freeing the
111+
* string.
112+
*/
113+
ONNX_MLIR_EXPORT char *omCompileOutputFileName(
114+
const char *inputFilename, const char *flags);
100115

101116
#ifdef __cplusplus
102117
} // namespace onnx_mlir

src/Compiler/OnnxMlirCompiler.cpp

+62-15
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,27 @@ using namespace onnx_mlir;
1818

1919
namespace onnx_mlir {
2020

21+
// Derive the name; base name is either given by a "-o" option, or is taken as
22+
// the model name. The extention depends on the target; e.g. -EmitLib will
23+
// generate a .so, other targets may generate a .mlir.
2124
static std::string deriveOutputFileName(
2225
std::vector<std::string> &flagVect, std::string inputFilename) {
2326
// Get output file name.
2427
std::string outputBasename;
2528
int num = flagVect.size();
26-
for (int i = 0; i < num - 1;
27-
++i) { // Skip last as need 2 consecutive entries.
28-
if (flagVect[i].find("-o") == 0) {
29-
outputBasename = flagVect[i + 1];
30-
break;
29+
for (int i = 0; i < num; ++i) {
30+
if (flagVect[i].find("-o=", 0, 3) == 0) {
31+
if (flagVect[i].length() > 3) {
32+
outputBasename = flagVect[i].substr(3);
33+
break;
34+
} else
35+
llvm::errs() << "Parsing `-o=` option, expected a name. Use default.\n";
36+
} else if (flagVect[i].find("-o") == 0) {
37+
if (i < num - 1) {
38+
outputBasename = flagVect[i + 1];
39+
break;
40+
} else
41+
llvm::errs() << "Parsing `-o` option, expected a name. Use default.\n";
3142
}
3243
}
3344
// If no output file name, derive it from input file name
@@ -56,12 +67,7 @@ static std::string deriveOutputFileName(
5667
return getTargetFilename(outputBasename, emissionTarget);
5768
}
5869

59-
extern "C" {
60-
61-
ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
62-
const char *flags, const char **outputFilename, const char **errorMessage) {
63-
// Process the flags, saving each space-separated text in a separate
64-
// entry in the string vector flagVect.
70+
static std::vector<std::string> parseFlags(const char *flags) {
6571
std::vector<std::string> flagVect;
6672
const char *str = flags;
6773
do {
@@ -76,6 +82,23 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
7682
if (begin != str)
7783
flagVect.push_back(std::string(begin, str));
7884
} while (*str);
85+
return flagVect;
86+
}
87+
88+
extern "C" {
89+
90+
ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
91+
const char *flags, char **outputFilename, char **errorMessage) {
92+
// Ensure known values in filename and error message if provided.
93+
if (outputFilename)
94+
*outputFilename = nullptr;
95+
if (errorMessage)
96+
*errorMessage = nullptr;
97+
98+
// Process the flags, saving each space-separated text in a separate
99+
// entry in the string vector flagVect.
100+
std::vector<std::string> flagVect = parseFlags(flags);
101+
79102
// Use 'onnx-mlir' command to compile the model.
80103
std::string onnxMlirPath;
81104
const auto &envDir = getEnvVar("ONNX_MLIR_BIN_PATH");
@@ -90,17 +113,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
90113
onnxMlirCompile.appendStr(inputFilenameStr);
91114
// Run command.
92115
int rc = onnxMlirCompile.exec();
93-
if (rc == CompilerSuccess && outputFilename) {
116+
if (rc != CompilerSuccess) {
117+
// Failure to compile.
118+
if (errorMessage) {
119+
std::string errorStr =
120+
"Compiler failed with error code " + std::to_string(rc);
121+
*errorMessage = strdup(errorStr.c_str());
122+
}
123+
return CompilerFailureInLLVMOpt;
124+
}
125+
// Success.
126+
if (outputFilename) {
94127
std::string name = deriveOutputFileName(flagVect, inputFilenameStr);
95128
*outputFilename = strdup(name.c_str());
96129
}
97-
return rc != 0 ? CompilerFailureInLLVMOpt : CompilerSuccess;
130+
return CompilerSuccess;
98131
}
99132

100133
ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
101134
int64_t bufferSize, const char *outputBaseName,
102-
EmissionTargetType emissionTarget, const char **outputFilename,
103-
const char **errorMessage) {
135+
EmissionTargetType emissionTarget, char **outputFilename,
136+
char **errorMessage) {
137+
// Ensure known values in filename and error message if provided.
138+
if (outputFilename)
139+
*outputFilename = nullptr;
140+
if (errorMessage)
141+
*errorMessage = nullptr;
142+
104143
mlir::OwningOpRef<mlir::ModuleOp> module;
105144
mlir::MLIRContext context;
106145
registerDialects(context);
@@ -124,5 +163,13 @@ ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
124163
return rc;
125164
}
126165

166+
ONNX_MLIR_EXPORT char *omCompileOutputFileName(
167+
const char *inputFilename, const char *flags) {
168+
std::vector<std::string> flagVect = parseFlags(flags);
169+
std::string inputFilenameStr(inputFilename);
170+
std::string name = deriveOutputFileName(flagVect, inputFilenameStr);
171+
return strdup(name.c_str());
172+
}
173+
127174
} // extern C
128175
} // namespace onnx_mlir

src/Compiler/PyOMCompileSession.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) {
3535
"No OMCompileSession was created with the input file name specified.";
3636
return -1;
3737
}
38-
const char *outputName, *errorMsg;
38+
char *outputName = nullptr;
39+
char *errorMsg = nullptr;
3940
int64_t rc;
4041
rc = omCompileFromFile(
4142
inputFileName.c_str(), flags.c_str(), &outputName, &errorMsg);
@@ -50,6 +51,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) {
5051
// Empty output file name.
5152
outputFileName = std::string();
5253
}
54+
free(outputName);
55+
free(errorMsg);
5356
return rc;
5457
}
5558

@@ -60,7 +63,8 @@ int64_t PyOMCompileSession::pyCompileFromArray(
6063
"No OMCompileSession was created with the input buffer specified.";
6164
return -1;
6265
}
63-
const char *outputName, *errorMsg;
66+
char *outputName = nullptr;
67+
char *errorMsg = nullptr;
6468
int64_t rc;
6569
rc = omCompileFromArray(inputBuffer, inputBufferSize, outputBaseName.c_str(),
6670
emissionTarget, &outputName, &errorMsg);
@@ -75,6 +79,8 @@ int64_t PyOMCompileSession::pyCompileFromArray(
7579
// Empty output file name.
7680
outputFileName = std::string();
7781
}
82+
free(outputName);
83+
free(errorMsg);
7884
return rc;
7985
}
8086

0 commit comments

Comments
 (0)