Skip to content

feat: Add Resource to tracer provider in Agent Engine templates #5357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from typing import Dict, Iterable, List, MutableSequence, Optional
from unittest import mock

from google.api_core import operation as ga_operation
import vertexai
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1 import types as types_v1
from google.cloud.aiplatform_v1.services import (
prediction_service as prediction_service_v1,
)
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
from google.cloud.aiplatform_v1beta1.services import endpoint_service
from vertexai import generative_models
from vertexai.preview import (
generative_models as preview_generative_models,
Expand All @@ -48,6 +50,7 @@
)
from vertexai.generative_models import _function_calling_utils
from vertexai.caching import CachedContent
from google.protobuf import field_mask_pb2


_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -1710,6 +1713,115 @@ def test_defs_ref_renaming(self):
_fix_schema_dict_for_gapic_in_place(actual)
assert actual == expected

@pytest.mark.parametrize(
"generative_models",
[preview_generative_models], # Only preview supports set_logging_config
)
@mock.patch.object(endpoint_service.EndpointServiceClient, "update_endpoint")
def test_set_logging_config_for_endpoint(
self, mock_update_endpoint, generative_models: generative_models
):
endpoint_name = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/12345"
)
model = generative_models.GenerativeModel(endpoint_name)

mock_update_endpoint.return_value = types_v1beta1.Endpoint(name=endpoint_name)

enabled = True
sampling_rate = 0.5
bigquery_destination = f"bq://{_TEST_PROJECT}.my_dataset.my_table"
enable_otel_logging = True

model.set_request_response_logging_config(
enabled=enabled,
sampling_rate=sampling_rate,
bigquery_destination=bigquery_destination,
enable_otel_logging=enable_otel_logging,
)

expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
enabled=enabled,
sampling_rate=sampling_rate,
bigquery_destination=types_v1beta1.BigQueryDestination(
output_uri=bigquery_destination
),
enable_otel_logging=enable_otel_logging,
)
expected_endpoint = types_v1beta1.Endpoint(
name=endpoint_name,
predict_request_response_logging_config=expected_logging_config,
)
expected_update_mask = field_mask_pb2.FieldMask(
paths=["predict_request_response_logging_config"]
)

mock_update_endpoint.assert_called_once_with(
types_v1beta1.UpdateEndpointRequest(
endpoint=expected_endpoint,
update_mask=expected_update_mask,
)
)

@pytest.mark.parametrize(
"generative_models",
[preview_generative_models], # Only preview supports set_logging_config
)
@mock.patch.object(
endpoint_service.EndpointServiceClient, "set_publisher_model_config"
)
def test_set_logging_config_for_publisher_model(
self, mock_set_publisher_model_config, generative_models: generative_models
):
model_name = "gemini-pro"
model = generative_models.GenerativeModel(model_name)
full_model_name = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/{model_name}"

enabled = False
sampling_rate = 1.0
bigquery_destination = f"bq://{_TEST_PROJECT}.another_dataset"
enable_otel_logging = False

mock_operation = mock.Mock(spec=ga_operation.Operation)
mock_set_publisher_model_config.return_value = mock_operation
mock_operation.result.return_value = types_v1beta1.PublisherModelConfig(
logging_config=types_v1beta1.PredictRequestResponseLoggingConfig(
enabled=enabled,
sampling_rate=sampling_rate,
bigquery_destination=types_v1beta1.BigQueryDestination(
output_uri=bigquery_destination
),
enable_otel_logging=enable_otel_logging,
)
)

model.set_request_response_logging_config(
enabled=enabled,
sampling_rate=sampling_rate,
bigquery_destination=bigquery_destination,
enable_otel_logging=enable_otel_logging,
)

expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
enabled=enabled,
sampling_rate=sampling_rate,
bigquery_destination=types_v1beta1.BigQueryDestination(
output_uri=bigquery_destination
),
enable_otel_logging=enable_otel_logging,
)
expected_publisher_model_config = types_v1beta1.PublisherModelConfig(
logging_config=expected_logging_config
)

