Skip to content

Commit

Permalink
feat: handle API rate-limiting
Browse files Browse the repository at this point in the history
Automatically retry API calls when we hit a rate limit. Notify the caller
through a callback when this happens.
  • Loading branch information
agateau-gg committed Dec 21, 2023
1 parent 7cf0356 commit b52f1fd
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Added

- GGClient now obeys rate-limits and can notify callers when hitting one.
4 changes: 2 additions & 2 deletions pygitguardian/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""PyGitGuardian API Client"""
from .client import ContentTooLarge, GGClient
from .client import ContentTooLarge, GGClient, GGClientCallbacks


__version__ = "1.11.0"
GGClient._version = __version__

__all__ = ["GGClient", "ContentTooLarge"]
__all__ = ["GGClient", "GGClientCallbacks", "ContentTooLarge"]
57 changes: 45 additions & 12 deletions pygitguardian/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tarfile
import time
import urllib.parse
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast
Expand Down Expand Up @@ -46,6 +47,8 @@
# max files size to create a tar from
MAX_TAR_CONTENT_SIZE = 30 * 1024 * 1024

HTTP_TOO_MANY_REQUESTS = 429


class ContentTooLarge(Exception):
"""
Expand Down Expand Up @@ -124,6 +127,15 @@ def _create_tar(root_path: Path, filenames: List[str]) -> bytes:
return tar_stream.getvalue()


class GGClientCallbacks(ABC):
"""Abstract class used to notify GGClient users of events"""

@abstractmethod
def on_rate_limited(self, delay: int) -> None:
"""Called when GGClient hits a rate-limit."""
... # pragma: no cover


class GGClient:
_version = "undefined"
session: Session
Expand All @@ -133,6 +145,7 @@ class GGClient:
user_agent: str
extra_headers: Dict
secret_scan_preferences: SecretScanPreferences
callbacks: Optional[GGClientCallbacks]

def __init__(
self,
Expand All @@ -141,13 +154,15 @@ def __init__(
session: Optional[Session] = None,
user_agent: Optional[str] = None,
timeout: Optional[float] = DEFAULT_TIMEOUT,
callbacks: Optional[GGClientCallbacks] = None,
):
"""
:param api_key: API Key to be added to requests
:param base_uri: Base URI for the API, defaults to "https://api.gitguardian.com"
:param session: custom requests session, defaults to requests.Session()
:param user_agent: user agent to identify requests, defaults to ""
:param timeout: request timeout, defaults to 20s
:param callbacks: object used to receive callbacks from the client, defaults to None
:raises ValueError: if the protocol or the api_key is invalid
"""
Expand Down Expand Up @@ -177,6 +192,7 @@ def __init__(
self.api_key = api_key
self.session = session if isinstance(session, Session) else Session()
self.timeout = timeout
self.callbacks = callbacks
self.user_agent = "pygitguardian/{} ({};py{})".format(
self._version, platform.system(), platform.python_version()
)
Expand Down Expand Up @@ -207,18 +223,35 @@ def request(
if extra_headers
else self.session.headers
)
start = time.time()
response: Response = self.session.request(
method=method, url=url, timeout=self.timeout, headers=headers, **kwargs
)
duration = time.time() - start
logger.debug(
"method=%s endpoint=%s status_code=%s duration=%f",
method,
endpoint,
response.status_code,
duration,
)
while True:
start = time.time()
response: Response = self.session.request(
method=method, url=url, timeout=self.timeout, headers=headers, **kwargs
)
duration = time.time() - start
logger.debug(
"method=%s endpoint=%s status_code=%s duration=%f",
method,
endpoint,
response.status_code,
duration,
)
if response.status_code == HTTP_TOO_MANY_REQUESTS:
logger.warning("Rate-limit hit")
try:
delay = int(response.headers["Retry-After"])
except (ValueError, KeyError):
# We failed to parse the Retry-After header, return the response as
# is so the caller handles it as an error
logger.error("Could not get the retry-after value")
return response

if self.callbacks:
self.callbacks.on_rate_limited(delay)
logger.warning("Waiting for %d seconds before retrying", delay)
time.sleep(delay)
else:
break

self.app_version = response.headers.get("X-App-Version", self.app_version)
self.secrets_engine_version = response.headers.get(
Expand Down
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from os.path import dirname, join, realpath
from typing import Any

import pytest
import vcr
Expand All @@ -19,7 +20,12 @@
)


def create_client(**kwargs: Any) -> GGClient:
"""Create a GGClient using $GITGUARDIAN_API_KEY"""
api_key = os.environ["GITGUARDIAN_API_KEY"]
return GGClient(api_key=api_key, **kwargs)


@pytest.fixture
def client():
api_key = os.environ["GITGUARDIAN_API_KEY"]
return GGClient(api_key=api_key)
return create_client()
75 changes: 72 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from datetime import date
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Type
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
import responses
from marshmallow import ValidationError
from responses import matchers

from pygitguardian import GGClient
from pygitguardian.client import is_ok, load_detail
from pygitguardian.client import GGClientCallbacks, is_ok, load_detail
from pygitguardian.config import (
DEFAULT_BASE_URI,
DOCUMENT_SIZE_THRESHOLD_BYTES,
Expand All @@ -36,7 +36,7 @@
SCAVulnerability,
)

from .conftest import my_vcr
from .conftest import create_client, my_vcr


FILENAME = ".env"
Expand Down Expand Up @@ -612,6 +612,75 @@ def test_multiscan_parameters(
assert mock_response.call_count == 1


@responses.activate
def test_rate_limit():
"""
GIVEN a GGClient instance with callbacks
WHEN calling an API endpoint and we hit a rate-limit
THEN the client retries after the delay
AND the `on_rate_limited()` method of the callback is called
"""
callbacks = Mock(spec=GGClientCallbacks)

client = create_client(callbacks=callbacks)
multiscan_url = client._url_from_endpoint("multiscan", "v1")

rate_limit_response = responses.post(
url=multiscan_url,
status=429,
headers={"Retry-After": "1"},
)
normal_response = responses.post(
url=multiscan_url,
status=200,
json=[
{
"policy_break_count": 0,
"policies": ["pol"],
"policy_breaks": [],
}
],
)

result = client.multi_content_scan(
[{"filename": FILENAME, "document": DOCUMENT}],
)

assert rate_limit_response.call_count == 1
assert normal_response.call_count == 1
assert isinstance(result, MultiScanResult)
callbacks.on_rate_limited.assert_called_once_with(1)


@responses.activate
def test_bogus_rate_limit():
"""
GIVEN a GGClient instance with callbacks
WHEN calling an API endpoint and we hit a rate-limit
AND we can't parse the rate-limit value
THEN the client just returns the error
AND the `on_rate_limited()` method of the callback is not called
"""
callbacks = Mock(spec=GGClientCallbacks)

client = create_client(callbacks=callbacks)
multiscan_url = client._url_from_endpoint("multiscan", "v1")

rate_limit_response = responses.post(
url=multiscan_url,
status=429,
headers={"Retry-After": "later"},
)

result = client.multi_content_scan(
[{"filename": FILENAME, "document": DOCUMENT}],
)

assert rate_limit_response.call_count == 1
assert isinstance(result, Detail)
callbacks.on_rate_limited.assert_not_called()


def test_quota_overview(client: GGClient):
with my_vcr.use_cassette("quota.yaml"):
quota_response = client.quota_overview()
Expand Down

0 comments on commit b52f1fd

Please sign in to comment.