Skip to content

Commit 0e7fda3

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Enabling Few-shot Prompt Optimization by passing either "OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS" or "OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE" to the optimize method
together with example dataframe PiperOrigin-RevId: 834333590
1 parent dd4775b commit 0e7fda3

File tree

6 files changed

+383
-16
lines changed

6 files changed

+383
-16
lines changed

tests/unit/vertexai/genai/replays/test_prompt_optimizer_async_optimize_prompt_return_type.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
19+
import pandas as pd
1920
import pytest
2021

2122

@@ -32,6 +33,65 @@ async def test_optimize_prompt(client):
3233
assert response.raw_text_response
3334

3435

36+
@pytest.mark.asyncio
37+
async def test_optimize_prompt_w_optimization_target(client):
38+
"""Tests the optimize request parameters method with optimization target."""
39+
test_prompt = "Generate system instructions for analyzing medical articles"
40+
response = await client.aio.prompt_optimizer.optimize_prompt(
41+
prompt=test_prompt,
42+
config=types.OptimizeConfig(
43+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
44+
),
45+
)
46+
assert isinstance(response, types.OptimizeResponse)
47+
assert response.raw_text_response
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_optimize_prompt_w_few_shot_optimization_target(client):
52+
"""Tests the optimize request parameters method with few shot optimization target."""
53+
test_prompt = "Generate system instructions for analyzing medical articles"
54+
df = pd.DataFrame(
55+
{
56+
"prompt": ["prompt1", "prompt2"],
57+
"model_response": ["response1", "response2"],
58+
"target_response": ["target1", "target2"],
59+
}
60+
)
61+
response = await client.aio.prompt_optimizer.optimize_prompt(
62+
prompt=test_prompt,
63+
config=types.OptimizeConfig(
64+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
65+
examples_dataframe=df,
66+
),
67+
)
68+
assert isinstance(response, types.OptimizeResponse)
69+
assert response.raw_text_response
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_optimize_prompt_w_few_shot_optimization_rubrics(client):
74+
"""Tests the optimize request parameters method with few shot optimization target."""
75+
test_prompt = "Generate system instructions for analyzing medical articles"
76+
df = pd.DataFrame(
77+
{
78+
"prompt": ["prompt1", "prompt2"],
79+
"model_response": ["response1", "response2"],
80+
"rubrics": ["rubric1", "rubric2"],
81+
"rubrics_evals": ["[True, True]", "[True, False]"],
82+
}
83+
)
84+
response = await client.aio.prompt_optimizer.optimize_prompt(
85+
prompt=test_prompt,
86+
config=types.OptimizeConfig(
87+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
88+
examples_dataframe=df,
89+
),
90+
)
91+
assert isinstance(response, types.OptimizeResponse)
92+
assert response.raw_text_response
93+
94+
3595
pytestmark = pytest_helper.setup(
3696
file=__file__,
3797
globals_for_file=globals(),

tests/unit/vertexai/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
19+
import pandas as pd
1920

2021

2122
def test_optimize_prompt(client):
@@ -27,18 +28,60 @@ def test_optimize_prompt(client):
2728
assert response.raw_text_response
2829

2930

30-
# def test_optimize_prompt_w_optimization_target(client):
31-
# """Tests the optimize request parameters method with optimization target."""
32-
# from google.genai import types as genai_types
33-
# test_prompt = "Generate system instructions for analyzing medical articles"
34-
# response = client.prompt_optimizer.optimize_prompt(
35-
# prompt=test_prompt,
36-
# config=types.OptimizeConfig(
37-
# optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
38-
# ),
39-
# )
40-
# assert isinstance(response, types.OptimizeResponse)
41-
# assert response.raw_text_response
31+
def test_optimize_prompt_w_optimization_target(client):
32+
"""Tests the optimize request parameters method with optimization target."""
33+
test_prompt = "Generate system instructions for analyzing medical articles"
34+
response = client.prompt_optimizer.optimize_prompt(
35+
prompt=test_prompt,
36+
config=types.OptimizeConfig(
37+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
38+
),
39+
)
40+
assert isinstance(response, types.OptimizeResponse)
41+
assert response.raw_text_response
42+
43+
44+
def test_optimize_prompt_w_few_shot_optimization_target(client):
45+
"""Tests the optimize request parameters method with few shot optimization target."""
46+
test_prompt = "Generate system instructions for analyzing medical articles"
47+
df = pd.DataFrame(
48+
{
49+
"prompt": ["prompt1", "prompt2"],
50+
"model_response": ["response1", "response2"],
51+
"target_response": ["target1", "target2"],
52+
}
53+
)
54+
response = client.prompt_optimizer.optimize_prompt(
55+
prompt=test_prompt,
56+
config=types.OptimizeConfig(
57+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
58+
examples_dataframe=df,
59+
),
60+
)
61+
assert isinstance(response, types.OptimizeResponse)
62+
assert response.raw_text_response
63+
64+
65+
def test_optimize_prompt_w_few_shot_optimization_rubrics(client):
66+
"""Tests the optimize request parameters method with few shot optimization target."""
67+
test_prompt = "Generate system instructions for analyzing medical articles"
68+
df = pd.DataFrame(
69+
{
70+
"prompt": ["prompt1", "prompt2"],
71+
"model_response": ["response1", "response2"],
72+
"rubrics": ["rubric1", "rubric2"],
73+
"rubrics_evals": ["[True, True]", "[True, False]"],
74+
}
75+
)
76+
response = client.prompt_optimizer.optimize_prompt(
77+
prompt=test_prompt,
78+
config=types.OptimizeConfig(
79+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
80+
examples_dataframe=df,
81+
),
82+
)
83+
assert isinstance(response, types.OptimizeResponse)
84+
assert response.raw_text_response
4285

4386

4487
pytestmark = pytest_helper.setup(

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vertexai._genai import prompt_optimizer
2222
from vertexai._genai import types
2323
from google.genai import client
24+
import pandas as pd
2425
import pytest
2526

2627

@@ -91,6 +92,35 @@ def test_prompt_optimizer_optimize_prompt(
9192
mock_client.assert_called_once()
9293
mock_custom_optimize_prompt.assert_called_once()
9394

95+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
96+
def test_prompt_optimizer_optimize_few_shot(self, mock_custom_optimize_prompt):
97+
"""Test that prompt_optimizer.optimize method for few shot optimizer."""
98+
df = pd.DataFrame(
99+
{
100+
"prompt": ["prompt1", "prompt2"],
101+
"model_response": ["response1", "response2"],
102+
"target_response": ["target1", "target2"],
103+
}
104+
)
105+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
106+
test_config = types.OptimizeConfig(
107+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
108+
examples_dataframe=df,
109+
)
110+
test_client.prompt_optimizer.optimize_prompt(
111+
prompt="test_prompt",
112+
config=test_config,
113+
)
114+
mock_custom_optimize_prompt.assert_called_once()
115+
mock_kwargs = mock_custom_optimize_prompt.call_args.kwargs
116+
assert (
117+
mock_kwargs["config"].optimization_target
118+
== test_config.optimization_target
119+
)
120+
pd.testing.assert_frame_equal(
121+
mock_kwargs["config"].examples_dataframe, test_config.examples_dataframe
122+
)
123+
94124
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
95125
def test_prompt_optimizer_optimize_prompt_with_optimization_target(
96126
self, mock_custom_optimize_prompt
@@ -138,4 +168,59 @@ async def test_async_prompt_optimizer_optimize_prompt_with_optimization_target(
138168
config=config,
139169
)
140170

171+
@pytest.mark.asyncio
172+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
173+
async def test_async_prompt_optimizer_optimize_prompt_few_shot_target_response(
174+
self, mock_custom_optimize_prompt
175+
):
176+
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot target response."""
177+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
178+
df = pd.DataFrame(
179+
{
180+
"prompt": ["prompt1", "prompt2"],
181+
"model_response": ["response1", "response2"],
182+
"target_response": ["target1", "target2"],
183+
}
184+
)
185+
config = types.OptimizeConfig(
186+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
187+
examples_dataframe=df,
188+
)
189+
await test_client.aio.prompt_optimizer.optimize_prompt(
190+
prompt="test_prompt",
191+
config=config,
192+
)
193+
mock_custom_optimize_prompt.assert_called_once_with(
194+
content=mock.ANY,
195+
config=config,
196+
)
197+
198+
@pytest.mark.asyncio
199+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
200+
async def test_async_prompt_optimizer_optimize_prompt_few_shot_rubrics(
201+
self, mock_custom_optimize_prompt
202+
):
203+
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot rubrics."""
204+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
205+
df = pd.DataFrame(
206+
{
207+
"prompt": ["prompt1", "prompt2"],
208+
"model_response": ["response1", "response2"],
209+
"rubrics": ["rubric1", "rubric2"],
210+
"rubrics_evals": ["[True, True]", "[True, False]"],
211+
}
212+
)
213+
config = types.OptimizeConfig(
214+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
215+
examples_dataframe=df,
216+
)
217+
await test_client.aio.prompt_optimizer.optimize_prompt(
218+
prompt="test_prompt",
219+
config=config,
220+
)
221+
mock_custom_optimize_prompt.assert_called_once_with(
222+
content=mock.ANY,
223+
config=config,
224+
)
225+
141226
# # TODO(b/415060797): add more tests for prompt_optimizer.optimize

vertexai/_genai/_prompt_optimizer_utils.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,125 @@
1515
"""Utility functions for prompt optimizer."""
1616

1717
import json
18+
import logging
19+
from typing import Optional
20+
import pandas as pd
1821
from . import types
1922

2023

24+
def _construct_input_prompt(
25+
example_df: pd.DataFrame,
26+
prompt_col_name: str,
27+
model_response_col_name: str,
28+
rubrics_col_name: str,
29+
rubric_evaluations_col_name: str,
30+
target_response_col_name: str,
31+
system_instruction: Optional[str] = None,
32+
) -> str:
33+
"""Construct the input prompt for the few shot prompt optimizer."""
34+
35+
all_prompts = []
36+
for _, row in example_df.iterrows():
37+
example_data = {
38+
"prompt": row[prompt_col_name],
39+
"model_response": row[model_response_col_name],
40+
}
41+
if rubrics_col_name:
42+
example_data["rubrics"] = row[rubrics_col_name]
43+
if rubric_evaluations_col_name:
44+
example_data["rubric_evaluations"] = row[rubric_evaluations_col_name]
45+
if target_response_col_name:
46+
example_data["target_response"] = row[target_response_col_name]
47+
48+
json_str = json.dumps(example_data, indent=2)
49+
all_prompts.append(f"```JSON\n{json_str}\n```")
50+
51+
all_prompts_str = "\n\n".join(all_prompts)
52+
53+
if system_instruction is None:
54+
system_instruction = ""
55+
56+
input_prompt = "\n".join(
57+
[
58+
"Original System Instructions:\n",
59+
system_instruction,
60+
"Examples:\n",
61+
all_prompts_str,
62+
"\nNew Output:\n",
63+
]
64+
)
65+
66+
return input_prompt
67+
68+
69+
def _get_few_shot_prompt(
70+
system_instruction: str,
71+
config: types.OptimizeConfig,
72+
) -> str:
73+
"""Builds the few shot prompt."""
74+
75+
if "prompt" not in config.examples_dataframe.columns:
76+
raise ValueError("'prompt' is required in the examples_dataframe.")
77+
prompt_col_name = "prompt"
78+
79+
if "model_response" not in config.examples_dataframe.columns:
80+
raise ValueError("'model_response' is required in the example_df.")
81+
model_response_col_name = "model_response"
82+
83+
target_response_col_name = ""
84+
rubrics_col_name = ""
85+
rubric_evaluations_col_name = ""
86+
87+
if (
88+
config.optimization_target
89+
== types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE
90+
):
91+
if "target_response" not in config.examples_dataframe.columns:
92+
raise ValueError("'target_response' is required in the examples_dataframe.")
93+
target_response_col_name = "target_response"
94+
if "rubrics" in config.examples_dataframe.columns:
95+
raise ValueError(
96+
"Only 'target_response' should be provided "
97+
"for OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE "
98+
"but 'rubrics' was provided."
99+
)
100+
101+
elif (
102+
config.optimization_target
103+
== types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS
104+
):
105+
if ("rubrics" not in config.examples_dataframe.columns) or (
106+
"rubrics_evals" not in config.examples_dataframe.columns
107+
):
108+
raise ValueError(
109+
"rubrics and rubrics_evals is required in the"
110+
"examples_dataframe when rubrics is set."
111+
)
112+
113+
rubrics_col_name = "rubrics"
114+
rubric_evaluations_col_name = "rubrics_evals"
115+
if "target_response" in config.examples_dataframe.columns:
116+
raise ValueError(
117+
"Only 'rubrics' and 'rubrics_evals' should be provided "
118+
"for OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS "
119+
"but target_response was provided."
120+
)
121+
else:
122+
raise ValueError("One of 'target_response' or 'rubrics' must be provided.")
123+
124+
prompt = _construct_input_prompt(
125+
config.examples_dataframe,
126+
prompt_col_name,
127+
model_response_col_name,
128+
rubrics_col_name,
129+
rubric_evaluations_col_name,
130+
target_response_col_name,
131+
system_instruction,
132+
)
133+
134+
return prompt
135+
136+
21137
def _get_service_account(
22138
config: types.PromptOptimizerConfigOrDict,
23139
) -> str:

0 commit comments

Comments
 (0)