Skip to content

Commit 411ab70

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add Resource to tracer provider in Agent Engine templates
PiperOrigin-RevId: 763962855
1 parent 47ab05a commit 411ab70

File tree

8 files changed

+367
-12
lines changed

8 files changed

+367
-12
lines changed

tests/unit/vertexai/test_generative_models.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
from typing import Dict, Iterable, List, MutableSequence, Optional
2222
from unittest import mock
2323

24+
from google.api_core import operation as ga_operation
2425
import vertexai
2526
from google.cloud.aiplatform import initializer
2627
from google.cloud.aiplatform_v1 import types as types_v1
2728
from google.cloud.aiplatform_v1.services import (
2829
prediction_service as prediction_service_v1,
2930
)
3031
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
32+
from google.cloud.aiplatform_v1beta1.services import endpoint_service
3133
from vertexai import generative_models
3234
from vertexai.preview import (
3335
generative_models as preview_generative_models,
@@ -48,6 +50,7 @@
4850
)
4951
from vertexai.generative_models import _function_calling_utils
5052
from vertexai.caching import CachedContent
53+
from google.protobuf import field_mask_pb2
5154

5255

5356
_TEST_PROJECT = "test-project"
@@ -1710,6 +1713,115 @@ def test_defs_ref_renaming(self):
17101713
_fix_schema_dict_for_gapic_in_place(actual)
17111714
assert actual == expected
17121715

1716+
@pytest.mark.parametrize(
1717+
"generative_models",
1718+
[preview_generative_models], # Only preview supports set_logging_config
1719+
)
1720+
@mock.patch.object(endpoint_service.EndpointServiceClient, "update_endpoint")
1721+
def test_set_logging_config_for_endpoint(
1722+
self, mock_update_endpoint, generative_models: generative_models
1723+
):
1724+
endpoint_name = (
1725+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/12345"
1726+
)
1727+
model = generative_models.GenerativeModel(endpoint_name)
1728+
1729+
mock_update_endpoint.return_value = types_v1beta1.Endpoint(name=endpoint_name)
1730+
1731+
enabled = True
1732+
sampling_rate = 0.5
1733+
bigquery_destination = f"bq://{_TEST_PROJECT}.my_dataset.my_table"
1734+
enable_otel_logging = True
1735+
1736+
model.set_request_response_logging_config(
1737+
enabled=enabled,
1738+
sampling_rate=sampling_rate,
1739+
bigquery_destination=bigquery_destination,
1740+
enable_otel_logging=enable_otel_logging,
1741+
)
1742+
1743+
expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
1744+
enabled=enabled,
1745+
sampling_rate=sampling_rate,
1746+
bigquery_destination=types_v1beta1.BigQueryDestination(
1747+
output_uri=bigquery_destination
1748+
),
1749+
enable_otel_logging=enable_otel_logging,
1750+
)
1751+
expected_endpoint = types_v1beta1.Endpoint(
1752+
name=endpoint_name,
1753+
predict_request_response_logging_config=expected_logging_config,
1754+
)
1755+
expected_update_mask = field_mask_pb2.FieldMask(
1756+
paths=["predict_request_response_logging_config"]
1757+
)
1758+
1759+
mock_update_endpoint.assert_called_once_with(
1760+
types_v1beta1.UpdateEndpointRequest(
1761+
endpoint=expected_endpoint,
1762+
update_mask=expected_update_mask,
1763+
)
1764+
)
1765+
1766+
@pytest.mark.parametrize(
1767+
"generative_models",
1768+
[preview_generative_models], # Only preview supports set_logging_config
1769+
)
1770+
@mock.patch.object(
1771+
endpoint_service.EndpointServiceClient, "set_publisher_model_config"
1772+
)
1773+
def test_set_logging_config_for_publisher_model(
1774+
self, mock_set_publisher_model_config, generative_models: generative_models
1775+
):
1776+
model_name = "gemini-pro"
1777+
model = generative_models.GenerativeModel(model_name)
1778+
full_model_name = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/{model_name}"
1779+
1780+
enabled = False
1781+
sampling_rate = 1.0
1782+
bigquery_destination = f"bq://{_TEST_PROJECT}.another_dataset"
1783+
enable_otel_logging = False
1784+
1785+
mock_operation = mock.Mock(spec=ga_operation.Operation)
1786+
mock_set_publisher_model_config.return_value = mock_operation
1787+
mock_operation.result.return_value = types_v1beta1.PublisherModelConfig(
1788+
logging_config=types_v1beta1.PredictRequestResponseLoggingConfig(
1789+
enabled=enabled,
1790+
sampling_rate=sampling_rate,
1791+
bigquery_destination=types_v1beta1.BigQueryDestination(
1792+
output_uri=bigquery_destination
1793+
),
1794+
enable_otel_logging=enable_otel_logging,
1795+
)
1796+
)
1797+
1798+
model.set_request_response_logging_config(
1799+
enabled=enabled,
1800+
sampling_rate=sampling_rate,
1801+
bigquery_destination=bigquery_destination,
1802+
enable_otel_logging=enable_otel_logging,
1803+
)
1804+
1805+
expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
1806+
enabled=enabled,
1807+
sampling_rate=sampling_rate,
1808+
bigquery_destination=types_v1beta1.BigQueryDestination(
1809+
output_uri=bigquery_destination
1810+
),
1811+
enable_otel_logging=enable_otel_logging,
1812+
)
1813+
expected_publisher_model_config = types_v1beta1.PublisherModelConfig(
1814+
logging_config=expected_logging_config
1815+
)
1816+
1817+
mock_set_publisher_model_config.assert_called_once_with(
1818+
types_v1beta1.SetPublisherModelConfigRequest(
1819+
name=full_model_name,
1820+
publisher_model_config=expected_publisher_model_config,
1821+
)
1822+
)
1823+
mock_operation.result.assert_called_once()
1824+
17131825

