Skip to content

Commit 569bc19

Browse files
authored
feat: add gemini model families, enhance group chat selection for Gemini model and add tests (microsoft#5334)
Resolves microsoft#5322
1 parent 9af6883 commit 569bc19

File tree

7 files changed

+82
-9
lines changed

7 files changed

+82
-9
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List, Mapping, Sequence
44

55
from autogen_core import Component, ComponentModel
6-
from autogen_core.models import ChatCompletionClient, SystemMessage
6+
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage
77
from pydantic import BaseModel
88
from typing_extensions import Self
99

@@ -135,7 +135,11 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
135135
select_speaker_prompt = self._selector_prompt.format(
136136
roles=roles, participants=str(participants), history=history
137137
)
138-
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
138+
select_speaker_messages: List[SystemMessage | UserMessage]
139+
if self._model_client.model_info["family"].startswith("gemini"):
140+
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
141+
else:
142+
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
139143
response = await self._model_client.create(messages=select_speaker_messages)
140144
assert isinstance(response.content, str)
141145
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
3+
import pytest
4+
from autogen_agentchat.agents import AssistantAgent
5+
from autogen_agentchat.teams import SelectorGroupChat
6+
from autogen_agentchat.ui import Console
7+
from autogen_core.models import ModelFamily
8+
from autogen_ext.models.openai import OpenAIChatCompletionClient
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_selector_group_chat_gemini() -> None:
13+
try:
14+
api_key = os.environ["GEMINI_API_KEY"]
15+
except KeyError:
16+
pytest.skip("GEMINI_API_KEY not set in environment variables.")
17+
18+
model_client = OpenAIChatCompletionClient(
19+
model="gemini-1.5-flash",
20+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
21+
api_key=api_key,
22+
model_info={
23+
"vision": True,
24+
"function_calling": True,
25+
"json_output": True,
26+
"family": ModelFamily.GEMINI_1_5_FLASH,
27+
},
28+
)
29+
30+
assistant = AssistantAgent(
31+
"assistant",
32+
description="A helpful assistant agent.",
33+
model_client=model_client,
34+
system_message="You are a helpful assistant.",
35+
)
36+
37+
critic = AssistantAgent(
38+
"critic",
39+
description="A critic agent to provide feedback.",
40+
model_client=model_client,
41+
system_message="Provide feedback.",
42+
)
43+
44+
team = SelectorGroupChat([assistant, critic], model_client=model_client, max_turns=2)
45+
await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year."))

python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@
309309
"source": [
310310
"import os\n",
311311
"\n",
312-
"from autogen_core.models import UserMessage\n",
312+
"from autogen_core.models import ModelFamily, UserMessage\n",
313313
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
314314
"\n",
315315
"model_client = OpenAIChatCompletionClient(\n",
@@ -320,7 +320,7 @@
320320
" \"vision\": True,\n",
321321
" \"function_calling\": True,\n",
322322
" \"json_output\": True,\n",
323-
" \"family\": \"unknown\",\n",
323+
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
324324
" },\n",
325325
")\n",
326326
"\n",

python/packages/autogen-core/docs/src/user-guide/core-user-guide/components/model-clients.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@
317317
"source": [
318318
"import os\n",
319319
"\n",
320-
"from autogen_core.models import UserMessage\n",
320+
"from autogen_core.models import ModelFamily, UserMessage\n",
321321
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
322322
"\n",
323323
"model_client = OpenAIChatCompletionClient(\n",
@@ -328,7 +328,7 @@
328328
" \"vision\": True,\n",
329329
" \"function_calling\": True,\n",
330330
" \"json_output\": True,\n",
331-
" \"family\": \"unknown\",\n",
331+
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
332332
" },\n",
333333
")\n",
334334
"\n",

python/packages/autogen-core/src/autogen_core/models/_model_client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,23 @@ class ModelFamily:
2424
GPT_4 = "gpt-4"
2525
GPT_35 = "gpt-35"
2626
R1 = "r1"
27+
GEMINI_1_5_FLASH = "gemini-1.5-flash"
28+
GEMINI_1_5_PRO = "gemini-1.5-pro"
29+
GEMINI_2_0_FLASH = "gemini-2.0-flash"
2730
UNKNOWN = "unknown"
2831

29-
ANY: TypeAlias = Literal["gpt-4o", "o1", "o3", "gpt-4", "gpt-35", "r1", "unknown"]
32+
ANY: TypeAlias = Literal[
33+
"gpt-4o",
34+
"o1",
35+
"o3",
36+
"gpt-4",
37+
"gpt-35",
38+
"r1",
39+
"gemini-1.5-flash",
40+
"gemini-1.5-pro",
41+
"gemini-2.0-flash",
42+
"unknown",
43+
]
3044

3145
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
3246
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")

python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,17 @@ async def main() -> None:
144144
temperature=0.2,
145145
)
146146
147-
model_client = SKChatCompletionAdapter(sk_client, kernel=Kernel(memory=NullMemory()), prompt_settings=settings)
147+
model_client = SKChatCompletionAdapter(
148+
sk_client,
149+
kernel=Kernel(memory=NullMemory()),
150+
prompt_settings=settings,
151+
model_info={
152+
"family": "gemini-1.5-flash",
153+
"function_calling": True,
154+
"json_output": True,
155+
"vision": False,
156+
},
157+
)
148158
149159
# Call the model directly.
150160
model_result = await model_client.create(

python/packages/autogen-ext/tests/models/test_openai_model_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ async def test_gemini() -> None:
958958
"function_calling": True,
959959
"json_output": True,
960960
"vision": True,
961-
"family": ModelFamily.UNKNOWN,
961+
"family": ModelFamily.GEMINI_1_5_FLASH,
962962
},
963963
)
964964
await _test_model_client_basic_completion(model_client)

0 commit comments

Comments
 (0)