6
6
import collections
7
7
8
8
from coremltools .converters .mil .input_types import InputType , ClassifierConfig
9
- from coremltools .converters .mil .converter import _convert
9
+ from coremltools .converters .mil .converter import mil_convert
10
10
from coremltools .converters .mil .mil import Program
11
11
from coremltools ._deps import _HAS_TORCH , _HAS_TF_1 , _HAS_TF_2
12
12
from coremltools .converters ._profile_utils import _profile
@@ -39,6 +39,7 @@ def convert(
39
39
outputs = None ,
40
40
classifier_config = None ,
41
41
minimum_deployment_target = None ,
42
+ convert_to = 'nn_proto' ,
42
43
** kwargs
43
44
):
44
45
"""
@@ -68,7 +69,7 @@ def convert(
68
69
- Path to a `.pt` file
69
70
70
71
source: str (optional)
71
- One of `auto`, `tensorflow`, or `pytorch`. `auto` determines the
72
+ One of [ `auto`, `tensorflow`, `pytorch`, `mil`] . `auto` determines the
72
73
framework automatically for most cases. Raise ValueError if it fails
73
74
to determine the source framework.
74
75
@@ -108,12 +109,22 @@ def convert(
108
109
109
110
minimum_deployment_target: coremltools.target enumeration (optional)
110
111
- one of the members of enum "coremltools.target."
111
- - When not-specified or None, converter aims for as minimum of a deployment target as possible
112
+ - When not-specified or None, converter aims for as minimum of a
113
+ deployment target as possible
114
+
115
+ convert_to: str (optional)
116
+ - Must be one of ['nn_proto', 'mil'].
117
+ - 'nn_proto': Returns MLModel containing a NeuralNetwork
118
+ proto
119
+ - 'mil': Returns MIL program object. MIL program is primarily used
120
+ for debugging purpose and currently cannot be compiled to
121
+ executable.
112
122
113
123
Returns
114
124
-------
115
- model: MLModel
116
- A Core ML MLModel object
125
+ model: `coremltools.models.MLModel` or
126
+ `coremltools.converters.mil.Program`
127
+ A Core ML MLModel object or MIL Program object (see `convert_to`)
117
128
118
129
Examples
119
130
--------
@@ -157,24 +168,55 @@ def convert(
157
168
See `here <https://coremltools.readme.io/docs/neural-network-conversion>`_ for
158
169
more advanced options
159
170
"""
160
- if minimum_deployment_target is not None and not isinstance (
161
- minimum_deployment_target , AvailableTarget
162
- ):
171
+ _check_deployment_target (minimum_deployment_target )
172
+ exact_source = _determine_source (model , source , outputs )
173
+ _validate_inputs (model , exact_source , inputs , outputs , classifier_config ,
174
+ ** kwargs )
175
+
176
+ mlmodel = mil_convert (
177
+ model ,
178
+ convert_from = exact_source ,
179
+ convert_to = convert_to ,
180
+ inputs = inputs ,
181
+ outputs = outputs ,
182
+ classifier_config = classifier_config ,
183
+ ** kwargs
184
+ )
185
+
186
+ if convert_to == 'mil' :
187
+ return mlmodel # Returns the MIL program
188
+
189
+ if minimum_deployment_target is not None :
190
+ check_deployment_compatibility (
191
+ spec = mlmodel .get_spec (),
192
+ representation = convert_to ,
193
+ deployment_target = minimum_deployment_target ,
194
+ )
195
+
196
+ gc .collect ()
197
+
198
+ mlmodel = _record_src_version (mlmodel , exact_source )
199
+ mlmodel .user_defined_metadata [_METADATA_VERSION ] = ct_version
200
+
201
+ return mlmodel
202
+
203
+
204
+ def _check_deployment_target (minimum_deployment_target ):
205
+ if minimum_deployment_target is not None and \
206
+ not isinstance (minimum_deployment_target , AvailableTarget ):
163
207
msg = (
164
208
"Unrecognized value of argument 'minimum_deployment_target': {}. "
165
209
"It needs to be a member of 'coremltools.target' enumeration. "
166
210
"For example, coremltools.target.iOS13"
167
211
)
168
212
raise TypeError (msg .format (minimum_deployment_target ))
169
213
170
- source = source .lower ()
171
- if source not in {"auto" , "tensorflow" , "pytorch" }:
172
- msg = (
173
- 'Unrecognized value of argument "source": {}. '
174
- 'It must be one of ["auto", "tensorflow", "pytorch"].'
175
- )
176
- raise ValueError (msg .format (source ))
177
-
214
+ def _validate_inputs (model , exact_source , inputs , outputs , classifier_config ,
215
+ ** kwargs ):
216
+ """
217
+ Validate and process model, inputs, outputs, classifier_config based on
218
+ `exact_source` (which cannot be `auto`)
219
+ """
178
220
def raise_if_duplicated (input_list ):
179
221
# Detect duplicated inputs
180
222
input_names = [t .name for t in input_list if t .name is not None ]
@@ -196,56 +238,11 @@ def raise_if_duplicated(input_list):
196
238
msg = '"classifier_config" must be of type ClassifierConfig'
197
239
raise ValueError (msg )
198
240
199
- if source == "tensorflow" and _HAS_TF_2 :
200
- source = "tensorflow2"
201
-
202
- if source == "auto" and _HAS_TF_1 :
203
- try :
204
- loader = TF1Loader (model , outputs = outputs )
205
- loader ._graph_def_from_model (outputs = outputs )
206
- source = "tensorflow"
207
- except :
208
- pass
209
-
210
- if source == "auto" and _HAS_TF_2 :
211
- try :
212
- loader = TF2Loader (model , outputs = outputs )
213
- loader ._graph_def_from_model (outputs = outputs )
214
- source = "tensorflow2"
215
- except :
216
- pass
217
-
218
- if source == "auto" and _HAS_TORCH :
219
- try :
220
- pytorch_load (model )
221
- source = "pytorch"
222
- except :
223
- pass
224
-
225
- if source == "auto" and isinstance (model , Program ):
226
- source = "mil"
227
-
228
- convert_to = kwargs .get ("convert_to" , "nn_proto" )
229
- kwargs .pop ("convert_to" , None )
230
-
231
- if source == "auto" :
232
- msg = (
233
- "Unable to determine the type of the model, i.e. the source framework. "
234
- 'Please provide the value of argument "source", from one of '
235
- '["tensorflow", "pytorch"]. Note that model conversion requires the '
236
- "source package that generates the model. Please make sure you have "
237
- "the appropriate version of source package installed. E.g., if you're "
238
- "converting model originally trained with TensorFlow 1.14, make sure "
239
- "you have `tensorflow==1.14` installed."
240
- )
241
- raise ValueError (msg )
242
-
243
- elif source in {"tensorflow" , "tensorflow2" }:
244
-
245
- if source == "tensorflow" and not _HAS_TF_1 :
246
- raise ValueError (
247
- 'Converter was called with source="tensorflow", but missing tensorflow package'
248
- )
241
+ if exact_source in {"tensorflow" , "tensorflow2" }:
242
+ if exact_source == "tensorflow" and not _HAS_TF_1 :
243
+ msg = 'Converter was called with source="tensorflow", ' + \
244
+ 'but missing tensorflow package'
245
+ raise ValueError (msg )
249
246
250
247
if inputs is not None :
251
248
raise_if_duplicated (inputs )
@@ -255,17 +252,7 @@ def raise_if_duplicated(input_list):
255
252
):
256
253
raise ValueError ("Input should be a list of TensorType or ImageType" )
257
254
258
- proto_spec = _convert (
259
- model ,
260
- convert_from = source ,
261
- convert_to = convert_to ,
262
- inputs = inputs ,
263
- outputs = outputs ,
264
- classifier_config = classifier_config ,
265
- ** kwargs
266
- )
267
-
268
- elif source == "pytorch" :
255
+ elif exact_source == "pytorch" :
269
256
if "example_inputs" in kwargs :
270
257
msg = 'Unexpected argument "example_inputs" found'
271
258
raise ValueError (msg )
@@ -300,55 +287,81 @@ def _flatten_list(_inputs):
300
287
if outputs is not None :
301
288
raise ValueError ("outputs must not be specified for PyTorch" )
302
289
303
- proto_spec = _convert (
304
- model ,
305
- convert_from = "torch" ,
306
- convert_to = convert_to ,
307
- inputs = inputs ,
308
- outputs = outputs ,
309
- classifier_config = classifier_config ,
310
- ** kwargs
311
- )
312
-
313
- elif source == "mil" :
290
+ elif exact_source == "mil" :
314
291
if not isinstance (model , Program ):
315
292
msg = "Converter was asked to convert MIL input, but input is not a MIL program!"
316
293
raise ValueError (msg )
317
294
318
- proto_spec = _convert (
319
- model ,
320
- convert_from = "mil" ,
321
- convert_to = convert_to ,
322
- example_inputs = inputs ,
323
- classifier_config = classifier_config ,
324
- ** kwargs
295
+
296
+ def _determine_source (model , source , outputs ):
297
+ """
298
+ Infer source (which can be auto) to the precise framework.
299
+ """
300
+ source = source .lower ()
301
+ if source not in {"auto" , "tensorflow" , "pytorch" , "mil" }:
302
+ msg = (
303
+ 'Unrecognized value of argument "source": {}. '
304
+ 'It must be one of ["auto", "tensorflow", "pytorch"].'
325
305
)
306
+ raise ValueError (msg .format (source ))
326
307
327
- if convert_to == 'mil' :
328
- return proto_spec # Returns the MIL program
329
308
330
- useCPUOnly = kwargs .get ("useCPUOnly" , True )
331
- model = coremltools .models .MLModel (proto_spec , useCPUOnly = useCPUOnly )
309
+ # Determine tensorflow version
310
+ if source == "tensorflow" and _HAS_TF_2 :
311
+ return "tensorflow2"
332
312
333
- if minimum_deployment_target is not None :
334
- check_deployment_compatibility (
335
- spec = proto_spec ,
336
- representation = convert_to ,
337
- deployment_target = minimum_deployment_target ,
338
- )
313
+ if source != 'auto' :
314
+ return source
339
315
340
- del proto_spec
341
- gc .collect ()
316
+ # Determine `auto` source
317
+ if source == "auto" and _HAS_TF_1 :
318
+ try :
319
+ loader = TF1Loader (model , outputs = outputs )
320
+ loader ._graph_def_from_model (outputs = outputs )
321
+ return "tensorflow"
322
+ except :
323
+ pass
342
324
325
+ if source == "auto" and _HAS_TF_2 :
326
+ try :
327
+ loader = TF2Loader (model , outputs = outputs )
328
+ loader ._graph_def_from_model (outputs = outputs )
329
+ return "tensorflow2"
330
+ except :
331
+ pass
332
+
333
+ if source == "auto" and _HAS_TORCH :
334
+ try :
335
+ pytorch_load (model )
336
+ return "pytorch"
337
+ except :
338
+ pass
339
+
340
+ if source == "auto" and isinstance (model , Program ):
341
+ return "mil"
342
+
343
+ msg = (
344
+ "Unable to determine the type of the model, i.e. the source framework. "
345
+ 'Please provide the value of argument "source", from one of '
346
+ '["tensorflow", "pytorch", "mil"]. Note that model conversion requires the '
347
+ "source package that generates the model. Please make sure you have "
348
+ "the appropriate version of source package installed. E.g., if you're "
349
+ "converting model originally trained with TensorFlow 1.14, make sure "
350
+ "you have `tensorflow==1.14` installed."
351
+ )
352
+ raise ValueError (msg )
353
+
354
+
355
+ def _record_src_version (mlmodel , exact_source ):
343
356
# recording metadata: coremltools version, source framework and version
344
- if source in {"tensorflow" , "tensorflow2" } and (_HAS_TF_1 or _HAS_TF_2 ):
357
+ if exact_source in {"tensorflow" , "tensorflow2" } and (_HAS_TF_1 or _HAS_TF_2 ):
345
358
src_pkg_version = "tensorflow=={0}" .format (tf .__version__ )
346
- elif source == "pytorch" and _HAS_TORCH :
359
+ elif exact_source == "pytorch" and _HAS_TORCH :
347
360
src_pkg_version = "torch=={0}" .format (torch .__version__ )
361
+ elif exact_source == 'mil' :
362
+ src_pkg_version = "mil"
348
363
else :
349
- src_pkg_version = "unknown"
350
-
351
- model .user_defined_metadata [_METADATA_VERSION ] = ct_version
352
- model .user_defined_metadata [_METADATA_SOURCE ] = src_pkg_version
364
+ raise ValueError ('Unsupported source {}' .format (exact_source ))
353
365
354
- return model
366
+ mlmodel .user_defined_metadata [_METADATA_SOURCE ] = src_pkg_version
367
+ return mlmodel
0 commit comments