Skip to content

Commit

Permalink
Merge pull request #60 from mynhardtburger/case-insensitive-route-info
Browse files Browse the repository at this point in the history
Bug fix: Make get_route_info() case insensitive
  • Loading branch information
evaline-ju authored Aug 1, 2024
2 parents b7bb118 + 3fe4942 commit 057b5fb
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 18 deletions.
47 changes: 44 additions & 3 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Standard
from copy import deepcopy
from threading import Lock
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union

# Third Party
import grpc
Expand Down Expand Up @@ -196,6 +196,11 @@ def handle_runtime_context(
{"hostname": route_info},
fill_with_defaults=True,
)
else:
log.debug(
"<TGB32948346D> No %s context override found",
self.ROUTE_INFO_HEADER_KEY,
)

## Backend user interface ##

Expand Down Expand Up @@ -351,6 +356,7 @@ def get_route_info(
context: Optional[RuntimeServerContextType],
) -> Optional[str]:
"""Get the string value of the x-route-info header/metadata if present
in a case insensitive manner.
Args:
context (Optional[RuntimeServerContextType]): The grpc or fastapi
Expand All @@ -363,9 +369,12 @@ def get_route_info(
if context is None:
return context
if isinstance(context, grpc.ServicerContext):
return dict(context.invocation_metadata()).get(cls.ROUTE_INFO_HEADER_KEY)
return TGISBackend._request_metadata_get(
context.invocation_metadata(), cls.ROUTE_INFO_HEADER_KEY
)

if HAVE_FASTAPI and isinstance(context, fastapi.Request):
return context.headers.get(cls.ROUTE_INFO_HEADER_KEY)
return TGISBackend._request_header_get(context, cls.ROUTE_INFO_HEADER_KEY)
error.log_raise(
"<TGB92615097E>",
TypeError(f"context is of an unsupported type: {type(context)}"),
Expand Down Expand Up @@ -415,6 +424,38 @@ def _safely_update_state(
if remote_models_cfg:
self._remote_models_cfg.setdefault(model_id, remote_models_cfg)

@classmethod
def _request_header_get(cls, request: fastapi.Request, key: str) -> Optional[str]:
"""
Returns the first matching value for the header key (case insensitive).
If no matching header was found return None.
"""
# https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/datastructures.py#L543
items: list[Tuple[str, str]] = request.headers.items()
get_header_key = key.lower()

for header_key, header_value in items:
if header_key.lower() == get_header_key:
return header_value

@classmethod
def _request_metadata_get(
cls, metadata: Tuple[str, Union[str, bytes]], key: str
) -> Optional[str]:
"""
Returns the first matching value for the metadata key (case insensitive).
If no matching metadata was found return None.
"""
# https://grpc.github.io/grpc/python/glossary.html#term-metadatum
get_metadata_key = key.lower()

for metadata_key, metadata_value in metadata:
if str(metadata_key).lower() == get_metadata_key:
if isinstance(metadata_value, str):
return metadata_value
if isinstance(metadata_value, bytes):
return metadata_value.decode()


# Register local backend
register_backend_type(TGISBackend)
126 changes: 111 additions & 15 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Standard
from copy import deepcopy
from dataclasses import asdict
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from unittest import mock
import os
import tempfile
Expand Down Expand Up @@ -97,17 +97,60 @@ def mock_tgis_fixture():
mock_tgis.stop()


class TestServicerContext:
class TestServicerContext(grpc.ServicerContext):
"""
A dummy class for mimicking ServicerContext invocation metadata storage.
"""

def __init__(self, metadata: Dict[str, Union[str, bytes]]):
self.metadata = metadata

def invocation_metadata(self):
def invocation_metadata(self) -> Sequence[Tuple[str, Union[str, bytes]]]:
# https://grpc.github.io/grpc/python/glossary.html#term-metadata
return list(self.metadata.items())

def is_active(self):
raise NotImplementedError

def time_remaining(self):
raise NotImplementedError

def cancel(self):
raise NotImplementedError

def add_callback(self, callback):
raise NotImplementedError

def peer(self):
raise NotImplementedError

def peer_identities(self):
raise NotImplementedError

def peer_identity_key(self):
raise NotImplementedError

def auth_context(self):
raise NotImplementedError

def send_initial_metadata(self, initial_metadata):
raise NotImplementedError

def set_trailing_metadata(self, trailing_metadata):
raise NotImplementedError

def abort(self, code, details):
raise NotImplementedError

def abort_with_status(self, status):
raise NotImplementedError

def set_code(self, code):
raise NotImplementedError

def set_details(self, details):
raise NotImplementedError


## Conn Config #################################################################

Expand Down Expand Up @@ -927,34 +970,84 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure):
{
"type": "http",
"headers": [
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), b"sometext")
(
TGISBackend.ROUTE_INFO_HEADER_KEY.encode("latin-1"),
"http exact".encode("latin-1"),
)
],
}
),
"http exact",
),
(
fastapi.Request(
{
"type": "http",
"headers": [
(
TGISBackend.ROUTE_INFO_HEADER_KEY.upper().encode("latin-1"),
"http upper-case".encode("latin-1"),
)
],
}
),
"sometext",
"http upper-case",
),
(
fastapi.Request(
{"type": "http", "headers": [(b"route-info", b"sometext")]}
{
"type": "http",
"headers": [
(
TGISBackend.ROUTE_INFO_HEADER_KEY.title().encode("latin-1"),
"http title-case".encode("latin-1"),
)
],
}
),
"http title-case",
),
(
fastapi.Request(
{
"type": "http",
"headers": [
(
"route-info".encode("latin-1"),
"http not-found".encode("latin-1"),
)
],
}
),
None,
),
(
TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "sometext"}),
"sometext",
TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "grpc exact"}),
"grpc exact",
),
(
TestServicerContext({"route-info": "sometext"}),
TestServicerContext(
{TGISBackend.ROUTE_INFO_HEADER_KEY.upper(): "grpc upper-case"}
),
"grpc upper-case",
),
(
TestServicerContext(
{TGISBackend.ROUTE_INFO_HEADER_KEY.title(): "grpc title-case"}
),
"grpc title-case",
),
(
TestServicerContext({"route-info": "grpc not found"}),
None,
),
("should raise TypeError", None),
("should raise TypeError", TypeError()),
(None, None),
# Uncertain how to create a grpc.ServicerContext object
],
)
def test_get_route_info(context, route_info: Optional[str]):
if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
with pytest.raises(TypeError):
def test_get_route_info(context, route_info: Union[str, None, Exception]):
if isinstance(route_info, Exception):
with pytest.raises(type(route_info)):
TGISBackend.get_route_info(context)
else:
actual_route_info = TGISBackend.get_route_info(context)
Expand All @@ -970,7 +1063,10 @@ def test_handle_runtime_context_with_route_info():
{
"type": "http",
"headers": [
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), route_info.encode("utf-8"))
(
TGISBackend.ROUTE_INFO_HEADER_KEY.encode("latin-1"),
route_info.encode("latin-1"),
)
],
}
)
Expand Down

0 comments on commit 057b5fb

Please sign in to comment.