17141826
EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
17151827
"title": "get_current_weather",

vertexai/agent_engines/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,20 @@ def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
685685
return None
686686

687687

688+
def _import_opentelemetry_sdk_resources_or_warn() -> Optional[types.ModuleType]:
689+
"""Tries to import the opentelemetry.sdk.trace module."""
690+
try:
691+
import opentelemetry.sdk.resources # noqa:F401
692+
693+
return opentelemetry.sdk.resources
694+
except ImportError:
695+
LOGGER.warning(
696+
"Failed to import opentelemetry.sdk.resources. Please call "
697+
"'pip install google-cloud-aiplatform[agent_engines]'."
698+
)
699+
return None
700+
701+
688702
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
689703
"""Tries to import the opentelemetry.sdk.trace module."""
690704
try:

vertexai/agent_engines/templates/ag2.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,15 @@ def _default_instrumentor_builder(project_id: str):
9696
openinference_autogen = _utils._import_openinference_autogen_or_warn()
9797
opentelemetry = _utils._import_opentelemetry_or_warn()
9898
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
99+
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
99100
if all(
100101
(
101102
cloud_trace_exporter,
102103
cloud_trace_v2,
103104
openinference_autogen,
104105
opentelemetry,
105106
opentelemetry_sdk_trace,
107+
opentelemetry_sdk_resources,
106108
)
107109
):
108110
import google.auth
@@ -129,6 +131,7 @@ def _default_instrumentor_builder(project_id: str):
129131
# create a tracer provider.
130132
if not tracer_provider:
131133
from google.cloud.aiplatform import base
134+
import os
132135

