From bb2b05f9fcf3939f6ebd21b42c5661d652bb1619 Mon Sep 17 00:00:00 2001 From: Gibson Chikafa Date: Wed, 31 Jan 2024 19:00:55 +0100 Subject: [PATCH] Remove tritonclient dependancy --- python/kserve/kserve/protocol/infer_type.py | 98 ++++++++++++++++++++- python/kserve/pyproject.toml | 1 - 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/python/kserve/kserve/protocol/infer_type.py b/python/kserve/kserve/protocol/infer_type.py index cea572a00f3..80f149405be 100644 --- a/python/kserve/kserve/protocol/infer_type.py +++ b/python/kserve/kserve/protocol/infer_type.py @@ -17,7 +17,6 @@ import numpy import numpy as np import pandas as pd -from tritonclient.utils import raise_error, serialize_byte_tensor from ..constants.constants import GRPC_CONTENT_DATATYPE_MAPPINGS from ..errors import InvalidInput @@ -25,6 +24,103 @@ from ..utils.numpy_codec import to_np_dtype, from_np_dtype +def raise_error(msg): + """ + Raise error with the provided message + """ + raise InferenceServerException(msg=msg) from None + + +def serialized_byte_size(tensor_value): + """ + Get the underlying number of bytes for a numpy ndarray. + + Parameters + ---------- + tensor_value : numpy.ndarray + Numpy array to calculate the number of bytes for. + + Returns + ------- + int + Number of bytes present in this tensor + """ + + if tensor_value.dtype != np.object_: + raise_error("The tensor_value dtype must be np.object_") + + if tensor_value.size > 0: + total_bytes = 0 + # 'C' order is row-major. + for obj in np.nditer(tensor_value, flags=["refs_ok"], order="C"): + total_bytes += len(obj.item()) + return total_bytes + else: + return 0 + + +class InferenceServerException(Exception): + """Exception indicating non-Success status. + + Parameters + ---------- + msg : str + A brief description of error + + status : str + The error code + + debug_details : str + The additional details on the error + + """ + + def __init__(self, msg, status=None, debug_details=None): + self._msg = msg + self._status = status + self._debug_details = debug_details + + def __str__(self): + msg = super().__str__() if self._msg is None else self._msg + if self._status is not None: + msg = "[" + self._status + "] " + msg + return msg + + def message(self): + """Get the exception message. + + Returns + ------- + str + The message associated with this exception, or None if no message. + + """ + return self._msg + + def status(self): + """Get the status of the exception. + + Returns + ------- + str + Returns the status of the exception + + """ + return self._status + + def debug_details(self): + """Get the detailed information about the exception + for debugging purposes + + Returns + ------- + str + Returns the exception details + + """ + return self._debug_details + + class InferInput: _name: str _shape: List[int] diff --git a/python/kserve/pyproject.toml b/python/kserve/pyproject.toml index 6ec052074e8..44a0fa07927 100644 --- a/python/kserve/pyproject.toml +++ b/python/kserve/pyproject.toml @@ -49,7 +49,6 @@ prometheus-client = "^0.13.1" orjson = "^3.8.0" httpx = "^0.23.0" timing-asgi = "^0.3.0" -tritonclient = "^2.18.0" tabulate = "^0.9.0" pandas = ">=1.3.5"