Skip to content

Commit 0e9ce20

Browse files
authored
Merge pull request #169 from dbpunk-labs/feat/refactor
fix: fix codellama agent bug
2 parents 95e9454 + 956ab83 commit 0e9ce20

File tree

8 files changed

+363
-66
lines changed

8 files changed

+363
-66
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
1616
https://github.com/dbpunk-labs/octogen/assets/8623385/7445cc4d-567e-4d1a-bedc-b5b566329c41
1717

18-
1918
|Supported OSs|Supported Interpreters|Supported Dev Enviroment|
2019
|----|-----|-----|
2120
|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/31b907e9-3a6f-4e9e-b0c0-f01d1e758a21"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/565d5f93-baac-4a77-ab1c-7d845e2fdb6d"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/acb7f919-ef09-446e-b1bc-0b50bc28de5a"/>|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/6e286d3d-55f8-43df-ade6-38065b78eda1"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/958d23a0-777c-4bb9-8480-c7350c128c3f"/>|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/ec8d5bff-f4cf-4870-baf9-3b0c53f39273"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/70602050-6a04-4c63-bb1a-7b35e44a8c79"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/fb543a9b-5235-45d4-b102-d57d21b2e237"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/8c1c5048-6c4a-40c9-b234-c5c5e0d53dc1"/>|
2221

2322

24-
2523
## Getting Started
2624

2725
Requirement
@@ -100,7 +98,6 @@ Use /help for help
10098

10199
## Supported API Service
102100

