-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Joe Runde <[email protected]>
- Loading branch information
Showing
2 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Provides a grpc client which: | ||
- Sets client-side load-balancing options | ||
- Polls DNS and triggers channel re-connection when new endpoints are detected | ||
""" | ||
# Standard | ||
from threading import RLock | ||
from typing import Generic, List, Optional, Set, Tuple, Type, TypeVar | ||
import socket | ||
import threading | ||
|
||
# Third Party | ||
import grpc | ||
|
||
# First Party | ||
from caikit.core.exceptions import error_handler | ||
import alog | ||
|
||
T = TypeVar("T") | ||
|
||
log = alog.use_channel("TGCONN") | ||
error = error_handler.get(log) | ||
|
||
|
||
# pylint: disable=too-many-instance-attributes | ||
class GRPCLoadBalancer(Generic[T]): | ||
"""Wraps a grpc client class T, rebuilding the client when new IPs are available""" | ||
|
||
def __init__( | ||
self, | ||
client_class: Type[T], | ||
target: str, | ||
policy: str = "round_robin", | ||
poll_interval_s: float = 10, | ||
credentials: Optional[str] = None, | ||
channel_options: Optional[List[Tuple[str, str]]] = None, | ||
): | ||
self.client_class = client_class | ||
self.target = target | ||
|
||
error.value_check( | ||
"<TGB54435438E>", | ||
target.count(":") == 1, | ||
"Target must be provided in {host}:{port} format", | ||
) | ||
self.options = channel_options or [] | ||
self.options.append(("grpc.lb_policy_name", policy)) | ||
self.credentials = credentials | ||
self._client = None | ||
self._client_lock = RLock() | ||
|
||
# Get initial IP set | ||
self._ip_set: Set[Tuple[str, int]] = set() | ||
|
||
self.poll_interval = poll_interval_s | ||
self._timer: Optional[threading.Timer] = None | ||
self._poll_for_ips() | ||
|
||
def __del__(self): | ||
if hasattr(self, "_timer") and self._timer is not None and self._timer.is_alive(): | ||
self._timer.cancel() | ||
|
||
def get_client(self) -> T: | ||
"""Returns the client. The result should not be cached as the client will be rebuilt | ||
periodically""" | ||
with self._client_lock: | ||
if self._client is None: | ||
self._rebuild_client() | ||
return self._client | ||
|
||
def _poll_for_ips(self): | ||
try: | ||
log.debug3("Polling DNS for updates to service: %s", self.target) | ||
new_ip_set = self._get_ip_set() | ||
|
||
# Create a new client only if new IP/port pairs are found | ||
if len(new_ip_set - self._ip_set) > 0: | ||
self._rebuild_client() | ||
|
||
self._ip_set = new_ip_set | ||
except Exception: # pylint: disable=broad-exception-caught | ||
log.warning("Failed to poll DNS for updates", exc_info=True) | ||
|
||
# Cancel any duplicate timers | ||
if self._timer is not None and self._timer.is_alive(): | ||
self._timer.cancel() | ||
|
||
# Schedule next poll | ||
log.debug3("Scheduling next DNS poll in %s seconds", self.poll_interval) | ||
self._timer = threading.Timer(self.poll_interval, self._poll_for_ips) | ||
self._timer.daemon = True | ||
self._timer.start() | ||
|
||
def _rebuild_client(self): | ||
log.debug3("Rebuilding client for service: %s", self.target) | ||
if self.credentials: | ||
channel = grpc.secure_channel( | ||
target=self.target, credentials=self.credentials, options=self.options | ||
) | ||
else: | ||
channel = grpc.insecure_channel(target=self.target, options=self.options) | ||
with self._client_lock: | ||
self._client = self.client_class(channel) | ||
|
||
def _get_ip_set(self) -> Set[Tuple[str, int]]: | ||
host, port = self.target.split(":") | ||
hosts = socket.getaddrinfo(host, port) | ||
ip_set = {host[4] for host in hosts} | ||
log.debug3("IPs for target: %s, %s", self.target, ip_set) | ||
return ip_set |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test the load-balancing client wrapper""" | ||
import datetime | ||
# Standard | ||
from concurrent import futures | ||
from socket import AddressFamily, SocketKind | ||
from typing import List | ||
from unittest import mock | ||
import contextlib | ||
|
||
# Third Party | ||
import grpc | ||
import pytest | ||
|
||
# Local | ||
from caikit_tgis_backend.load_balancing_client import GRPCLoadBalancer | ||
from caikit_tgis_backend.protobufs import generation_pb2, generation_pb2_grpc | ||
|
||
|
||
# 🌶️🌶️🌶️ These tests don't actually flex the real grpc load balancing between remotes. | ||
# It may be possible to run a local DNS server during testing, but it seems very difficult | ||
# to spin up multiple servers on localhost and somehow return DNS records that mimic what | ||
# kubedns does while still routing all traffic back to the local mocks. | ||
|
||
class TGISTestServer(generation_pb2_grpc.GenerationServiceServicer): | ||
def Tokenize(self, request, context): | ||
return generation_pb2.BatchedTokenizeResponse( | ||
responses=[ | ||
generation_pb2.TokenizeResponse( | ||
token_count=5, tokens=["hello ", "world ", "I ", "am ", "Zod."] | ||
) | ||
] | ||
) | ||
|
||
|
||
@contextlib.contextmanager | ||
def mock_tgis_server(port): | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) | ||
generation_pb2_grpc.add_GenerationServiceServicer_to_server( | ||
TGISTestServer(), server | ||
) | ||
server.add_insecure_port(f"[::]:{port}") | ||
server.start() | ||
yield | ||
server.stop(grace=0) | ||
|
||
|
||
@contextlib.contextmanager | ||
def mock_ip_set(ports: List[int]): | ||
with mock.patch("socket.getaddrinfo") as socket_mock: | ||
response_list = [] | ||
for port in ports: | ||
response_list.extend( | ||
[ | ||
( | ||
AddressFamily.AF_INET, | ||
SocketKind.SOCK_STREAM, | ||
6, | ||
"", | ||
("127.0.0.1", port), | ||
), | ||
( | ||
AddressFamily.AF_INET, | ||
SocketKind.SOCK_DGRAM, | ||
17, | ||
"", | ||
("127.0.0.1", port), | ||
), | ||
( | ||
AddressFamily.AF_INET, | ||
SocketKind.SOCK_RAW, | ||
0, | ||
"", | ||
("127.0.0.1", port), | ||
), | ||
] | ||
) | ||
socket_mock.return_value = response_list | ||
yield | ||
|
||
|
||
def test_client_works(): | ||
"""Basic test- does it turn on""" | ||
with mock_tgis_server(9000): | ||
wrapper = GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, target="localhost:9000") | ||
client = wrapper.get_client() | ||
|
||
response = client.Tokenize(request=generation_pb2.BatchedTokenizeRequest()) | ||
assert response.responses[0].token_count == 5 | ||
|
||
|
||
def test_target_validation(): | ||
"""Targets must be in host:port format""" | ||
with pytest.raises(ValueError, match="Target must be provided in .* format"): | ||
GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, target="localhost") | ||
|
||
with pytest.raises(ValueError, match="Target must be provided in .* format"): | ||
GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, target="9001") | ||
|
||
with pytest.raises(ValueError, match="Target must be provided in .* format"): | ||
# NB: dns targets not supported | ||
GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, target="dns://foo.bar/localhost:9001") | ||
|
||
|
||
def test_client_rebuilds_on_ip_change(): | ||
"""If a new pod in the target service appears, the grpc load balancer won't have any trigger | ||
to re-query DNS. Forcing a new client with a new channel will pick up the new pod.""" | ||
|
||
poll_interval = 0.0001 # 0.1 ms | ||
with mock_tgis_server(9000): | ||
with mock_ip_set([8080]): | ||
wrapper = GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, | ||
target="localhost:9000", | ||
poll_interval_s=poll_interval) | ||
client = wrapper.get_client() | ||
client.Tokenize(request=generation_pb2.BatchedTokenizeRequest()) | ||
|
||
with mock_ip_set([8080, 9090]): | ||
then = datetime.datetime.now() | ||
while client is wrapper.get_client(): | ||
assert datetime.datetime.now() - then < datetime.timedelta(milliseconds=100), "Client did not update" | ||
|
||
# new client still works | ||
new_client = wrapper.get_client() | ||
new_client.Tokenize(request=generation_pb2.BatchedTokenizeRequest()) | ||
|
||
|
||
def test_client_does_not_rebuild_when_ips_do_not_change(): | ||
"""Make sure we're not churning a ton of clients""" | ||
with mock_tgis_server(9000): | ||
with mock_ip_set([8080, 9090]): | ||
wrapper = GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, | ||
target="localhost:9000") | ||
client = wrapper.get_client() | ||
client.Tokenize(request=generation_pb2.BatchedTokenizeRequest()) | ||
|
||
# Force poll which would update the client | ||
wrapper._poll_for_ips() | ||
assert client is wrapper.get_client() | ||
|
||
|
||
def test_client_does_not_rebuild_when_ips_drop_out(): | ||
"""If a pod in the target service terminates, we don't need to bother rebuilding a client. | ||
The grpc load balancing policy should close the sub-channel and re-query DNS anyway.""" | ||
poll_interval = 0.0001 # 0.1 ms | ||
with mock_tgis_server(9000): | ||
with mock_ip_set([8080, 9090]): | ||
wrapper = GRPCLoadBalancer(client_class=generation_pb2_grpc.GenerationServiceStub, | ||
target="localhost:9000", | ||
poll_interval_s=poll_interval) | ||
client = wrapper.get_client() | ||
client.Tokenize(request=generation_pb2.BatchedTokenizeRequest()) | ||
|
||
with mock_ip_set([8080]): | ||
# Force poll which would update the client | ||
wrapper._poll_for_ips() | ||
assert client is wrapper.get_client() |