Skip to content

Commit ba9cfb9

Browse files
committed
expose minimal required versions for ir and opset
Up to now, we set the opset import versions from the values supported by the onnx library. This is an incorrect behavior. We now store the maximum opset version required per domain, based on the operators used in the graph. We then use the onnx library to compute the minimal required IR version to execute the model, based on these opset versions. Fixes #22 Fixes #23
1 parent 3b33a46 commit ba9cfb9

File tree

10 files changed

+87
-84
lines changed

10 files changed

+87
-84
lines changed

ebm2onnx/convert.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
from enum import Enum
44
from copy import deepcopy
5-
from .utils import get_latest_opset_version
65
from ebm2onnx import graph
76
from ebm2onnx import ebm
87
import ebm2onnx.operators as ops
@@ -114,12 +113,13 @@ def to_graph(model, dtype, name="ebm",
114113
name: [Optional] The name of the model
115114
predict_proba: [Optional] For classification models, output prediction probabilities instead of class
116115
explain: [Optional] Adds an additional output with the score per feature per class
117-
target_opset: [Optional] The target onnx opset version to use
116+
target_opset: [Optional][Deprecated] The target onnx opset version to use
118117
119118
Returns:
120119
An ONNX model.
121120
"""
122-
target_opset = target_opset or get_latest_opset_version()
121+
if target_opset:
122+
logging.warning("to_graph: target_opset argument is deprecated")
123123
root = graph.create_graph(context=context)
124124

125125
inputs = [None for _ in model.feature_names_in_]
@@ -296,23 +296,24 @@ def to_onnx(model, dtype, name="ebm",
296296
name: [Optional] The name of the model
297297
predict_proba: [Optional] For classification models, output prediction probabilities instead of class
298298
explain: [Optional] Adds an additional output with the score per feature per class
299-
target_opset: [Optional] The target onnx opset version to use
299+
target_opset: [Optional][Deprecated] The target onnx opset version to use
300300
301301
Returns:
302302
An ONNX model.
303303
"""
304+
if target_opset:
305+
logging.warning("to_onnx: target_opset argument is deprecated")
304306
g = to_graph(
305307
model=model,
306308
dtype=dtype,
307309
name=name,
308310
predict_proba=predict_proba,
309311
explain=explain,
310-
target_opset=target_opset,
311312
prediction_name=prediction_name,
312313
probabilities_name=probabilities_name,
313314
explain_name=explain_name,
314315
context=context,
315316
)
316317

317-
model = graph.to_onnx(g, target_opset, name=name)
318+
model = graph.to_onnx(g, name=name)
318319
return model

ebm2onnx/graph.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import NamedTuple, Callable, Optional, List, Dict, Union
1+
import logging
2+
from typing import NamedTuple, Optional, List, Dict, Union
23

34
import onnx
4-
from ebm2onnx import __version__
5-
from .utils import get_latest_opset_version
5+
from onnx.helper import make_opsetid
66

77
from . import context as _context
88

@@ -14,6 +14,7 @@ class Graph(NamedTuple):
1414
transients: List[onnx.ValueInfoProto] = []
1515
nodes: List[onnx.NodeProto] = []
1616
initializers: List[onnx.TensorProto] = []
17+
opsets: Dict[str, int] = {}
1718

1819

1920
def extend(i, val):
@@ -52,12 +53,18 @@ def from_onnx(model) -> Graph:
5253
Returns:
5354
A Graph object.
5455
"""
56+
opsets = {
57+
op.domain: op.version
58+
for op in model.opset_import
59+
}
60+
5561
return Graph(
5662
context=_context.create(),
5763
inputs=[n for n in model.graph.input],
5864
outputs=[n for n in model.graph.output],
5965
nodes=[n for n in model.graph.node],
6066
initializers=[n for n in model.graph.initializer],
67+
opsets=opsets,
6168
)
6269

6370

@@ -74,13 +81,19 @@ def to_onnx(
7481
7582
Args:
7683
graph: The graph object
77-
target_opset: the target opset to use when converting ot onnx, can be an int or a dict
84+
target_opset: [Optional][Deprecated] the target opset to use when converting ot onnx, can be an int or a dict
7885
name: [Optional] An existing ONNX model
7986
8087
Returns:
8188
A Graph object.
8289
"""
83-
#outputs = graph.transients
90+
if target_opset:
91+
logging.warning("to_onnx: target_opset argument is deprecated")
92+
93+
opset_imports = [
94+
make_opsetid(domain=domain, version=version)
95+
for domain,version in graph.opsets.items()
96+
]
8497

8598
graph = onnx.helper.make_graph(
8699
nodes=graph.nodes,
@@ -89,32 +102,15 @@ def to_onnx(
89102
outputs=graph.outputs,
90103
initializer=graph.initializers,
91104
)
92-
model = onnx.helper.make_model(graph, producer_name='ebm2onnx')
93-
94-
#producer_name = "interpretml/ebm2onnx"
95-
#producer_version = __version__
96-
97-
#domain
98-
#model_version
99-
#doc_string
100105

101-
#metadata_props
102-
103-
# set opset versions
104-
if target_opset is not None:
105-
if type(target_opset) is int:
106-
model.opset_import[0].version = target_opset
107-
elif type(target_opset) is dict:
108-
del model.opset_import[:]
109-
110-
for k, v in target_opset.items():
111-
opset = model.opset_import.add()
112-
opset.domain = k
113-
opset.version = v
114-
else:
115-
raise ValueError(f"ebm2onnx.graph.to_onnx: invalid type for target_opset: {type(target_opset)}.")
116-
else:
117-
model.opset_import[0].version = get_latest_opset_version()
106+
# create the onnx model from the graph.
107+
# The onnx library will set the ir version to the minimal required ir that
108+
# is compatible with the opset_imports provided.
109+
model = onnx.helper.make_model_gen_version(
110+
graph,
111+
producer_name='ebm2onnx',
112+
opset_imports=opset_imports,
113+
)
118114

119115
return model
120116

@@ -210,4 +206,10 @@ def merge(*args):
210206
nodes=extend(g.nodes, graph.nodes),
211207
)
212208

209+
# merge opsets, keep higher version for each domain
210+
for domain,version in graph.opsets.items():
211+
cur_version = g.opsets.get(domain, -1)
212+
if version > cur_version:
213+
g.opsets[domain] = version
214+
213215
return g

ebm2onnx/operators.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import onnx
22
import ebm2onnx.graph as graph
3+
from .utils import opset
34

45

56
def add():
7+
@opset(version=14)
68
def _add(g):
79
add_result_name = g.context.generate_variable_name('add_result')
810
nodes = [
@@ -25,6 +27,7 @@ def _add(g):
2527

2628

2729
def argmax(axis=0, keepdims=1, select_last_index=0):
30+
@opset(version=13)
2831
def _argmax(g):
2932
argmax_result_name = g.context.generate_variable_name('argmax_result')
3033
nodes = [
@@ -48,6 +51,7 @@ def _argmax(g):
4851

4952

5053
def cast(to):
54+
@opset(version=13)
5155
def _cast(g):
5256
cast_result_name = g.context.generate_variable_name('cast_result')
5357
nodes = [
@@ -71,6 +75,7 @@ def _cast(g):
7175

7276

7377
def concat(axis):
78+
@opset(version=13)
7479
def _concat(g):
7580
concat_result_name = g.context.generate_variable_name('concat_result')
7681

@@ -96,6 +101,7 @@ def _concat(g):
96101

97102

98103
def expand():
104+
@opset(version=13)
99105
def _expand(g):
100106
expand_result_name = g.context.generate_variable_name('expand_result')
101107
nodes = [
@@ -118,6 +124,7 @@ def _expand(g):
118124

119125

120126
def flatten(axis=1):
127+
@opset(version=13)
121128
def _flatten(g):
122129
flatten_result_name = g.context.generate_variable_name('flatten_result')
123130
nodes = [
@@ -146,6 +153,7 @@ def gather(axis=0):
146153
- data
147154
- indices
148155
"""
156+
@opset(version=13)
149157
def _gather(g):
150158
gather_result_name = g.context.generate_variable_name('gather_result')
151159
nodes = [
@@ -169,6 +177,7 @@ def _gather(g):
169177

170178

171179
def gather_elements(axis=0):
180+
@opset(version=13)
172181
def _gather_elements(g):
173182
gather_elements_result_name = g.context.generate_variable_name('gather_elements_result')
174183
nodes = [
@@ -197,6 +206,7 @@ def gather_nd():
197206
- scores, as a 2D matrix
198207
- indices, as a [None, 2] matrix
199208
"""
209+
@opset(version=13)
200210
def _gather_nd(g):
201211
gather_nd_result_name = g.context.generate_variable_name('gather_nd_result')
202212
nodes = [
@@ -219,6 +229,7 @@ def _gather_nd(g):
219229

220230

221231
def greater_or_equal():
232+
@opset(version=16)
222233
def _greater_or_equal(g):
223234
greater_or_equal_result_name = g.context.generate_variable_name('greater_or_equal_result')
224235
nodes = [
@@ -241,6 +252,7 @@ def _greater_or_equal(g):
241252

242253

243254
def identity(name, suffix=True):
255+
@opset(version=13)
244256
def _identity(g):
245257
identity_name = g.context.generate_variable_name(name) if suffix else name
246258
nodes = [
@@ -263,6 +275,7 @@ def _identity(g):
263275

264276

265277
def less():
278+
@opset(version=13)
266279
def _less(g):
267280
less_result_name = g.context.generate_variable_name('less_result')
268281
nodes = [
@@ -285,6 +298,7 @@ def _less(g):
285298

286299

287300
def less_or_equal():
301+
@opset(version=16)
288302
def _less_or_equal(g):
289303
less_or_equal_result_name = g.context.generate_variable_name('less_or_equal_result')
290304
nodes = [
@@ -307,6 +321,7 @@ def _less_or_equal(g):
307321

308322

309323
def mul():
324+
@opset(version=13)
310325
def _mul(g):
311326
mul_result_name = g.context.generate_variable_name('mul_result')
312327
nodes = [
@@ -329,6 +344,7 @@ def _mul(g):
329344

330345

331346
def reduce_sum(keepdims=1, noop_with_empty_axes=0):
347+
@opset(version=13)
332348
def _reduce_sum(g):
333349
reduce_sum_result_name = g.context.generate_variable_name('reduce_sum_result')
334350
nodes = [
@@ -354,6 +370,7 @@ def _reduce_sum(g):
354370

355371

356372
def reshape(allowzero=0):
373+
@opset(version=13)
357374
def _reshape(g):
358375
reshape_result_name = g.context.generate_variable_name('reshape_result')
359376
nodes = [
@@ -376,6 +393,7 @@ def _reshape(g):
376393

377394

378395
def softmax(axis=-1):
396+
@opset(version=13)
379397
def _softmax(g):
380398
softmax_result_name = g.context.generate_variable_name('softmax_result')
381399
nodes = [
@@ -399,6 +417,7 @@ def _softmax(g):
399417

400418

401419
def split(axis=0):
420+
@opset(version=18)
402421
def _split(g):
403422
split_result_name = [
404423
g.context.generate_variable_name('split_result')

ebm2onnx/operators_ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import onnx
22
import ebm2onnx.graph as graph
3+
from .utils import opset
34

45

56
def category_mapper(cats_int64s, cats_strings, default_int64=-1, default_string="_Unused"):
7+
@opset(version=1, domain="ai.onnx.ml")
68
def _category_mapper(g):
79
category_mapper_result_name = g.context.generate_variable_name('category_mapper_result')
810
nodes = [

ebm2onnx/sklearn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ def convert_ebm_classifier(scope, operator, container):
3535
context=ctx
3636
)
3737

38-
for node in g.nodes:
39-
v = container._get_op_version(node.domain, node.op_type)
40-
container.node_domain_version_pair_sets.add((node.domain, v))
38+
for domain, version in g.opsets.items():
39+
container.node_domain_version_pair_sets.add(
40+
(domain, version)
41+
)
4142

4243
container.nodes.extend(g.nodes)
4344

ebm2onnx/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from onnx import defs
2-
3-
4-
def get_latest_opset_version():
5-
"""
6-
This module relies on *onnxruntime* to test every
7-
converter. The function returns the most recent
8-
target opset tested with *onnxruntime* or the opset
9-
version specified by *onnx* package if this one is lower
10-
(return by `onnx.defs.onnx_opset_version()`).
11-
"""
12-
return min(21, defs.onnx_opset_version())
1+
2+
3+
def opset(version: int, domain: str=""):
4+
def _operator(op):
5+
def _call(g):
6+
g = op(g)
7+
old_version = g.opsets.get(domain, -1)
8+
if version > old_version:
9+
g.opsets[domain] = version
10+
return g
11+
12+
return _call
13+
14+
return _operator

0 commit comments

Comments
 (0)