103-
104101
|name|type|status| installation|
105102
|----|-----|----------------|---|
106103
|[Openai GPT 3.5/4](https://openai.com/product#made-for-developers) |LLM| ✅ fully supported|use `og_up` then choose the `OpenAI`|
@@ -133,4 +130,3 @@ if you have any feature suggestion. please create a discuession to talk about it
133130

134131
* [roadmap for v0.5.0](https://github.com/dbpunk-labs/octogen/issues/64)
135132

136-

agent/src/og_agent/codellama_agent.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,18 @@ async def handle_show_sample_code(
6767
"language": json_response.get("language", "text"),
6868
})
6969
await queue.put(
70-
TaskRespond(
71-
state=task_context.to_task_state_proto(),
72-
respond_type=TaskRespond.OnAgentActionType,
73-
on_agent_action=OnAgentAction(
70+
TaskResponse(
71+
state=task_context.to_context_state_proto(),
72+
response_type=TaskResponse.OnStepActionStart,
73+
on_step_agent_start=OnStepActionStart(
7474
input=tool_input, tool="show_sample_code"
7575
),
7676
)
7777
)
7878

79-
async def handle_bash_code(self, json_response, queue, context, task_context):
79+
async def handle_bash_code(
80+
self, json_response, queue, context, task_context, task_opt
81+
):
8082
commands = json_response["action_input"]
8183
code = f"%%bash\n {commands}"
8284
explanation = json_response["explanation"]
@@ -88,10 +90,10 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
8890
"language": json_response.get("language"),
8991
})
9092
await queue.put(
91-
TaskRespond(
92-
state=task_context.to_task_state_proto(),
93-
respond_type=TaskRespond.OnAgentActionType,
94-
on_agent_action=OnAgentAction(
93+
TaskResponse(
94+
state=task_context.to_context_state_proto(),
95+
response_type=TaskResponse.OnStepActionStart,
96+
on_step_action_start=OnStepActionStart(
9597
input=tool_input, tool="execute_bash_code"
9698
),
9799
)
@@ -102,7 +104,7 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
102104
logger.debug("the client has cancelled the request")
103105
break
104106
function_result = result
105-
if respond:
107+
if respond and task_opt.streaming:
106108
await queue.put(respond)
107109
return function_result
108110

agent/src/og_agent/mock_agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99
import logging
1010
from .base_agent import BaseAgent, TypingState, TaskContext
11-
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer,TypingContent
11+
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer, TypingContent
1212
from .tokenizer import tokenize
1313

1414
logger = logging.getLogger(__name__)
@@ -33,15 +33,19 @@ async def call_ai(self, prompt, queue, iteration, task_context):
3333
TaskResponse(
3434
state=task_context.to_context_state_proto(),
3535
response_type=TaskResponse.OnModelTypeText,
36-
typing_content=TypingContent(content = message["explanation"], language="text"),
36+
typing_content=TypingContent(
37+
content=message["explanation"], language="text"
38+
),
3739
)
3840
)
3941
if message.get("code", None):
4042
await queue.put(
4143
TaskResponse(
4244
state=task_context.to_context_state_proto(),
4345
response_type=TaskResponse.OnModelTypeCode,
44-
typing_content=TypingContent(content = message["code"], language="python"),
46+
typing_content=TypingContent(
47+
content=message["code"], language="python"
48+
),
4549
)
4650
)
4751
return message

agent/src/og_agent/openai_agent.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ async def call_openai(self, messages, queue, context, task_context, task_opt):
139139
"""
140140
call the openai api
141141
"""
142+
logger.debug(f"call openai with messages {messages}")
142143
input_token_count = 0
143144
for message in messages:
144145
if not message["content"]:
@@ -212,10 +213,10 @@ async def call_openai(self, messages, queue, context, task_context, task_opt):
212213
code_content = code_str
213214
if task_opt.streaming and len(typed_chars) > 0:
214215
typing_language = "text"
215-
if (
216-
delta["function_call"].get("name", "")
217-
== "execute_python_code"
218-
):
216+
if delta["function_call"].get("name", "") in [
217+
"execute_python_code",
218+
"python",
219+
]:
219220
typing_language = "python"
220221
elif (
221222
delta["function_call"].get("name", "")
@@ -357,6 +358,8 @@ async def arun(self, task, queue, context, task_opt):
357358
if "function_call" in chat_message:
358359
if "content" not in chat_message:
359360
chat_message["content"] = None
361+
if "role" not in chat_message:
362+
chat_message["role"] = "assistant"
360363
messages.append(chat_message)
361364
function_name = chat_message["function_call"]["name"]
362365
if function_name not in [

agent/tests/codellama_agent_tests.py

Lines changed: 155 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import json
1111
import logging
1212
import pytest
13-
from og_sdk.agent_sdk import AgentSDK
13+
14+
from og_sdk.kernel_sdk import KernelSDK
1415
from og_agent.codellama_agent import CodellamaAgent
1516
from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse
1617
import asyncio
@@ -21,6 +22,14 @@
2122
logger = logging.getLogger(__name__)
2223

2324

25+
@pytest.fixture
26+
def kernel_sdk():
27+
endpoint = (
28+
"localhost:9527" # Replace with the actual endpoint of your test gRPC server
29+
)
30+
return KernelSDK(endpoint, "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH")
31+
32+
2433
class PayloadStream:
2534

2635
def __init__(self, payload):
@@ -49,24 +58,157 @@ def done(self):
4958

5059
class CodellamaMockClient:
5160

52-
def __init__(self, payload):
53-
self.payload = payload
61+
def __init__(self, payloads):
62+
self.payloads = payloads
63+
self.index = 0
5464

5565
async def prompt(self, question, chat_history=[]):
56-
async for line in PayloadStream(self.payload):
66+
if self.index >= len(self.payloads):
67+
raise StopAsyncIteration
68+
self.index += 1
69+
payload = self.payloads[self.index - 1]
70+
async for line in PayloadStream(payload):
5771
yield line
5872

5973

60-
@pytest_asyncio.fixture
61-
async def agent_sdk():
62-
sdk = AgentSDK(api_base, api_key)
63-
sdk.connect()
64-
yield sdk
65-
await sdk.close()
74+
@pytest.mark.asyncio
75+
async def test_codellama_agent_execute_bash_code(kernel_sdk):
76+
kernel_sdk.connect()
77+
sentence1 = {
78+
"explanation": "print a hello world using python",
79+
"action": "execute_bash_code",
80+
"action_input": "echo 'hello world'",
81+
"saved_filenames": [],
82+
"language": "python",
83+
"is_final_answer": False,
84+
}
85+
sentence2 = {
86+
"explanation": "the output matchs the goal",
87+
"action": "no_action",
88+
"action_input": "",
89+
"saved_filenames": [],
90+
"language": "en",
91+
"is_final_answer": False,
92+
}
93+
client = CodellamaMockClient([json.dumps(sentence1), json.dumps(sentence2)])
94+
agent = CodellamaAgent(client, kernel_sdk)
95+
task_opt = ProcessOptions(
96+
streaming=True,
97+
llm_name="codellama",
98+
input_token_limit=100000,
99+
output_token_limit=100000,
100+
timeout=5,
101+
)
102+
queue = asyncio.Queue()
103+
await agent.arun("write a hello world in bash", queue, MockContext(), task_opt)
104+
responses = []
105+
while True:
106+
try:
107+
response = await queue.get()
108+
if not response:
109+
break
110+
responses.append(response)
111+
except asyncio.QueueEmpty:
112+
break
113+
logger.info(responses)
114+
console_output = list(
115+
filter(
116+
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
117+
responses,
118+
)
119+
)
120+
assert len(console_output) == 1, "bad console output count"
121+
assert console_output[0].console_stdout == "hello world\n", "bad console output"
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_codellama_agent_execute_python_code(kernel_sdk):
126+
kernel_sdk.connect()
127+
sentence1 = {
128+
"explanation": "print a hello world using python",
129+
"action": "execute_python_code",
130+
"action_input": "print('hello world')",
131+
"saved_filenames": [],
132+
"language": "python",
133+
"is_final_answer": False,
134+
}
135+
sentence2 = {
136+
"explanation": "the output matchs the goal",
137+
"action": "no_action",
138+
"action_input": "",
139+
"saved_filenames": [],
140+
"language": "en",
141+
"is_final_answer": False,
142+
}
143+
client = CodellamaMockClient([json.dumps(sentence1), json.dumps(sentence2)])
144+
agent = CodellamaAgent(client, kernel_sdk)
145+
task_opt = ProcessOptions(
146+
streaming=True,
147+
llm_name="codellama",
148+
input_token_limit=100000,
149+
output_token_limit=100000,
150+
timeout=5,
151+
)
152+
queue = asyncio.Queue()
153+
await agent.arun("write a hello world in python", queue, MockContext(), task_opt)
154+
responses = []
155+
while True:
156+
try:
157+
response = await queue.get()
158+
if not response:
159+
break
160+
responses.append(response)
161+
except asyncio.QueueEmpty:
162+
break
163+
logger.info(responses)
164+
console_output = list(
165+
filter(
166+
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
167+
responses,
168+
)
169+
)
170+
assert len(console_output) == 1, "bad console output count"
171+
assert console_output[0].console_stdout == "hello world\n", "bad console output"
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_codellama_agent_show_demo_code(kernel_sdk):
176+
sentence = {
177+
"explanation": "Hello, how can I help you?",
178+
"action": "show_demo_code",
179+
"action_input": "echo 'hello world'",
180+
"saved_filenames": [],
181+
"language": "shell",
182+
"is_final_answer": True,
183+
}
184+
client = CodellamaMockClient([json.dumps(sentence)])
185+
agent = CodellamaAgent(client, kernel_sdk)
186+
task_opt = ProcessOptions(
187+
streaming=True,
188+
llm_name="codellama",
189+
input_token_limit=100000,
190+
output_token_limit=100000,
191+
timeout=5,
192+
)
193+
queue = asyncio.Queue()
194+
await agent.arun("hello", queue, MockContext(), task_opt)
195+
responses = []
196+
while True:
197+
try:
198+
response = await queue.get()
199+
if not response:
200+
break
201+
responses.append(response)
202+
except asyncio.QueueEmpty:
203+
break
204+
logger.info(responses)
205+
assert (
206+
responses[-1].response_type == TaskResponse.OnFinalAnswer
207+
), "bad response type"
66208

67209

68210
@pytest.mark.asyncio
69-
async def test_codellama_agent_smoke_test(agent_sdk):
211+
async def test_codellama_agent_smoke_test(kernel_sdk):
70212
sentence = {
71213
"explanation": "Hello, how can I help you?",
72214
"action": "no_action",
@@ -75,8 +217,8 @@ async def test_codellama_agent_smoke_test(agent_sdk):
75217
"language": "en",
76218
"is_final_answer": True,
77219
}
78-
client = CodellamaMockClient(json.dumps(sentence))
79-
agent = CodellamaAgent(client, agent_sdk)
220+
client = CodellamaMockClient([json.dumps(sentence)])
221+
agent = CodellamaAgent(client, kernel_sdk)
80222
task_opt = ProcessOptions(
81223
streaming=True,
82224
llm_name="codellama",

0 commit comments

Comments
 (0)