Skip to content

Commit

Permalink
Merge pull request #55 from mynhardtburger/register_connection
Browse files Browse the repository at this point in the history
Add `TGISBackend.register_model_connection()` method
  • Loading branch information
gabe-l-hart authored Jun 7, 2024
2 parents 1214b0b + b87c958 commit ddbccd3
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 47 deletions.
128 changes: 95 additions & 33 deletions caikit_tgis_backend/tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements a TGIS backend configuration
"""
"""This module implements a TGIS backend configuration"""

# Standard
from copy import deepcopy
from threading import Lock
from typing import Dict, Optional
from typing import Any, Dict, Optional

# Third Party
import grpc
Expand All @@ -35,6 +35,7 @@
log = alog.use_channel("TGISBKND")
error = error_handler.get(log)


# pylint: disable=too-many-instance-attributes
class TGISBackend(BackendBase):
"""Caikit backend with a connection to the TGIS server. If no connection
Expand Down Expand Up @@ -67,15 +68,17 @@ def __init__(self, config: Optional[dict] = None):
self._mutex = Lock()
self._local_tgis = None
self._managed_tgis = None
self._model_connections = {}
self._model_connections: Dict[str, TGISConnection] = {}
self._test_connections = self.config.get("test_connections", False)
self._connect_timeout = self.config.get("connect_timeout", None)

# Parse the config to see if we're managing a connection to a remote
# TGIS instance or running a local copy
connection_cfg = self.config.get("connection") or {}
error.type_check("<TGB20235229E>", dict, connection=connection_cfg)
self._remote_models_cfg = self.config.get("remote_models") or {}
self._remote_models_cfg: Dict[str, dict] = (
self.config.get("remote_models") or {}
)
error.type_check("<TGB20235338E>", dict, connection=self._remote_models_cfg)
local_cfg = self.config.get("local") or {}
error.type_check("<TGB20235225E>", dict, local=local_cfg)
Expand Down Expand Up @@ -114,19 +117,9 @@ def __init__(self, config: Optional[dict] = None):
model_id,
)
if self._test_connections:
try:
model_conn.test_connection(timeout=self._connect_timeout)
except grpc.RpcError as err:
log.warning(
"<TGB95244222W>",
"Unable to connect to model %s: %s",
model_id,
err,
exc_info=True,
)
model_conn = None
model_conn = self._test_connection(model_conn, self._connect_timeout)
if model_conn is not None:
self._model_connections[model_id] = model_conn
self._safely_update_state(model_id, model_conn)

# We manage a local TGIS instance if there are no remote connections
# specified as either a valid base connection or remote_connections
Expand Down Expand Up @@ -182,25 +175,52 @@ def get_connection(
if not model_conn and create and not self.local_tgis and conn_cfg:
model_conn = TGISConnection.from_config(model_id, conn_cfg)
if self._test_connections:
try:
model_conn.test_connection()
except grpc.RpcError as err:
log.warning(
"<TGB50048960W>",
"Unable to connect to model %s: %s",
model_id,
err,
exc_info=True,
)
model_conn = None
model_conn = self._test_connection(model_conn)
if model_conn is not None:
# NOTE: setdefault used here to avoid the need to hold the mutex
# when running the connection test. It's possible that two
# threads would stimulate the creation of the connection
# concurrently, so just keep whichever dict update lands first
self._model_connections.setdefault(model_id, model_conn)
self._safely_update_state(model_id, model_conn)

return model_conn

def register_model_connection(
self,
model_id: str,
conn_cfg: Optional[Dict[str, Any]] = None,
fill_with_defaults: bool = True,
) -> None:
"""
Register a remote model connection.
If the model connection is already registered, do nothing.
Otherwise create and register the model connection using the TGISBackend's
config connection, or the `conn_cfg` if provided.
If `fill_with_defaults == True`, missing keys in `conn_cfg` will be populated
with defaults from the TGISBackend's config connection.
"""
if model_id in self._model_connections:
return # Model connection exists --> do nothing

# Craft new connection config
new_conn_cfg = {}
if conn_cfg is None:
new_conn_cfg = deepcopy(self._base_connection_cfg)
else:
if fill_with_defaults:
new_conn_cfg = deepcopy(self._base_connection_cfg)
new_conn_cfg.update(conn_cfg)

# Create model connection
model_conn = TGISConnection.from_config(model_id, new_conn_cfg)

error.value_check("<TGB81270235E>", model_conn is not None)

# Register model connection
if self._test_connections:
model_conn = self._test_connection(model_conn)
if model_conn is not None:
self._safely_update_state(model_id, model_conn, new_conn_cfg)

def get_client(self, model_id: str) -> generation_pb2_grpc.GenerationServiceStub:
model_conn = self.get_connection(model_id)
if model_conn is None and self.local_tgis:
Expand Down Expand Up @@ -269,6 +289,48 @@ def model_loaded(self) -> bool:
self._managed_tgis is not None and self._managed_tgis.is_ready()
)