133136
_LOGGER = base.Logger(__name__)
134137
_LOGGER.warning(
@@ -137,13 +140,39 @@ def _default_instrumentor_builder(project_id: str):
137140
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
138141
"or _PROXY_TRACER_PROVIDER."
139142
)
140-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
143+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
144+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
145+
resource = opentelemetry_sdk_resources.Resource.create(
146+
attributes={
147+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
148+
SERVICE_INSTANCE_ID: os.environ.get(
149+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
150+
),
151+
}
152+
)
153+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
154+
resource=resource,
155+
)
141156
opentelemetry.trace.set_tracer_provider(tracer_provider)
142157
# Avoids AttributeError:
143158
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
144159
# attribute 'add_span_processor'.
145160
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
146-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
161+
import os
162+
163+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
164+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
165+
resource = opentelemetry_sdk_resources.Resource.create(
166+
attributes={
167+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
168+
SERVICE_INSTANCE_ID: os.environ.get(
169+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
170+
),
171+
}
172+
)
173+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
174+
resource=resource,
175+
)
147176
opentelemetry.trace.set_tracer_provider(tracer_provider)
148177
# Avoids OpenTelemetry client already exists error.
149178
_override_active_span_processor(

vertexai/agent_engines/templates/langchain.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,15 @@ def _default_instrumentor_builder(project_id: str):
175175
openinference_langchain = _utils._import_openinference_langchain_or_warn()
176176
opentelemetry = _utils._import_opentelemetry_or_warn()
177177
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
178+
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
178179
if all(
179180
(
180181
cloud_trace_exporter,
181182
cloud_trace_v2,
182183
openinference_langchain,
183184
opentelemetry,
184185
opentelemetry_sdk_trace,
186+
opentelemetry_sdk_resources,
185187
)
186188
):
187189
import google.auth
@@ -208,6 +210,7 @@ def _default_instrumentor_builder(project_id: str):
208210
# create a tracer provider.
209211
if not tracer_provider:
210212
from google.cloud.aiplatform import base
213+
import os
211214

212215
_LOGGER = base.Logger(__name__)
213216
_LOGGER.warning(
@@ -216,13 +219,39 @@ def _default_instrumentor_builder(project_id: str):
216219
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
217220
"or _PROXY_TRACER_PROVIDER."
218221
)
219-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
222+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
223+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
224+
resource = opentelemetry_sdk_resources.Resource.create(
225+
attributes={
226+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
227+
SERVICE_INSTANCE_ID: os.environ.get(
228+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
229+
),
230+
}
231+
)
232+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
233+
resource=resource,
234+
)
220235
opentelemetry.trace.set_tracer_provider(tracer_provider)
221236
# Avoids AttributeError:
222237
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
223238
# attribute 'add_span_processor'.
224239
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
225-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
240+
import os
241+
242+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
243+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
244+
resource = opentelemetry_sdk_resources.Resource.create(
245+
attributes={
246+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
247+
SERVICE_INSTANCE_ID: os.environ.get(
248+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
249+
),
250+
}
251+
)
252+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
253+
resource=resource,
254+
)
226255
opentelemetry.trace.set_tracer_provider(tracer_provider)
227256
# Avoids OpenTelemetry client already exists error.
228257
_override_active_span_processor(

vertexai/agent_engines/templates/langgraph.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,15 @@ def _default_instrumentor_builder(project_id: str):
166166
openinference_langchain = _utils._import_openinference_langchain_or_warn()
167167
opentelemetry = _utils._import_opentelemetry_or_warn()
168168
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
169+
opentelemetry_sdk_resources = _utils._import_opentelemetry_sdk_resources_or_warn()
169170
if all(
170171
(
171172
cloud_trace_exporter,
172173
cloud_trace_v2,
173174
openinference_langchain,
174175
opentelemetry,
175176
opentelemetry_sdk_trace,
177+
opentelemetry_sdk_resources,
176178
)
177179
):
178180
import google.auth
@@ -199,20 +201,47 @@ def _default_instrumentor_builder(project_id: str):
199201
# create a tracer provider.
200202
if not tracer_provider:
201203
from google.cloud.aiplatform import base
204+
import os
202205

203206
base.Logger(__name__).warning(
204207
"No tracer provider. By default, "
205208
"we should get one of the following providers: "
206209
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
207210
"or _PROXY_TRACER_PROVIDER."
208211
)
209-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
212+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
213+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
214+
resource = opentelemetry_sdk_resources.Resource.create(
215+
attributes={
216+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
217+
SERVICE_INSTANCE_ID: os.environ.get(
218+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
219+
),
220+
}
221+
)
222+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
223+
resource=resource,
224+
)
210225
opentelemetry.trace.set_tracer_provider(tracer_provider)
211226
# Avoids AttributeError:
212227
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
213228
# attribute 'add_span_processor'.
214229
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
215-
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
230+
import os
231+
232+
SERVICE_INSTANCE_ID = opentelemetry_sdk_resources.SERVICE_INSTANCE_ID
233+
SERVICE_NAME = opentelemetry_sdk_resources.SERVICE_NAME
234+
resource = opentelemetry_sdk_resources.Resource.create(
235+
attributes={
236+
SERVICE_NAME: "aiplatform.googleapis.com/ReasoningEngine",
237+
SERVICE_INSTANCE_ID: os.environ.get(
238+
"GOOGLE_CLOUD_AGENT_ENGINE_ID", ""
239+
),
240+
}
241+
)
242+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
243+
resource=resource,
244+
)
216245
opentelemetry.trace.set_tracer_provider(tracer_provider)
217246
# Avoids OpenTelemetry client already exists error.
218247
_override_active_span_processor(

0 commit comments

Comments
 (0)