Skip to content

Commit

Permalink
ENH: Refine request log and add optional request_id (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostyplanet authored Sep 6, 2024
1 parent c60e8fd commit e2618be
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 43 deletions.
1 change: 1 addition & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_xinference_home() -> str:
XINFERENCE_DEFAULT_LOG_FILE_NAME = "xinference.log"
XINFERENCE_LOG_MAX_BYTES = 100 * 1024 * 1024
XINFERENCE_LOG_BACKUP_COUNT = 30
XINFERENCE_LOG_ARG_MAX_LENGTH = 100
XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD = int(
os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD, 5)
)
Expand Down
63 changes: 51 additions & 12 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import time
import types
import uuid
import weakref
from asyncio.queues import Queue
from asyncio.tasks import wait_for
Expand Down Expand Up @@ -444,18 +445,30 @@ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
@log_async(logger=logger)
async def generate(self, prompt: str, *args, **kwargs):
if self.allow_batching():
# not support request_id
kwargs.pop("request_id", None)
return await self.handle_batching_request(
prompt, "generate", *args, **kwargs
)
else:
kwargs.pop("raw_params", None)
if hasattr(self._model, "generate"):
# not support request_id
kwargs.pop("request_id", None)
return await self._call_wrapper_json(
self._model.generate, prompt, *args, **kwargs
)
if hasattr(self._model, "async_generate"):
if "request_id" not in kwargs:
kwargs["request_id"] = str(uuid.uuid1())
else:
# model only accept string
kwargs["request_id"] = str(kwargs["request_id"])
return await self._call_wrapper_json(
self._model.async_generate, prompt, *args, **kwargs
self._model.async_generate,
prompt,
*args,
**kwargs,
)
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")

Expand Down Expand Up @@ -534,17 +547,26 @@ async def chat(self, messages: List[Dict], *args, **kwargs):
response = None
try:
if self.allow_batching():
# not support request_id
kwargs.pop("request_id", None)
return await self.handle_batching_request(
messages, "chat", *args, **kwargs
)
else:
kwargs.pop("raw_params", None)
if hasattr(self._model, "chat"):
# not support request_id
kwargs.pop("request_id", None)
response = await self._call_wrapper_json(
self._model.chat, messages, *args, **kwargs
)
return response
if hasattr(self._model, "async_chat"):
if "request_id" not in kwargs:
kwargs["request_id"] = str(uuid.uuid1())
else:
# model only accept string
kwargs["request_id"] = str(kwargs["request_id"])
response = await self._call_wrapper_json(
self._model.async_chat, messages, *args, **kwargs
)
Expand Down Expand Up @@ -577,9 +599,10 @@ async def abort_request(self, request_id: str) -> str:
return await self._scheduler_ref.abort_request(request_id)
return AbortRequestMessage.NO_OP.name

@log_async(logger=logger)
@request_limit
@log_async(logger=logger)
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
kwargs.pop("request_id", None)
if hasattr(self._model, "create_embedding"):
return await self._call_wrapper_json(
self._model.create_embedding, input, *args, **kwargs
Expand All @@ -589,8 +612,8 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
f"Model {self._model.model_spec} is not for creating embedding."
)

@log_async(logger=logger)
@request_limit
@log_async(logger=logger)
async def rerank(
self,
documents: List[str],
Expand All @@ -602,6 +625,7 @@ async def rerank(
*args,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "rerank"):
return await self._call_wrapper_json(
self._model.rerank,
Expand All @@ -616,8 +640,8 @@ async def rerank(
)
raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
@log_async(logger=logger, ignore_kwargs=["audio"])
async def transcriptions(
self,
audio: bytes,
Expand All @@ -626,7 +650,9 @@ async def transcriptions(
response_format: str = "json",
temperature: float = 0,
timestamp_granularities: Optional[List[str]] = None,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "transcriptions"):
return await self._call_wrapper_json(
self._model.transcriptions,
Expand All @@ -641,8 +667,8 @@ async def transcriptions(
f"Model {self._model.model_spec} is not for creating transcriptions."
)

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
@log_async(logger=logger, ignore_kwargs=["audio"])
async def translations(
self,
audio: bytes,
Expand All @@ -651,7 +677,9 @@ async def translations(
response_format: str = "json",
temperature: float = 0,
timestamp_granularities: Optional[List[str]] = None,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "translations"):
return await self._call_wrapper_json(
self._model.translations,
Expand All @@ -668,10 +696,7 @@ async def translations(

@request_limit
@xo.generator
@log_async(
logger=logger,
args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
)
@log_async(logger=logger, ignore_kwargs=["prompt_speech"])
async def speech(
self,
input: str,
Expand All @@ -681,6 +706,7 @@ async def speech(
stream: bool = False,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "speech"):
return await self._call_wrapper_binary(
self._model.speech,
Expand All @@ -695,8 +721,8 @@ async def speech(
f"Model {self._model.model_spec} is not for creating speech."
)

@log_async(logger=logger)
@request_limit
@log_async(logger=logger)
async def text_to_image(
self,
prompt: str,
Expand All @@ -706,6 +732,7 @@ async def text_to_image(
*args,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "text_to_image"):
return await self._call_wrapper_json(
self._model.text_to_image,
Expand All @@ -720,6 +747,10 @@ async def text_to_image(
f"Model {self._model.model_spec} is not for creating image."
)

@log_async(
logger=logger,
ignore_kwargs=["image"],
)
async def image_to_image(
self,
image: "PIL.Image",
Expand All @@ -731,6 +762,7 @@ async def image_to_image(
*args,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "image_to_image"):
return await self._call_wrapper_json(
self._model.image_to_image,
Expand All @@ -747,6 +779,10 @@ async def image_to_image(
f"Model {self._model.model_spec} is not for creating image."
)

@log_async(
logger=logger,
ignore_kwargs=["image"],
)
async def inpainting(
self,
image: "PIL.Image",
Expand All @@ -759,6 +795,7 @@ async def inpainting(
*args,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "inpainting"):
return await self._call_wrapper_json(
self._model.inpainting,
Expand All @@ -776,12 +813,13 @@ async def inpainting(
f"Model {self._model.model_spec} is not for creating image."
)

@log_async(logger=logger)
@request_limit
@log_async(logger=logger, ignore_kwargs=["image"])
async def infer(
self,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "infer"):
return await self._call_wrapper_json(
self._model.infer,
Expand All @@ -791,15 +829,16 @@ async def infer(
f"Model {self._model.model_spec} is not for flexible infer."
)

@log_async(logger=logger)
@request_limit
@log_async(logger=logger)
async def text_to_video(
self,
prompt: str,
n: int = 1,
*args,
**kwargs,
):
kwargs.pop("request_id", None)
if hasattr(self._model, "text_to_video"):
return await self._call_wrapper_json(
self._model.text_to_video,
Expand Down
102 changes: 80 additions & 22 deletions xinference/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,120 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import random
import string
from typing import Dict, Generator, List, Tuple, Union
import uuid
from typing import Dict, Generator, List, Optional, Tuple, Union

import orjson
from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown

from .._compat import BaseModel
from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH

logger = logging.getLogger(__name__)


def log_async(logger, args_formatter=None):
def truncate_log_arg(arg) -> str:
s = str(arg)
if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH:
s = s[0:XINFERENCE_LOG_ARG_MAX_LENGTH] + "..."
return s


def log_async(
logger,
level=logging.DEBUG,
ignore_kwargs: Optional[List[str]] = None,
log_exception=True,
):
import time
from functools import wraps

def decorator(func):
func_name = func.__name__

@wraps(func)
async def wrapped(*args, **kwargs):
if args_formatter is not None:
formatted_args, formatted_kwargs = copy.copy(args), copy.copy(kwargs)
args_formatter(formatted_args, formatted_kwargs)
else:
formatted_args, formatted_kwargs = args, kwargs
logger.debug(
f"Enter {func.__name__}, args: {formatted_args}, kwargs: {formatted_kwargs}"
request_id_str = kwargs.get("request_id", "")
if not request_id_str:
request_id_str = uuid.uuid1()
request_id_str = f"[request {request_id_str}]"
formatted_args = ",".join(map(truncate_log_arg, args))
formatted_kwargs = ",".join(
[
"%s=%s" % (k, truncate_log_arg(v))
for k, v in kwargs.items()
if ignore_kwargs is None or k not in ignore_kwargs
]
)
start = time.time()
ret = await func(*args, **kwargs)
logger.debug(
f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s"
logger.log(
level,
f"{request_id_str} Enter {func_name}, args: {formatted_args}, kwargs: {formatted_kwargs}",
)
return ret
start = time.time()
try:
ret = await func(*args, **kwargs)
logger.log(
level,
f"{request_id_str} Leave {func_name}, elapsed time: {int(time.time() - start)} s",
)
return ret
except Exception as e:
if log_exception:
logger.error(
f"{request_id_str} Leave {func_name}, error: {e}, elapsed time: {int(time.time() - start)} s",
exc_info=True,
)
else:
logger.log(
level,
f"{request_id_str} Leave {func_name}, error: {e}, elapsed time: {int(time.time() - start)} s",
)
raise

return wrapped

return decorator


def log_sync(logger):
def log_sync(logger, level=logging.DEBUG, log_exception=True):
import time
from functools import wraps

def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
logger.debug(f"Enter {func.__name__}, args: {args}, kwargs: {kwargs}")
start = time.time()
ret = func(*args, **kwargs)
logger.debug(
f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s"
formatted_args = ",".join(map(truncate_log_arg, args))
formatted_kwargs = ",".join(
map(lambda x: "%s=%s" % (x[0], truncate_log_arg(x[1])), kwargs.items())
)
return ret
logger.log(
level,
f"Enter {func.__name__}, args: {formatted_args}, kwargs: {formatted_kwargs}",
)
start = time.time()
try:
ret = func(*args, **kwargs)
logger.log(
level,
f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s",
)
return ret
except Exception as e:
if log_exception:
logger.error(
f"Leave {func.__name__}, error: {e}, elapsed time: {int(time.time() - start)} s",
exc_info=True,
)
else:
logger.log(
level,
f"Leave {func.__name__}, error: {e}, elapsed time: {int(time.time() - start)} s",
)
raise

return wrapped

Expand Down
Loading

0 comments on commit e2618be

Please sign in to comment.