def _test_connection(
self, model_conn: Optional[TGISConnection], timeout: Optional[float] = None
) -> Optional[TGISConnection]:
"""
Returns the TGISConnection if successful, else returns None.
"""
if model_conn is None:
return

try:
model_conn.test_connection(timeout)
except grpc.RpcError as err:
log.warning(
"<TGB10601575W>",
"Unable to connect to model %s: %s",
model_conn.model_id,
err,
exc_info=True,
)
model_conn = None

return model_conn

def _safely_update_state(
self,
model_id: str,
model_connections: Optional[TGISConnection] = None,
remote_models_cfg: Optional[Dict[str, Any]] = None,
):
"""
Update the `_model_connections` and `_remote_models_cfg` state dictionaries in a
thread safe manner.
"""
# NOTE: setdefault used here to avoid the need to hold the mutex
# when running the connection test. It's possible that two
# threads would stimulate the creation of the connection
# concurrently, so just keep whichever dict update lands first
if model_connections:
self._model_connections.setdefault(model_id, model_connections)
if remote_models_cfg:
self._remote_models_cfg.setdefault(model_id, remote_models_cfg)


# Register local backend
register_backend_type(TGISBackend)
13 changes: 8 additions & 5 deletions caikit_tgis_backend/tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Container
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import os
import shutil

Expand Down Expand Up @@ -90,19 +90,22 @@ class TGISConnection:
TLS_HN_OVERRIDE_KEY = "hostname_override"

