Skip to content

Commit

Permalink
Merge pull request #34 from joerunde/load-balancing
Browse files Browse the repository at this point in the history
✨ add LB config to tgis connection
  • Loading branch information
gabe-l-hart authored Oct 5, 2023
2 parents 8e41e0c + 47863aa commit ba4dcee
Show file tree
Hide file tree
Showing 4 changed files with 566 additions and 4 deletions.
191 changes: 191 additions & 0 deletions caikit_tgis_backend/load_balancing_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides a grpc client which:
- Sets client-side load-balancing options
- Polls DNS and triggers channel re-connection when new endpoints are detected
"""
# Standard
from functools import partial
from typing import Generic, List, Optional, Set, Tuple, Type, TypeVar
import socket
import threading

# Third Party
import grpc

# First Party
from caikit.core.exceptions import error_handler
import alog

T = TypeVar("T")

log = alog.use_channel("TGCONN")
error = error_handler.get(log)


# pylint: disable=too-many-instance-attributes
class GRPCLoadBalancerProxy(Generic[T]):
"""Proxies a grpc client class T, reconnecting the client when new IPs are available"""

def __init__(
self,
client_class: Type[T],
target: str,
policy: str = "round_robin",
poll_interval_s: Optional[float] = 10,
credentials: Optional[str] = None,
channel_options: Optional[List[Tuple[str, str]]] = None,
):
# Ensure that self._client always exists. It is required by the __getattr__ proxying
self._client = None
self.client_class = client_class
self.target = target

error.value_check(
"<TGB54435438E>",
target.count(":") == 1,
"Target must be provided in {host}:{port} format",
)

error.value_check(
"<TGB01133969E>",
poll_interval_s is None or poll_interval_s >= 0,
"poll_interval_s should be > 0",
)

channel_options = channel_options or []
# pylint: disable=line-too-long
# Cite: https://grpc.github.io/grpc/core/group__grpc__arg__keys.html#ga72c2b475e218ecfd36bb7d3551d0295b
channel_options.append(("grpc.lb_policy_name", policy))

# Save a partial for re-constructing channels later
if credentials:
log.debug3("Creating load-balancing client with secure channel")
self.channel_partial = partial(
grpc.secure_channel,
target=self.target,
options=channel_options,
credentials=credentials,
)
else:
log.debug3("Creating load-balancing client with insecure channel")
self.channel_partial = partial(
grpc.insecure_channel, target=self.target, options=channel_options
)

# Build the client once
self._client = self.client_class(self.channel_partial())

# Get initial IP set
self._ip_set: Set[Tuple[str, int]] = set()

self.poll_interval = poll_interval_s
self._timer: Optional[threading.Timer] = None
self._poll_lock = threading.Lock()
self._shutdown = False
if self.poll_interval:
log.debug2(
"Enabling DNS poll interval every %f seconds", self.poll_interval
)
self._dns_poll()

def __del__(self):
"""Attempt a bit of cleanup on GC"""
self.shutdown_dns_poll()

def __getattr__(self, item):
"""Proxies self._client so that self is the grpc client"""
return getattr(self._client, item)

@property
def client(self) -> T:
"""Syntactic sugar to assert that we are in fact a type T.
Returns the client instance (self). The channel that this client holds will periodically be
replaced when DNS polling indicates new hosts are available."""
return self

def shutdown_dns_poll(self):
"""Shuts down the internal DNS poll.
This should happen on garbage collection, and is exposed here to explicitly control the
polling lifecycle if needed."""
self._shutdown = True
if (
hasattr(self, "_timer")
and self._timer is not None
and self._timer.is_alive()
):
self._timer.cancel()

def _dns_poll(self):
"""Run the internal DNS poll. This method re-schedules itself until shutdown_dns_poll
is called."""
if self._shutdown:
return
# Lock for both _ip_set and _timer
with self._poll_lock:
try:
log.debug3("Polling DNS for updates to service: %s", self.target)

new_ip_set = self._get_ip_set()

# Create a new client only if new IP/port pairs are found
if new_ip_set - self._ip_set:
self._reconnect()

self._ip_set = new_ip_set
except (socket.gaierror, socket.herror):
log.warning("Failed to poll DNS for updates", exc_info=True)

except Exception as ex: # pylint: disable=broad-exception-caught
log.warning(
"<TGB58023131W>",
"Unhandled exception caught during polling DNS for updates: %s",
ex,
exc_info=True,
)

# Cancel any duplicate timers
if self._timer is not None and self._timer.is_alive():
self._timer.cancel()

# Schedule next poll
log.debug3("Scheduling next DNS poll in %s seconds", self.poll_interval)
self._timer = threading.Timer(self.poll_interval, self._dns_poll)
self._timer.daemon = True
self._timer.start()

def _reconnect(self):
"""Force-reconnect the client by re-invoking the initializer with a new channel"""
log.debug3("Reconnecting channel for service: %s", self.target)
# 🌶️🌶️🌶️ We don't want to rebuild a new client, since that would require that all users
# update any client references that they're holding.
# This __init__ call re-initializes the client instance that many things may be holding.
# This should be safe since the grpc client classes are "dumb" wrappers around channels.
# pylint: disable=unnecessary-dunder-call
self.client_class.__init__(self=self._client, channel=self.channel_partial())

def _get_ip_set(self) -> Set[Tuple[str, int]]:
"""Uses `socket` to attempt a DNS lookup.
Returns a set of (ip address, port) tuples that self.target resolves to
"""
host, port = self.target.split(":")
hosts = socket.getaddrinfo(host, port)
# socket.getaddrinfo returns a tuple containing information
# about socket, where 4th index contains sockaddr containing
# ip address and port
ip_set = {host[4] for host in hosts}
log.debug3("IPs for target: %s, %s", self.target, ip_set)
return ip_set
47 changes: 44 additions & 3 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import alog

# Local
from .load_balancing_proxy import GRPCLoadBalancerProxy
from .protobufs import generation_pb2, generation_pb2_grpc

log = alog.use_channel("TGCONN")
Expand All @@ -39,6 +40,7 @@ class TLSFilePair:
key_file: str


# pylint: disable=too-many-instance-attributes
@dataclass
class TGISConnection:

Expand All @@ -56,6 +58,10 @@ class TGISConnection:
client_tls: Optional[TLSFilePair] = None
# Mounted directory where TGIS will look for prompt vector artifacts
prompt_dir: Optional[str] = None
# Load balancing policy
lb_policy: Optional[str] = None
# DNS poll interval (seconds) for LB updates
lb_poll_interval_s: Optional[float] = None
# Private member to hold the client once created
_client: Optional[generation_pb2_grpc.GenerationServiceStub] = None

Expand All @@ -69,6 +75,8 @@ class TGISConnection:
CLIENT_CERT_FILE_KEY = "client_cert_file"
CLIENT_KEY_FILE_KEY = "client_key_file"
PROMPT_DIR_KEY = "prompt_dir"
LB_POLICY_KEY = "grpc_lb_policy_name"
LB_POLL_INTERVAL_KEY = "grpc_lb_poll_interval_s"

@classmethod
def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
Expand All @@ -82,6 +90,23 @@ def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
)
log.debug("Resolved hostname [%s] for model %s", hostname, model_id)

lb_policy = config.get(cls.LB_POLICY_KEY) or None
error.type_check(
"<TGB17223790E>",
str,
allow_none=True,
**{cls.LB_POLICY_KEY: lb_policy},
)

lb_poll_interval_s = config.get(cls.LB_POLL_INTERVAL_KEY) or None
error.type_check(
"<TGB17223790E>",
float,
int,
allow_none=True,
**{cls.LB_POLL_INTERVAL_KEY: lb_poll_interval_s},
)

# Look for the prompt dir
prompt_dir = config.get(cls.PROMPT_DIR_KEY) or None
error.type_check(
Expand Down Expand Up @@ -153,6 +178,8 @@ def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
ca_cert_file=ca_cert,
client_tls=client_tls,
prompt_dir=prompt_dir,
lb_policy=lb_policy,
lb_poll_interval_s=lb_poll_interval_s,
)

@property
Expand Down Expand Up @@ -227,6 +254,7 @@ def unload_prompt_artifacts(self, *prompt_ids: List[str]):

def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
"""Get a grpc client for the connection"""
load_balancer_kwargs = {}
if self._client is None:
log.info(
"<TGB20236231I>",
Expand All @@ -237,7 +265,6 @@ def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
)
if not self.tls_enabled:
log.debug("Connecting to TGIS at [%s] INSECURE", self.hostname)
channel = grpc.insecure_channel(self.hostname)
else:
log.debug("Connecting to TGIS at [%s] SECURE", self.hostname)
creds_kwargs = {
Expand All @@ -252,8 +279,22 @@ def get_client(self) -> generation_pb2_grpc.GenerationServiceStub:
self.client_tls.key_file
)
credentials = grpc.ssl_channel_credentials(**creds_kwargs)
channel = grpc.secure_channel(self.hostname, credentials=credentials)
self._client = generation_pb2_grpc.GenerationServiceStub(channel)
load_balancer_kwargs["credentials"] = credentials

# Override the lb policy, if configured.
# The load balancer will always provide a default if we don't pass one.
if self.lb_policy:
load_balancer_kwargs["policy"] = self.lb_policy

if self.lb_poll_interval_s:
load_balancer_kwargs["poll_interval_s"] = self.lb_poll_interval_s

load_balancer = GRPCLoadBalancerProxy(
client_class=generation_pb2_grpc.GenerationServiceStub,
target=self.hostname,
**load_balancer_kwargs,
)
self._client = load_balancer.client
return self._client

def test_connection(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit>=0.16.0,<0.20.0", # Core abstractions
"caikit>=0.16.0,<0.21.0", # Core abstractions
"grpcio>=1.35.0,<2.0", # Client calls to TGIS
"requests>=2.28.2,<3", # Health check calls to TGIS
]
Expand Down
Loading

0 comments on commit ba4dcee

Please sign in to comment.