Skip to content

Commit d4fa24a

Browse files
authored
Merge branch 'master' into patch-2
2 parents bdf0bf1 + b6684ec commit d4fa24a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2420
-1290
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ With coremltools, you can do the following:
2929
To get the latest version of coremltools:
3030

3131
```shell
32-
pip install coremltools==4.0b4
32+
pip install --upgrade coremltools
3333
```
3434

3535
For the latest changes please see the [release notes](https://github.com/apple/coremltools/releases/).

coremltools/_deps/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __get_sklearn_version(version):
8989
_TF_1_MIN_VERSION = "1.12.0"
9090
_TF_1_MAX_VERSION = "1.15.0"
9191
_TF_2_MIN_VERSION = "2.1.0"
92-
_TF_2_MAX_VERSION = "2.3.0"
92+
_TF_2_MAX_VERSION = "2.3.1"
9393

9494
try:
9595
import tensorflow

coremltools/converters/_converters_entry.py

+127-114
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import collections
77

88
from coremltools.converters.mil.input_types import InputType, ClassifierConfig
9-
from coremltools.converters.mil.converter import _convert
9+
from coremltools.converters.mil.converter import mil_convert
1010
from coremltools.converters.mil.mil import Program
1111
from coremltools._deps import _HAS_TORCH, _HAS_TF_1, _HAS_TF_2
1212
from coremltools.converters._profile_utils import _profile
@@ -39,6 +39,7 @@ def convert(
3939
outputs=None,
4040
classifier_config=None,
4141
minimum_deployment_target=None,
42+
convert_to='nn_proto',
4243
**kwargs
4344
):
4445
"""
@@ -68,7 +69,7 @@ def convert(
6869
- Path to a `.pt` file
6970
7071
source: str (optional)
71-
One of `auto`, `tensorflow`, or `pytorch`. `auto` determines the
72+
One of [`auto`, `tensorflow`, `pytorch`, `mil`]. `auto` determines the
7273
framework automatically for most cases. Raise ValueError if it fails
7374
to determine the source framework.
7475
@@ -108,12 +109,22 @@ def convert(
108109
109110
minimum_deployment_target: coremltools.target enumeration (optional)
110111
- one of the members of enum "coremltools.target."
111-
- When not-specified or None, converter aims for as minimum of a deployment target as possible
112+
- When not-specified or None, converter aims for as minimum of a
113+
deployment target as possible
114+
115+
convert_to: str (optional)
116+
- Must be one of ['nn_proto', 'mil'].
117+
- 'nn_proto': Returns MLModel containing a NeuralNetwork
118+
proto
119+
- 'mil': Returns MIL program object. MIL program is primarily used
120+
for debugging purpose and currently cannot be compiled to
121+
executable.
112122
113123
Returns
114124
-------
115-
model: MLModel
116-
A Core ML MLModel object
125+
model: `coremltools.models.MLModel` or
126+
`coremltools.converters.mil.Program`
127+
A Core ML MLModel object or MIL Program object (see `convert_to`)
117128
118129
Examples
119130
--------
@@ -157,24 +168,55 @@ def convert(
157168
See `here <https://coremltools.readme.io/docs/neural-network-conversion>`_ for
158169
more advanced options
159170
"""
160-
if minimum_deployment_target is not None and not isinstance(
161-
minimum_deployment_target, AvailableTarget
162-
):
171+
_check_deployment_target(minimum_deployment_target)
172+
exact_source = _determine_source(model, source, outputs)
173+
_validate_inputs(model, exact_source, inputs, outputs, classifier_config,
174+
**kwargs)
175+
176+
mlmodel = mil_convert(
177+
model,
178+
convert_from=exact_source,
179+
convert_to=convert_to,
180+
inputs=inputs,
181+
outputs=outputs,
182+
classifier_config=classifier_config,
183+
**kwargs
184+
)
185+
186+
if convert_to == 'mil':
187+
return mlmodel # Returns the MIL program
188+
189+
if minimum_deployment_target is not None:
190+
check_deployment_compatibility(
191+
spec=mlmodel.get_spec(),
192+
representation=convert_to,
193+
deployment_target=minimum_deployment_target,
194+
)
195+
196+
gc.collect()
197+
198+
mlmodel = _record_src_version(mlmodel, exact_source)
199+
mlmodel.user_defined_metadata[_METADATA_VERSION] = ct_version
200+
201+
return mlmodel
202+
203+
204+
def _check_deployment_target(minimum_deployment_target):
205+
if minimum_deployment_target is not None and \
206+
not isinstance(minimum_deployment_target, AvailableTarget):
163207
msg = (
164208
"Unrecognized value of argument 'minimum_deployment_target': {}. "
165209
"It needs to be a member of 'coremltools.target' enumeration. "
166210
"For example, coremltools.target.iOS13"
167211
)
168212
raise TypeError(msg.format(minimum_deployment_target))
169213

170-
source = source.lower()
171-
if source not in {"auto", "tensorflow", "pytorch"}:
172-
msg = (
173-
'Unrecognized value of argument "source": {}. '
174-
'It must be one of ["auto", "tensorflow", "pytorch"].'
175-
)
176-
raise ValueError(msg.format(source))
177-
214+
def _validate_inputs(model, exact_source, inputs, outputs, classifier_config,
215+
**kwargs):
216+
"""
217+
Validate and process model, inputs, outputs, classifier_config based on
218+
`exact_source` (which cannot be `auto`)
219+
"""
178220
def raise_if_duplicated(input_list):
179221
# Detect duplicated inputs
180222
input_names = [t.name for t in input_list if t.name is not None]
@@ -196,56 +238,11 @@ def raise_if_duplicated(input_list):
196238
msg = '"classifier_config" must be of type ClassifierConfig'
197239
raise ValueError(msg)
198240

199-
if source == "tensorflow" and _HAS_TF_2:
200-
source = "tensorflow2"
201-
202-
if source == "auto" and _HAS_TF_1:
203-
try:
204-
loader = TF1Loader(model, outputs=outputs)
205-
loader._graph_def_from_model(outputs=outputs)
206-
source = "tensorflow"
207-
except:
208-
pass
209-
210-
if source == "auto" and _HAS_TF_2:
211-
try:
212-
loader = TF2Loader(model, outputs=outputs)
213-
loader._graph_def_from_model(outputs=outputs)
214-
source = "tensorflow2"
215-
except:
216-
pass
217-
218-
if source == "auto" and _HAS_TORCH:
219-
try:
220-
pytorch_load(model)
221-
source = "pytorch"
222-
except:
223-
pass
224-
225-
if source == "auto" and isinstance(model, Program):
226-
source = "mil"
227-
228-
convert_to = kwargs.get("convert_to", "nn_proto")
229-
kwargs.pop("convert_to", None)
230-
231-
if source == "auto":
232-
msg = (
233-
"Unable to determine the type of the model, i.e. the source framework. "
234-
'Please provide the value of argument "source", from one of '
235-
'["tensorflow", "pytorch"]. Note that model conversion requires the '
236-
"source package that generates the model. Please make sure you have "
237-
"the appropriate version of source package installed. E.g., if you're "
238-
"converting model originally trained with TensorFlow 1.14, make sure "
239-
"you have `tensorflow==1.14` installed."
240-
)
241-
raise ValueError(msg)
242-
243-
elif source in {"tensorflow", "tensorflow2"}:
244-
245-
if source == "tensorflow" and not _HAS_TF_1:
246-
raise ValueError(
247-
'Converter was called with source="tensorflow", but missing tensorflow package'
248-
)
241+
if exact_source in {"tensorflow", "tensorflow2"}:
242+
if exact_source == "tensorflow" and not _HAS_TF_1:
243+
msg = 'Converter was called with source="tensorflow", ' +\
244+
'but missing tensorflow package'
245+
raise ValueError(msg)
249246

250247
if inputs is not None:
251248
raise_if_duplicated(inputs)
@@ -255,17 +252,7 @@ def raise_if_duplicated(input_list):
255252
):
256253
raise ValueError("Input should be a list of TensorType or ImageType")
257254