@classmethod
def from_config(cls, model_id: str, config: dict) -> Optional["TGISConnection"]:
def from_config(
cls, model_id: str, config: Dict[str, Any]
) -> Optional["TGISConnection"]:
"""Create an instance from a connection template and a model_id"""
hostname = config.get(cls.HOSTNAME_KEY)
if hostname:
hostname = hostname.format(
**{
error.type_check("<TGB57775870E>", str, hostname=hostname)

hostname = hostname.format_map(
{
cls.HOSTNAME_TEMPLATE_MODEL_ID: model_id,
}
)
log.debug("Resolved hostname [%s] for model %s", hostname, model_id)

tls_hostname_override = config.get(cls.TLS_HN_OVERRIDE_KEY)

lb_policy = config.get(cls.LB_POLICY_KEY) or None
error.type_check(
"<TGB17223790E>",
Expand Down
106 changes: 98 additions & 8 deletions tests/test_tgis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
"""

# Standard
from copy import deepcopy
from dataclasses import asdict
from typing import Any, Dict, Optional
from unittest import mock
import os
import tempfile
Expand All @@ -32,16 +35,16 @@
# Local
from caikit_tgis_backend import TGISBackend
from caikit_tgis_backend.protobufs import generation_pb2
from tests.tgis_mock import (
TGISMock,
tgis_mock_insecure,
tgis_mock_insecure_health_delay,
tgis_mock_mtls,
tgis_mock_tls,
)
from caikit_tgis_backend.tgis_connection import TGISConnection
from tests.tgis_mock import tgis_mock_insecure # noqa
from tests.tgis_mock import tgis_mock_insecure_health_delay # noqa
from tests.tgis_mock import tgis_mock_mtls # noqa
from tests.tgis_mock import tgis_mock_tls # noqa
from tests.tgis_mock import TGISMock

## Helpers #####################################################################


# for convenience in managing the multiple parts of the fixture
class MockTGISFixture:
def __init__(
Expand Down Expand Up @@ -575,7 +578,6 @@ def test_tgis_backend_config_load_prompt_artifacts():
"""Make sure that loading prompt artifacts behaves as expected"""
with tempfile.TemporaryDirectory() as source_dir:
with tempfile.TemporaryDirectory() as prompt_dir:

# Make some source files
source_fnames = ["prompt1.pt", "prompt2.pt"]
source_files = [os.path.join(source_dir, fname) for fname in source_fnames]
Expand Down Expand Up @@ -681,6 +683,94 @@ def test_tgis_backend_config_load_prompt_artifacts():
tgis_be.load_prompt_artifacts("buz", prompt_id1, source_files[0])


@pytest.mark.parametrize(
argnames=["model_id", "conn_cfg", "fill", "expected_conn_cfg"],
argvalues=[
(
"model1",
None,
False,
{
"hostname": "localhost:1234",
"model_id": "model1",
"lb_policy": "abc",
},
),
(
"model1",
None,
True,
{
"hostname": "localhost:1234",
"model_id": "model1",
"lb_policy": "abc",
},
),
(
"model1",
{"hostname": "myhost"},
False,
{"hostname": "myhost", "model_id": "model1"},
),
(
"model1",
{"hostname": "myhost"},
True,
{"hostname": "myhost", "model_id": "model1", "lb_policy": "abc"},
),
],
)
def test_tgis_backend_register_model_connection(
model_id: str,
conn_cfg: Optional[dict],
fill: bool,
expected_conn_cfg: Dict[str, Any],
):
"""Test that register_model_connection correctly adds a TGISConnection to the _model_connections dictionary"""
tgis_be = TGISBackend(
{
"connection": {"hostname": "localhost:1234", "grpc_lb_policy_name": "abc"},
"remote_models": {},
}
)

# Assert new model is not in backend
assert model_id not in tgis_be._remote_models_cfg
assert model_id not in tgis_be._model_connections
backup_base_cfg = deepcopy(tgis_be._base_connection_cfg)

# Register model
tgis_be.register_model_connection(model_id, conn_cfg, fill_with_defaults=fill)
assert model_id in tgis_be._remote_models_cfg
assert model_id in tgis_be._model_connections
assert isinstance(tgis_be._model_connections[model_id], TGISConnection)
assert {
k: v
for k, v in asdict(tgis_be._model_connections[model_id]).items()
if v is not None
} == expected_conn_cfg

# Re-register -> no change to existing model
tgis_be.register_model_connection(model_id, {"hostname": "{model_id}.mycluster"})
assert {
k: v
for k, v in asdict(tgis_be._model_connections[model_id]).items()
if v is not None
} == expected_conn_cfg

# Confirm get_connection works
conn = tgis_be.get_connection(model_id, create=False)
assert isinstance(conn, TGISConnection)
assert {
k: v
for k, v in asdict(tgis_be._model_connections[model_id]).items()
if v is not None
} == expected_conn_cfg

# Confirm that the source _base_connection_cfg wasn't mutated
assert tgis_be._base_connection_cfg == backup_base_cfg


## Failure Tests ###############################################################


Expand Down
3 changes: 2 additions & 1 deletion tests/test_tgis_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
Unit tests for the TGISConnection class
"""

# Standard
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -27,7 +28,7 @@

# Local
from caikit_tgis_backend.tgis_connection import TGISConnection
from tests.tgis_mock import tgis_mock_insecure
from tests.tgis_mock import tgis_mock_insecure # noqa


@contextmanager
Expand Down

0 comments on commit ddbccd3

Please sign in to comment.