Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OP_REQUIRES failed at xla_compile_on_demand_op.cc:290 : UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND #2217

Open
rb-23 opened this issue Apr 12, 2024 · 5 comments

Comments

@rb-23
Copy link

rb-23 commented Apr 12, 2024

Bug Report

If this is a bug report, please fill out the following form in full:

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    • Ubuntu 20.04
  • TensorFlow Serving installed from (source or binary):
    • Singularity Container taken from tf serving Docker-Hub
  • TensorFlow Serving version:
    • TensorFlow ModelServer: 2.14.0-rc1
    • TensorFlow Library: 2.14.0

Describe the problem

Although CUDA and all other relevant libraries were linked in, when running inference on the model server, the CUDA compiler is not found. This does not happen if i try to run other models with the same containers.

Exact Steps to Reproduce

  1. Download and build the relevant singularity containers from the docker-hub, with sudo singularity build. The following are the container definition file for the tensorflow-serving container, as well as the base tensorflow container to run inference from.

base_tensorflow_container.def:

Bootstrap: docker
From: tensorflow/tensorflow:2.14.0-gpu


%environment
	export PATH=${PATH}:/cm/local/apps/cuda/libs/current/bin
	export MODEL_NAME=model 
	export MODEL_BASE_PATH=/models
	
%files
    model /models/model

tensorflow_container.def:

Bootstrap: docker
From: tensorflow/serving:2.14.0-gpu


%environment
	export PATH=${PATH}:/cm/local/apps/cuda/libs/current/bin
	export MODEL_NAME=model 
	export MODEL_BASE_PATH=/models
	
%files
    model /models/model
  1. Build the model file to be served
    save_model.py:
import tensorflow as tf
from transformers import TFBartForConditionalGeneration, BartTokenizer
import numpy as np

class MyOwnModel(tf.Module):
    def __init__(self, model_path="facebook/bart-large-cnn"):
        super(MyOwnModel, self).__init__()
        self.model = TFBartForConditionalGeneration.from_pretrained(model_path, no_repeat_ngram_size=None)

    @tf.function(input_signature=[tf.TensorSpec(shape=[1, 1024], dtype=tf.int32, name="input_ids")]) 
    def serving(self, input_ids):
        return self.model.generate(input_ids=input_ids)

model = MyOwnModel()
export_dir = "./shaped_input_model"


tf.saved_model.save(model, export_dir, signatures={"serving_default": model.serving})
  1. Run the tensorflow singularity container with singularity run --nv -B shaped_input_model:/models/model/1 -B /usr/local/cuda-11.8:/usr/local/cuda-11.8 tensorflow_container.sif --per_process_gpu_memory_fraction=0.5
  2. Enter the base tensorflow inference with singularity run --nv base_tensorflow_container.sif and run inference using the python script

infer.py:

import tensorflow as tf
from transformers import BartTokenizer
import json
import numpy as np
import requests

article = "At least 14 people were killed and 60 others wounded Thursday when a bomb ripped through a crowd waiting to see Algeria's president in Batna, east of the capital of Algiers, the Algerie Presse Service reported. A wounded person gets first aid shortly after Thursday's attack in Batna, Algeria. The explosion occurred at 5 p.m. about 20 meters (65 feet) from a mosque in Batna, a town about 450 kilometers (280 miles) east of Algiers, security officials in Batna told the state-run news agency. The bomb went off 15 minutes before the expected arrival of President Abdel-Aziz Bouteflika. It wasn't clear if the bomb was caused by a suicide bomber or if it was planted, the officials said. Later Thursday, Algeria's Interior Minister Noureddine Yazid Zerhouni said \"a suspect person who was among the crowd attempted to go beyond the security cordon,\" but the person escaped \"immediately after the bomb exploded,\" the press service reported. Bouteflika made his visit to Batna as planned, adding a stop at a hospital to visit the wounded before he returned to the capital. There was no immediate claim of responsibility for the bombing. Algeria faces a continuing Islamic insurgency, according to the CIA. In July, 33 people were killed in apparent suicide bombings in Algiers that were claimed by an al Qaeda-affiliated group. Bouteflika said terrorist acts have nothing in common with the noble values of Islam, the press service reported. E-mail to a friend . CNN's Mohammed Tawfeeq contributed to this report."

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
url = "http://localhost:8501/v1/models/model:predict"
MAX_SHAPE = 1024

inputs = tokenizer.encode(article, return_tensors="np")
padding_length = 1024 - inputs.shape[1]

inputs = np.pad(inputs, ((0, 0), (0, padding_length)), mode='constant')
print(inputs.shape)


inputs = inputs.tolist()
#inputs = inputs[0]

json_data = json.dumps(
    {
        "signature_name": "serving_default",
        "inputs": inputs,
    }
)

json_response = requests.post(url, data=json_data)
response = json.loads(json_response.text)
print(f"Summary: {response}")
  1. Error occurs at this point

Source code / logs

output of infer.py
Summary: {'error': '2 root error(s) found.\n (0) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?\n\t [[{{function_node while_body_26758}}{{node while/XlaDynamicUpdateSlice}}]]\n\t [[StatefulPartitionedCall/StatefulPartitionedCall/while/body/_1058/while/tf_bart_for_conditional_generation/model/decoder/assert_less/Assert/Const_1/_1674]]\n (1) UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?\n\t [[{{function_node while_body_26758}}{{node while/XlaDynamicUpdateSlice}}]]\n0 successful operations.\n0 derived errors ignored.'}

nvcc inside tensorflow_serving singularity container:

Singularity> nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

Singularity> tensorflow_model_server --version
TensorFlow ModelServer: 2.14.0-rc1
TensorFlow Library: 2.14.0
@rb-23 rb-23 changed the title OP_REQUIRES failed at xla_compile_on_demand_op.cc:290 : UNIMPLEMENTED: Could not find compiler for platform CUDA: OP_REQUIRES failed at xla_compile_on_demand_op.cc:290 : UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND Apr 12, 2024
@kmkolasinski
Copy link

hi, the error suggests that there is an issue with some dynamics loops which can be implemented in the generate function (node while/XlaDynamicUpdateSlice). I believe the generate function is creating some dynamic tensors inside the loop which is not supported. XLA errors are hard to read sometimes.

@singhniraj08 singhniraj08 self-assigned this Apr 22, 2024
@singhniraj08
Copy link

@rb-23, Can you try passing --xla_cpu_compilation_enabled=true parameter as additional argument while running TF Serving docker image as shown here and see if model inference works. Please let us know if you face any issues, Thank you!

Copy link

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 30, 2024
@rb-23
Copy link
Author

rb-23 commented Apr 30, 2024

@rb-23, Can you try passing --xla_cpu_compilation_enabled=true parameter as additional argument while running TF Serving docker image as shown here and see if model inference works. Please let us know if you face any issues, Thank you!

Hi @singhniraj08 , i tried doing as you suggested. Unfortunately, the same error comes out: 2024-04-30 14:37:45.277728: W external/org_tensorflow/tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_compile_on_demand_op.cc:290 : UNIMPLEMENTED: Could not find compiler for platform CUDA: NOT_FOUND: could not find registered compiler for platform CUDA -- was support for that platform linked in?

@google-ml-butler google-ml-butler bot removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response labels Apr 30, 2024
@janasangeetha
Copy link

Tagging similar issue in past #2214
@gharibian, request your help on this. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants