Skip to content

Commit

Permalink
Merge pull request #169 from dbpunk-labs/feat/refactor
Browse files Browse the repository at this point in the history
fix: fix codellama agent bug
  • Loading branch information
imotai authored Oct 23, 2023
2 parents 95e9454 + 956ab83 commit 0e9ce20
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 66 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
https://github.com/dbpunk-labs/octogen/assets/8623385/7445cc4d-567e-4d1a-bedc-b5b566329c41


|Supported OSs|Supported Interpreters|Supported Dev Enviroment|
|----|-----|-----|
|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/31b907e9-3a6f-4e9e-b0c0-f01d1e758a21"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/565d5f93-baac-4a77-ab1c-7d845e2fdb6d"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/acb7f919-ef09-446e-b1bc-0b50bc28de5a"/>|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/6e286d3d-55f8-43df-ade6-38065b78eda1"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/958d23a0-777c-4bb9-8480-c7350c128c3f"/>|<img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/ec8d5bff-f4cf-4870-baf9-3b0c53f39273"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/70602050-6a04-4c63-bb1a-7b35e44a8c79"/><img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/fb543a9b-5235-45d4-b102-d57d21b2e237"/> <img width="40px" src="https://github.com/dbpunk-labs/octogen/assets/8623385/8c1c5048-6c4a-40c9-b234-c5c5e0d53dc1"/>|



## Getting Started

Requirement
Expand Down Expand Up @@ -100,7 +98,6 @@ Use /help for help

## Supported API Service


|name|type|status| installation|
|----|-----|----------------|---|
|[Openai GPT 3.5/4](https://openai.com/product#made-for-developers) |LLM| ✅ fully supported|use `og_up` then choose the `OpenAI`|
Expand Down Expand Up @@ -133,4 +130,3 @@ if you have any feature suggestion. please create a discuession to talk about it

* [roadmap for v0.5.0](https://github.com/dbpunk-labs/octogen/issues/64)


22 changes: 12 additions & 10 deletions agent/src/og_agent/codellama_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,18 @@ async def handle_show_sample_code(
"language": json_response.get("language", "text"),
})
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionType,
on_agent_action=OnAgentAction(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_agent_start=OnStepActionStart(
input=tool_input, tool="show_sample_code"
),
)
)

async def handle_bash_code(self, json_response, queue, context, task_context):
async def handle_bash_code(
self, json_response, queue, context, task_context, task_opt
):
commands = json_response["action_input"]
code = f"%%bash\n {commands}"
explanation = json_response["explanation"]
Expand All @@ -88,10 +90,10 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
"language": json_response.get("language"),
})
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionType,
on_agent_action=OnAgentAction(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_action_start=OnStepActionStart(
input=tool_input, tool="execute_bash_code"
),
)
Expand All @@ -102,7 +104,7 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
logger.debug("the client has cancelled the request")
break
function_result = result
if respond:
if respond and task_opt.streaming:
await queue.put(respond)
return function_result

Expand Down
10 changes: 7 additions & 3 deletions agent/src/og_agent/mock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import logging
from .base_agent import BaseAgent, TypingState, TaskContext
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer,TypingContent
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer, TypingContent
from .tokenizer import tokenize

logger = logging.getLogger(__name__)
Expand All @@ -33,15 +33,19 @@ async def call_ai(self, prompt, queue, iteration, task_context):
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(content = message["explanation"], language="text"),
typing_content=TypingContent(
content=message["explanation"], language="text"
),
)
)
if message.get("code", None):
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=TypingContent(content = message["code"], language="python"),
typing_content=TypingContent(
content=message["code"], language="python"
),
)
)
return message
Expand Down
11 changes: 7 additions & 4 deletions agent/src/og_agent/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async def call_openai(self, messages, queue, context, task_context, task_opt):
"""
call the openai api
"""
logger.debug(f"call openai with messages {messages}")
input_token_count = 0
for message in messages:
if not message["content"]:
Expand Down Expand Up @@ -212,10 +213,10 @@ async def call_openai(self, messages, queue, context, task_context, task_opt):
code_content = code_str
if task_opt.streaming and len(typed_chars) > 0:
typing_language = "text"
if (
delta["function_call"].get("name", "")
== "execute_python_code"
):
if delta["function_call"].get("name", "") in [
"execute_python_code",
"python",
]:
typing_language = "python"
elif (
delta["function_call"].get("name", "")
Expand Down Expand Up @@ -357,6 +358,8 @@ async def arun(self, task, queue, context, task_opt):
if "function_call" in chat_message:
if "content" not in chat_message:
chat_message["content"] = None
if "role" not in chat_message:
chat_message["role"] = "assistant"
messages.append(chat_message)
function_name = chat_message["function_call"]["name"]
if function_name not in [
Expand Down
168 changes: 155 additions & 13 deletions agent/tests/codellama_agent_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import json
import logging
import pytest
from og_sdk.agent_sdk import AgentSDK

from og_sdk.kernel_sdk import KernelSDK
from og_agent.codellama_agent import CodellamaAgent
from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse
import asyncio
Expand All @@ -21,6 +22,14 @@
logger = logging.getLogger(__name__)


@pytest.fixture
def kernel_sdk():
endpoint = (
"localhost:9527" # Replace with the actual endpoint of your test gRPC server
)
return KernelSDK(endpoint, "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH")


class PayloadStream:

def __init__(self, payload):
Expand Down Expand Up @@ -49,24 +58,157 @@ def done(self):

class CodellamaMockClient:

def __init__(self, payload):
self.payload = payload
def __init__(self, payloads):
self.payloads = payloads
self.index = 0

async def prompt(self, question, chat_history=[]):
async for line in PayloadStream(self.payload):
if self.index >= len(self.payloads):
raise StopAsyncIteration
self.index += 1
payload = self.payloads[self.index - 1]
async for line in PayloadStream(payload):
yield line


@pytest_asyncio.fixture
async def agent_sdk():
sdk = AgentSDK(api_base, api_key)
sdk.connect()
yield sdk
await sdk.close()
@pytest.mark.asyncio
async def test_codellama_agent_execute_bash_code(kernel_sdk):
kernel_sdk.connect()
sentence1 = {
"explanation": "print a hello world using python",
"action": "execute_bash_code",
"action_input": "echo 'hello world'",
"saved_filenames": [],
"language": "python",
"is_final_answer": False,
}
sentence2 = {
"explanation": "the output matchs the goal",
"action": "no_action",
"action_input": "",
"saved_filenames": [],
"language": "en",
"is_final_answer": False,
}
client = CodellamaMockClient([json.dumps(sentence1), json.dumps(sentence2)])
agent = CodellamaAgent(client, kernel_sdk)
task_opt = ProcessOptions(
streaming=True,
llm_name="codellama",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
queue = asyncio.Queue()
await agent.arun("write a hello world in bash", queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
console_output = list(
filter(
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
responses,
)
)
assert len(console_output) == 1, "bad console output count"
assert console_output[0].console_stdout == "hello world\n", "bad console output"


@pytest.mark.asyncio
async def test_codellama_agent_execute_python_code(kernel_sdk):
kernel_sdk.connect()
sentence1 = {
"explanation": "print a hello world using python",
"action": "execute_python_code",
"action_input": "print('hello world')",
"saved_filenames": [],
"language": "python",
"is_final_answer": False,
}
sentence2 = {
"explanation": "the output matchs the goal",
"action": "no_action",
"action_input": "",
"saved_filenames": [],
"language": "en",
"is_final_answer": False,
}
client = CodellamaMockClient([json.dumps(sentence1), json.dumps(sentence2)])
agent = CodellamaAgent(client, kernel_sdk)
task_opt = ProcessOptions(
streaming=True,
llm_name="codellama",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
queue = asyncio.Queue()
await agent.arun("write a hello world in python", queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
console_output = list(
filter(
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
responses,
)
)
assert len(console_output) == 1, "bad console output count"
assert console_output[0].console_stdout == "hello world\n", "bad console output"


@pytest.mark.asyncio
async def test_codellama_agent_show_demo_code(kernel_sdk):
sentence = {
"explanation": "Hello, how can I help you?",
"action": "show_demo_code",
"action_input": "echo 'hello world'",
"saved_filenames": [],
"language": "shell",
"is_final_answer": True,
}
client = CodellamaMockClient([json.dumps(sentence)])
agent = CodellamaAgent(client, kernel_sdk)
task_opt = ProcessOptions(
streaming=True,
llm_name="codellama",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
queue = asyncio.Queue()
await agent.arun("hello", queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
assert (
responses[-1].response_type == TaskResponse.OnFinalAnswer
), "bad response type"


@pytest.mark.asyncio
async def test_codellama_agent_smoke_test(agent_sdk):
async def test_codellama_agent_smoke_test(kernel_sdk):
sentence = {
"explanation": "Hello, how can I help you?",
"action": "no_action",
Expand All @@ -75,8 +217,8 @@ async def test_codellama_agent_smoke_test(agent_sdk):
"language": "en",
"is_final_answer": True,
}
client = CodellamaMockClient(json.dumps(sentence))
agent = CodellamaAgent(client, agent_sdk)
client = CodellamaMockClient([json.dumps(sentence)])
agent = CodellamaAgent(client, kernel_sdk)
task_opt = ProcessOptions(
streaming=True,
llm_name="codellama",
Expand Down
Loading

0 comments on commit 0e9ce20

Please sign in to comment.