Skip to content

PyTorch __ior__ op is not implemented for conversion #2584

@kamisori-daijin

Description

@kamisori-daijin

Description

Hello,

I encountered an issue while trying to convert a PyTorch model (gemma-3-1b-it) to Core ML format. The conversion process failed with the following error: PyTorch convert function for op 'ior' not implemented..

I understand there was a recent fix related to an ior issue in RangeDim. I have confirmed that I am using the latest version of coremltools by installing directly from the main branch of this repository.

Steps to Reproduce

Install coremltools from the main branch: pip install git+https://github.com/apple/coremltools.git

Run the provided Python script with the gemma-3-1b-it model.

The conversion fails with the traceback shown below.

Environment

CoreMLTools Version: coremltools @ git+https://github.com/apple/coremltools.git@0f4244215c1f293f9b822b194fede05ad0e93851

PyTorch Version: 2.2.2

Python Version: 3.12.10

Additional Information

Here is the traceback I received:

Failed to load _MLModelProxy: No module named 'coremltools.libcoremlpython'
...
(Omitted - 'coremltools.libcoremlpython' related errors)
...
Fail to import BlobReader from libmilstoragepython. No module named 'coremltools.libmilstoragepython'
...
(Omitted - 'coremltools.libmilstoragepython' related errors)
...
Failed to load '_MLCPUComputeDeviceRemoteProxy'. Remote device functionality for retrieving the compute plan is unavailable.
...
(Omitted - 'RemoteProxy' related errors)
...
CoreMLTools Version: 9.0b1
PyTorch Version: 2.2.2
Numpy Version: 1.26.4
Loading Hugging Face model from '.../gemma-3-1b-it' into memory...
Model loaded and configured.
Wrapper model prepared.
Tracing wrapper model to TorchScript...
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
/Users/.../transformers/masking_utils.py:190: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
...
(Omitted - TracerWarning)
...
TorchScript tracing complete.
Converting TorchScript model to Core ML...
Model is not in eval mode. Consider calling '.eval()' on your model prior to conversion
Converting PyTorch Frontend ==> MIL Ops:   0%| | 0/4901 [00:00<Core ML embedding (gather) layer does not support any inputs besides the weights and indices. Those given will be ignored.
Converting PyTorch Frontend ==> MIL Ops:   0%| | 23/4901 [00:00

ERROR - converting '__ior__' op (located at: 'model/model/attention_mask.19'):

Converting PyTorch Frontend ==> MIL Ops:   1%| | 57/4901 [00:00
Conversion to CoreML failed: PyTorch convert function for op '__ior__' not implemented.

Here is my convert.py script:

import coremltools as ct
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
import os


def main():
    # Configuration for parsing command line arguments
    parser = argparse.ArgumentParser(description="Convert a Hugging Face model to Core ML.")
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Path to the downloaded Hugging Face model directory (e.g., 'gemma-3-1b-it')."
    )
    args = parser.parse_args()

    # Use the path received as a command line argument
    downloaded_hf_model_dir = args.model

    print(f"CoreMLTools Version: {ct.__version__}")
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Numpy Version: {np.__version__}")

    try:
        # 1. Hugging Face model loading
        print(f"Loading Hugging Face model from '{downloaded_hf_model_dir}' into memory...")
        model = AutoModelForCausalLM.from_pretrained(downloaded_hf_model_dir, torch_dtype=torch.float16)
        model.eval()
        model.config.use_cache = False
        print("Model loaded and configured.")

        # 2. Create a wrapper model for Core ML conversion
        class GemmaCoreMLWrapper(nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model
                self.model.config.use_cache = False

            def forward(self, input_ids, attention_mask):
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    use_cache=False,
                    return_dict=False,
                    output_attentions=False,
                    output_hidden_states=False
                )
                logits = outputs[0]
                return logits

        wrapped_model = GemmaCoreMLWrapper(model)
        print("Wrapper model prepared.")

        # 3. Prepare dummy inputs for TorchScript tracing
        max_seq_length = 1024
        tokenizer = AutoTokenizer.from_pretrained(downloaded_hf_model_dir)

        dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (1, 10), dtype=torch.long)
        dummy_attention_mask = torch.ones(1, 10, dtype=torch.long)

        # 4. Trace the wrapper model to TorchScript
        print("Tracing wrapper model to TorchScript...")
        traced_model = torch.jit.trace(wrapped_model, (dummy_input_ids, dummy_attention_mask))
        print("TorchScript tracing complete.")

        # 5. Convert TorchScript model to Core ML
        print("Converting TorchScript model to Core ML...")
        coreml_model = ct.convert(
            traced_model,
            inputs=[
                ct.TensorType(name="input_ids", shape=(1, ct.RangeDim(upper_bound=max_seq_length)), dtype=np.int32),
                ct.TensorType(name="attention_mask", shape=(1, ct.RangeDim(upper_bound=max_seq_length)), dtype=np.int32)
            ],
            source="pytorch",
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16
        )

        model_name = os.path.basename(downloaded_hf_model_dir)
        output_filename = f"{model_name}-coreml.mlpackage"
        coreml_model.save(output_filename)
        print(f"CoreML model saved successfully to {output_filename}.")

    except Exception as e:
        print(f"Conversion to CoreML failed: {e}")


if __name__ == "__main__":
    main()

How to Use convert.py

This script is designed to be run from the command line. You need to provide the path to the model you want to convert using the --model argument.
Basic Command
python convert.py --model "[path_to_your_model]"

Example

If your model is located at /path/your/directory/gemma-3-1b-it, the command would be:
python convert.py --model "/path/your/directory/gemma-3-1b-it"

important Notes

Current Directory: Make sure you are in the project's root directory when you run the command.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions