|
22 | 22 | import uuid |
23 | 23 | from functools import lru_cache |
24 | 24 | from http import HTTPStatus |
25 | | -from typing import Callable, Optional, Tuple, Type, Union |
| 25 | +from shlex import quote |
| 26 | +from typing import Any, Callable, List, Optional, Tuple, Type, Union |
26 | 27 |
|
27 | 28 | import requests |
28 | 29 | from requests import HTTPError, Response |
@@ -82,13 +83,15 @@ def add_headers(self, request, **kwargs): |
82 | 83 | request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) |
83 | 84 |
|
84 | 85 | # Add debug log |
85 | | - has_token = str(request.headers.get("authorization", "")).startswith("Bearer hf_") |
| 86 | + has_token = len(str(request.headers.get("authorization", ""))) > 0 |
86 | 87 | logger.debug( |
87 | 88 | f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" |
88 | 89 | ) |
89 | 90 |
|
90 | 91 | def send(self, request: PreparedRequest, *args, **kwargs) -> Response: |
91 | 92 | """Catch any RequestException to append request id to the error message for debugging.""" |
| 93 | + if constants.HF_DEBUG: |
| 94 | + logger.debug(f"Send: {_curlify(request)}") |
92 | 95 | try: |
93 | 96 | return super().send(request, *args, **kwargs) |
94 | 97 | except requests.RequestException as e: |
@@ -549,3 +552,41 @@ def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Res |
549 | 552 |
|
550 | 553 | # Return |
551 | 554 | return error_type(final_error_message.strip(), response=response, server_message=server_message or None) |
| 555 | + |
| 556 | + |
| 557 | +def _curlify(request: requests.PreparedRequest) -> str: |
| 558 | + """Convert a `requests.PreparedRequest` into a curl command (str). |
| 559 | +
|
| 560 | + Used for debug purposes only. |
| 561 | +
|
| 562 | + Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py. |
| 563 | + MIT License Copyright (c) 2016 Egor. |
| 564 | + """ |
| 565 | + parts: List[Tuple[Any, Any]] = [ |
| 566 | + ("curl", None), |
| 567 | + ("-X", request.method), |
| 568 | + ] |
| 569 | + |
| 570 | + for k, v in sorted(request.headers.items()): |
| 571 | + if k.lower() == "authorization": |
| 572 | + v = "<TOKEN>" # Hide authorization header, no matter its value (can be Bearer, Key, etc.) |
| 573 | + parts += [("-H", "{0}: {1}".format(k, v))] |
| 574 | + |
| 575 | + if request.body: |
| 576 | + body = request.body |
| 577 | + if isinstance(body, bytes): |
| 578 | + body = body.decode("utf-8") |
| 579 | + if len(body) > 1000: |
| 580 | + body = body[:1000] + " ... [truncated]" |
| 581 | + parts += [("-d", body)] |
| 582 | + |
| 583 | + parts += [(None, request.url)] |
| 584 | + |
| 585 | + flat_parts = [] |
| 586 | + for k, v in parts: |
| 587 | + if k: |
| 588 | + flat_parts.append(quote(k)) |
| 589 | + if v: |
| 590 | + flat_parts.append(quote(v)) |
| 591 | + |
| 592 | + return " ".join(flat_parts) |
0 commit comments