258-
proto_spec = _convert(
259-
model,
260-
convert_from=source,
261-
convert_to=convert_to,
262-
inputs=inputs,
263-
outputs=outputs,
264-
classifier_config=classifier_config,
265-
**kwargs
266-
)
267-
268-
elif source == "pytorch":
255+
elif exact_source == "pytorch":
269256
if "example_inputs" in kwargs:
270257
msg = 'Unexpected argument "example_inputs" found'
271258
raise ValueError(msg)
@@ -300,55 +287,81 @@ def _flatten_list(_inputs):
300287
if outputs is not None:
301288
raise ValueError("outputs must not be specified for PyTorch")
302289

303-
proto_spec = _convert(
304-
model,
305-
convert_from="torch",
306-
convert_to=convert_to,
307-
inputs=inputs,
308-
outputs=outputs,
309-
classifier_config=classifier_config,
310-
**kwargs
311-
)
312-
313-
elif source == "mil":
290+
elif exact_source == "mil":
314291
if not isinstance(model, Program):
315292
msg = "Converter was asked to convert MIL input, but input is not a MIL program!"
316293
raise ValueError(msg)
317294

318-
proto_spec = _convert(
319-
model,
320-
convert_from="mil",
321-
convert_to=convert_to,
322-
example_inputs=inputs,
323-
classifier_config=classifier_config,
324-
**kwargs
295+
296+
def _determine_source(model, source, outputs):
297+
"""
298+
Infer source (which can be auto) to the precise framework.
299+
"""
300+
source = source.lower()
301+
if source not in {"auto", "tensorflow", "pytorch", "mil"}:
302+
msg = (
303+
'Unrecognized value of argument "source": {}. '
304+
'It must be one of ["auto", "tensorflow", "pytorch"].'
325305
)
306+
raise ValueError(msg.format(source))
326307

327-
if convert_to == 'mil':
328-
return proto_spec # Returns the MIL program
329308

330-
useCPUOnly = kwargs.get("useCPUOnly", True)
331-
model = coremltools.models.MLModel(proto_spec, useCPUOnly=useCPUOnly)
309+
# Determine tensorflow version
310+
if source == "tensorflow" and _HAS_TF_2:
311+
return "tensorflow2"
332312

333-
if minimum_deployment_target is not None:
334-
check_deployment_compatibility(
335-
spec=proto_spec,
336-
representation=convert_to,
337-
deployment_target=minimum_deployment_target,
338-
)
313+
if source != 'auto':
314+
return source
339315

340-
del proto_spec
341-
gc.collect()
316+
# Determine `auto` source
317+
if source == "auto" and _HAS_TF_1:
318+
try:
319+
loader = TF1Loader(model, outputs=outputs)
320+
loader._graph_def_from_model(outputs=outputs)
321+
return "tensorflow"
322+
except:
323+
pass
342324

325+
if source == "auto" and _HAS_TF_2:
326+
try:
327+
loader = TF2Loader(model, outputs=outputs)
328+
loader._graph_def_from_model(outputs=outputs)
329+
return "tensorflow2"
330+
except:
331+
pass
332+
333+
if source == "auto" and _HAS_TORCH:
334+
try:
335+
pytorch_load(model)
336+
return "pytorch"
337+
except:
338+
pass
339+
340+
if source == "auto" and isinstance(model, Program):
341+
return "mil"
342+
343+
msg = (
344+
"Unable to determine the type of the model, i.e. the source framework. "
345+
'Please provide the value of argument "source", from one of '
346+
'["tensorflow", "pytorch", "mil"]. Note that model conversion requires the '
347+
"source package that generates the model. Please make sure you have "
348+
"the appropriate version of source package installed. E.g., if you're "
349+
"converting model originally trained with TensorFlow 1.14, make sure "
350+
"you have `tensorflow==1.14` installed."
351+
)
352+
raise ValueError(msg)
353+
354+
355+
def _record_src_version(mlmodel, exact_source):
343356
# recording metadata: coremltools version, source framework and version
344-
if source in {"tensorflow", "tensorflow2"} and (_HAS_TF_1 or _HAS_TF_2):
357+
if exact_source in {"tensorflow", "tensorflow2"} and (_HAS_TF_1 or _HAS_TF_2):
345358
src_pkg_version = "tensorflow=={0}".format(tf.__version__)
346-
elif source == "pytorch" and _HAS_TORCH:
359+
elif exact_source == "pytorch" and _HAS_TORCH:
347360
src_pkg_version = "torch=={0}".format(torch.__version__)
361+
elif exact_source == 'mil':
362+
src_pkg_version = "mil"
348363
else:
349-
src_pkg_version = "unknown"
350-
351-
model.user_defined_metadata[_METADATA_VERSION] = ct_version
352-
model.user_defined_metadata[_METADATA_SOURCE] = src_pkg_version
364+
raise ValueError('Unsupported source {}'.format(exact_source))
353365

354-
return model
366+
mlmodel.user_defined_metadata[_METADATA_SOURCE] = src_pkg_version
367+
return mlmodel

coremltools/converters/mil/backend/nn/op_mapping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ def l2_pool(const_context, builder, op):
13661366
def linear(const_context, builder, op):
13671367
out_channels, in_channels = op.weight.shape
13681368
if op.x.rank and op.x.rank <= 3 and op.x.rank > 0:
1369-
has_bias = op.bias.val is not None
1369+
has_bias = op.bias is not None and op.bias.val is not None
13701370
builder.add_inner_product(
13711371
name=op.name,
13721372
W=op.weight.val,
1373-
b=op.bias.val,
1373+
b=op.bias.val if has_bias else None,
13741374
input_channels=in_channels,
13751375
output_channels=out_channels,
13761376
has_bias=has_bias,

0 commit comments

Comments
 (0)