mock_set_publisher_model_config.assert_called_once_with(
types_v1beta1.SetPublisherModelConfigRequest(
name=full_model_name,
publisher_model_config=expected_publisher_model_config,
)
)
mock_operation.result.assert_called_once()


EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
"title": "get_current_weather",
Expand Down
14 changes: 14 additions & 0 deletions vertexai/agent_engines/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,20 @@ def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
return None


def _import_opentelemetry_sdk_resources_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.sdk.trace module."""
try:
import opentelemetry.sdk.resources # noqa:F401

return opentelemetry.sdk.resources
except ImportError:
LOGGER.warning(
"Failed to import opentelemetry.sdk.resources. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None


def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.sdk.trace module."""
try:
Expand Down
23 changes: 20 additions & 3 deletions vertexai/agent_engines/templates/ag2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,23 @@ def _default_instrumentor_builder(project_id: str):
openinference_autogen = _utils._import_openinference_autogen_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_autogen,
opentelemetry,
opentelemetry_sdk_trace,
opentelemetry_sdk_resources,
)
):
import google.auth
import os

SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
AGENT_ENGINE_ID = os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID", "")

credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
Expand All @@ -119,6 +126,12 @@ def _default_instrumentor_builder(project_id: str):
span_exporter=span_exporter,
)
)
resource = opentelemetry_sdk_resources.Resource.create(
attributes={
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
SERVICE_INSTANCE_ID: AGENT_ENGINE_ID,
}
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
Expand All @@ -127,7 +140,7 @@ def _default_instrumentor_builder(project_id: str):
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
if AGENT_ENGINE_ID or not tracer_provider:
from google.cloud.aiplatform import base

_LOGGER = base.Logger(__name__)
Expand All @@ -137,13 +150,17 @@ def _default_instrumentor_builder(project_id: str):
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
Expand Down
23 changes: 20 additions & 3 deletions vertexai/agent_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,23 @@ def _default_instrumentor_builder(project_id: str):
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
opentelemetry_sdk_resources,
)
):
import google.auth
import os

SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
AGENT_ENGINE_ID = os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID", "")

credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
Expand All @@ -198,6 +205,12 @@ def _default_instrumentor_builder(project_id: str):
span_exporter=span_exporter,
)
)
resource = opentelemetry_sdk_resources.Resource.create(
attributes={
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
SERVICE_INSTANCE_ID: AGENT_ENGINE_ID,
}
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
Expand All @@ -206,7 +219,7 @@ def _default_instrumentor_builder(project_id: str):
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
if AGENT_ENGINE_ID or not tracer_provider:
from google.cloud.aiplatform import base

_LOGGER = base.Logger(__name__)
Expand All @@ -216,13 +229,17 @@ def _default_instrumentor_builder(project_id: str):
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
Expand Down
23 changes: 20 additions & 3 deletions vertexai/agent_engines/templates/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,23 @@ def _default_instrumentor_builder(project_id: str):
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
opentelemetry_sdk_resources,
)
):
import google.auth
import os

SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
AGENT_ENGINE_ID = os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID", "")

credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
Expand All @@ -189,6 +196,12 @@ def _default_instrumentor_builder(project_id: str):
span_exporter=span_exporter,
)
)
resource = opentelemetry_sdk_resources.Resource.create(
attributes={
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
SERVICE_INSTANCE_ID: AGENT_ENGINE_ID,
}
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
Expand All @@ -197,7 +210,7 @@ def _default_instrumentor_builder(project_id: str):
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
if AGENT_ENGINE_ID or not tracer_provider:
from google.cloud.aiplatform import base

base.Logger(__name__).warning(
Expand All @@ -206,13 +219,17 @@ def _default_instrumentor_builder(project_id: str):
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
resource=resource,
)
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
Expand Down
Loading