Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent improvements: Adopt system instructions and allow multiple command executions #717

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
71bc0ef
New system instructions
DonggeLiu Nov 12, 2024
3b25840
Apply system instructions
DonggeLiu Nov 12, 2024
3498faf
Refine priming
DonggeLiu Nov 12, 2024
3c8d99a
Minor correction
DonggeLiu Nov 12, 2024
9477c02
Bug fix
DonggeLiu Nov 12, 2024
8cdbaca
Do not emphasize on simple/minimum fuzz target
DonggeLiu Nov 12, 2024
1759a67
Complete conclusion protocol
DonggeLiu Nov 12, 2024
8b83ef1
Minimize priming
DonggeLiu Nov 12, 2024
8b17486
Visually separate RESPONSE/PROMPT and their content by a line break
DonggeLiu Nov 13, 2024
5272385
Strip empty lines and spaces from bash output
DonggeLiu Nov 13, 2024
adb9e72
Allow executing multiple bash commands in one response
DonggeLiu Nov 13, 2024
4c41553
Allow passing system instructions to LLM
DonggeLiu Nov 13, 2024
e8e737e
More concise objective and instructions
DonggeLiu Nov 13, 2024
4adf624
Make code consistent
DonggeLiu Nov 13, 2024
ebf36c0
lower input token limit by system instruction token size
DonggeLiu Nov 13, 2024
b0d6d3c
Simplify system instruction
DonggeLiu Nov 13, 2024
8d1de5b
Consider previous text in the same prompt when truncate new text
DonggeLiu Nov 14, 2024
dc007b4
minor fix
DonggeLiu Nov 14, 2024
f474d16
ASK LLM do not compile
DonggeLiu Nov 14, 2024
6286ae5
Prioritize understanding over retrying
DonggeLiu Nov 14, 2024
122b0a5
Remove the compile command so that LLM cannot learn
DonggeLiu Nov 14, 2024
1ca97d3
Debug truncating prompt
DonggeLiu Nov 14, 2024
611b4b7
Reduce unnecessary logs
DonggeLiu Nov 14, 2024
588afc9
Fix bug to remove compile command from build result
DonggeLiu Nov 14, 2024
76d7975
Simpler debugging
DonggeLiu Nov 14, 2024
2b24d05
Fix bug in truncation
DonggeLiu Nov 14, 2024
610efda
Set log level
DonggeLiu Nov 14, 2024
b2be748
Fix truncation and be more strict on individual output size limit
DonggeLiu Nov 14, 2024
f189d18
Retry on VertexAI's InternalServerError
DonggeLiu Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def get_tool(self, tool_name: str) -> Optional[BaseTool]:

def chat_llm(self, cur_round: int, client: Any, prompt: Prompt) -> str:
"""Chat with LLM."""
logger.info('<CHAT PROMPT:ROUND %02d>%s</CHAT PROMPT:ROUND %02d>',
logger.info('<CHAT PROMPT:ROUND %02d>\n%s\n</CHAT PROMPT:ROUND %02d>',
cur_round, prompt.get(), cur_round)
response = self.llm.chat_llm(client=client, prompt=prompt)
logger.info('<CHAT RESPONSE:ROUND %02d>%s</CHAT RESPONSE:ROUND %02d>',
logger.info('<CHAT RESPONSE:ROUND %02d>\n%s\n</CHAT RESPONSE:ROUND %02d>',
cur_round, response, cur_round)
return response

Expand All @@ -56,6 +56,11 @@ def _parse_tag(self, response: str, tag: str) -> str:
match = re.search(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
return match.group(1).strip() if match else ''

def _parse_tags(self, response: str, tag: str) -> list[str]:
"""Parses the XML-style tags from LLM response."""
matches = re.findall(rf'<{tag}>(.*?)</{tag}>', response, re.DOTALL)
return [content.strip() for content in matches]

def _filter_code(self, raw_code_block: str) -> str:
"""Filters out irrelevant lines from |raw_code_block|."""
# TODO(dongge): Move this function to a separate module.
Expand All @@ -67,11 +72,20 @@ def _filter_code(self, raw_code_block: str) -> str:
filtered_code_block = '\n'.join(filtered_lines)
return filtered_code_block

def _format_bash_execution_result(self, process: sp.CompletedProcess) -> str:
def _format_bash_execution_result(
self,
process: sp.CompletedProcess,
previous_prompt: Optional[Prompt] = None) -> str:
"""Formats a prompt based on bash execution result."""
stdout = self.llm.truncate_prompt(process.stdout)
if previous_prompt:
previous_prompt_text = previous_prompt.get()
else:
previous_prompt_text = ''
stdout = self.llm.truncate_prompt(process.stdout,
previous_prompt_text).strip()
# TODO(dongge) Share input limit evenly if both stdout and stderr overlong.
stderr = self.llm.truncate_prompt(process.stderr, stdout)
stderr = self.llm.truncate_prompt(process.stderr,
stdout + previous_prompt_text).strip()
return (f'<bash>\n{process.args}\n</bash>\n'
f'<return code>\n{process.returncode}\n</return code>\n'
f'<stdout>\n{stdout}\n</stdout>\n'
Expand All @@ -83,11 +97,15 @@ def _container_handle_bash_command(self, command: str,
prompt_text = self._format_bash_execution_result(tool.execute(command))
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])

def _container_handle_invalid_tool_usage(self, tool: BaseTool) -> Prompt:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
prompt_text = (f'No valid instruction received, Please follow the '
f'interaction protocols:\n{tool.tutorial()}')
return DefaultTemplateBuilder(self.llm, None, initial=prompt_text).build([])
def _container_handle_bash_commands(self, response: str, tool: BaseTool,
prompt: Prompt) -> Prompt:
"""Handles the command from LLM with container |tool|."""
prompt_text = ''
for command in self._parse_tags(response, 'bash'):
prompt_text += self._format_bash_execution_result(
tool.execute(command), previous_prompt=prompt) + '\n'
prompt.append(prompt_text)
return prompt

def _sleep_random_duration(self, min_sec: int = 1, max_sec: int = 60) -> None:
"""Sleeps for a random duration between min_sec and max_sec. Agents uses
Expand Down
96 changes: 61 additions & 35 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""An LLM agent to generate a simple fuzz target prototype that can build.
Use it as a usual module locally, or as script in cloud builds.
"""
import re
import subprocess as sp
import time
from datetime import timedelta
Expand All @@ -10,7 +11,6 @@
from agent.base_agent import BaseAgent
from data_prep.project_context.context_introspector import ContextRetriever
from experiment.benchmark import Benchmark
from llm_toolkit.prompt_builder import EXAMPLES as EXAMPLE_FUZZ_TARGETS
from llm_toolkit.prompt_builder import (DefaultTemplateBuilder,
PrototyperTemplateBuilder)
from llm_toolkit.prompts import Prompt
Expand All @@ -33,38 +33,37 @@ def _initial_prompt(self, results: list[Result]) -> Prompt:
benchmark=benchmark,
)
prompt = prompt_builder.build(example_pair=[],
project_context_content=context_info,
tool_guides=self.inspect_tool.tutorial())
# prompt = prompt_builder.build(example_pair=EXAMPLE_FUZZ_TARGETS.get(
# benchmark.language, []),
# tool_guides=self.inspect_tool.tutorial())
project_context_content=context_info)
self.llm.system_instruction = prompt_builder.system_instructions(
benchmark, [
'prototyper-system-instruction-objective.txt',
'prototyper-system-instruction-protocols.txt'
])
self.protocol = self.llm.system_instruction[-1]
return prompt

def _update_fuzz_target_and_build_script(self, cur_round: int, response: str,
def _update_fuzz_target_and_build_script(self, cur_round: int,
conclusion: str,
build_result: BuildResult) -> None:
"""Updates fuzz target and build script in build_result with LLM response.
"""
fuzz_target_source = self._filter_code(
self._parse_tag(response, 'fuzz target'))
self._parse_tag(conclusion, 'fuzz target'))
build_result.fuzz_target_source = fuzz_target_source
if fuzz_target_source:
logger.debug('ROUND %02d Parsed fuzz target from LLM: %s', cur_round,
logger.debug('ROUND %02d Parsed fuzz target from LLM:\n%s', cur_round,
fuzz_target_source)
else:
logger.error('ROUND %02d No fuzz target source code in conclusion: %s',
cur_round, response)
logger.error('ROUND %02d No fuzz target source code in conclusion.',
cur_round)

build_script_source = self._filter_code(
self._parse_tag(response, 'build script'))
# Sometimes LLM adds chronos, which makes no sense for new build scripts.
build_result.build_script_source = build_script_source.replace(
'source /src/chronos.sh', '')
self._parse_tag(conclusion, 'build script'))
if build_script_source:
logger.debug('ROUND %02d Parsed build script from LLM: %s', cur_round,
logger.debug('ROUND %02d Parsed build script from LLM:\n%s', cur_round,
build_script_source)
else:
logger.debug('ROUND %02d No build script in conclusion: %s', cur_round,
response)
logger.debug('ROUND %02d No build script in conclusion.', cur_round)

def _update_build_result(self, build_result: BuildResult,
compile_process: sp.CompletedProcess, status: bool,
Expand All @@ -74,6 +73,11 @@ def _update_build_result(self, build_result: BuildResult,
build_result.compile_error = compile_process.stderr
build_result.compile_log = self._format_bash_execution_result(
compile_process)
# Remove the compile command, e.g., <bash>compile</bash>
build_result.compile_log = re.sub(r'<bash>.*?</bash>',
'',
build_result.compile_log,
flags=re.DOTALL)
build_result.is_function_referenced = referenced

def _validate_fuzz_target_and_build_script(self, cur_round: int,
Expand Down Expand Up @@ -159,22 +163,26 @@ def _validate_fuzz_target_and_build_script_via_compile(
status=compile_succeed and binary_exists,
referenced=function_referenced)

def _container_handle_conclusion(
self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
def _container_handle_conclusion(self, cur_round: int, response: str,
build_result: BuildResult,
prompt: Prompt) -> Optional[Prompt]:
"""Runs a compilation tool to validate the new fuzz target and build script
from LLM."""
conclusion = self._parse_tag(response, 'conclusion')
if not conclusion:
return prompt
logger.info('----- ROUND %02d Received conclusion -----', cur_round)

self._update_fuzz_target_and_build_script(cur_round, response, build_result)

self._validate_fuzz_target_and_build_script(cur_round, build_result)
if build_result.success:
logger.info('***** Prototyper succeded in %02d rounds *****', cur_round)
logger.info('***** Prototyper succeeded in %02d rounds *****', cur_round)
return None

if not build_result.compiles:
compile_log = self.llm.truncate_prompt(build_result.compile_log)
compile_log = self.llm.truncate_prompt(build_result.compile_log,
extra_text=prompt.get()).strip()
logger.info('***** Failed to recompile in %02d rounds *****', cur_round)
prompt_text = (
'Failed to build fuzz target. Here is the fuzz target, build script, '
Expand All @@ -198,7 +206,7 @@ def _container_handle_conclusion(
'target function. We can increase its complexity later, but first try'
'to make it compile successfully.'
'If an error happens repeatedly and cannot be fixed, try to '
'mitigate it. For example, replace or remove the line.'
'mitigate it. For example, replace or remove the line.\n'
f'<fuzz target>\n{build_result.fuzz_target_source}\n</fuzz target>\n'
f'<build script>\n{build_result.build_script_source}\n</build script>'
f'\n<compilation log>\n{compile_log}\n</compilation log>\n')
Expand All @@ -215,23 +223,41 @@ def _container_handle_conclusion(
else:
prompt_text = ''

prompt = DefaultTemplateBuilder(self.llm, initial=prompt_text).build([])
prompt.append(prompt_text)
return prompt

def _container_handle_invalid_tool_usage(self, cur_round: int, response: str,
prompt: Prompt) -> Prompt:
"""Formats a prompt to re-teach LLM how to use the |tool|."""
logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
response)
prompt_text = (f'No valid instruction received, Please follow the system '
f'instructions:\n{self.protocol}')
prompt.append(prompt_text)
return prompt

def _container_tool_reaction(self, cur_round: int, response: str,
build_result: BuildResult) -> Optional[Prompt]:
"""Validates LLM conclusion or executes its command."""
# Prioritize Bash instructions.
if command := self._parse_tag(response, 'bash'):
return self._container_handle_bash_command(command, self.inspect_tool)
prompt = DefaultTemplateBuilder(self.llm, None).build([])

if self._parse_tag(response, 'conclusion'):
return self._container_handle_conclusion(cur_round, response,
build_result)
# Other responses are invalid.
logger.warning('ROUND %02d Invalid response from LLM: %s', cur_round,
response)
return self._container_handle_invalid_tool_usage(self.inspect_tool)
# First execute bash commands.
prompt = self._container_handle_bash_commands(response, self.inspect_tool,
prompt)

# Then build fuzz target.
prompt = self._container_handle_conclusion(cur_round, response,
build_result, prompt)
if prompt is None:
# Succeeded.
return None

# Finally check invalid responses.
if not prompt.get():
prompt = self._container_handle_invalid_tool_usage(
cur_round, response, prompt)

return prompt

def execute(self, result_history: list[Result]) -> BuildResult:
"""Executes the agent based on previous result."""
Expand Down
29 changes: 24 additions & 5 deletions llm_toolkit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
import openai
import tiktoken
import vertexai
from google.api_core.exceptions import (GoogleAPICallError, InvalidArgument,
from google.api_core.exceptions import (GoogleAPICallError,
InternalServerError, InvalidArgument,
ResourceExhausted)
from vertexai import generative_models
from vertexai.preview.generative_models import ChatSession, GenerativeModel
Expand Down Expand Up @@ -59,6 +60,7 @@ class LLM:
MAX_INPUT_TOKEN: int = sys.maxsize

_max_attempts = 5 # Maximum number of attempts to get prediction response
system_instruction: Optional[list] = None

def __init__(
self,
Expand Down Expand Up @@ -564,7 +566,8 @@ class GeminiModel(VertexAIModel):
]

def get_model(self) -> Any:
return GenerativeModel(self._vertex_ai_model)
return GenerativeModel(self._vertex_ai_model,
system_instruction=self.system_instruction)

def do_generate(self, model: Any, prompt: str, config: dict[str, Any]) -> Any:
# Loosen inapplicable restrictions just in case.
Expand Down Expand Up @@ -648,6 +651,7 @@ def get_chat_client(self, model: GenerativeModel) -> Any:
InvalidArgument,
ValueError, # TODO(dongge): Handle RECITATION specifically.
IndexError, # A known error from vertexai.
InternalServerError, # A known error from vertexai.
],
other_exceptions={ResourceExhausted: 100})
def _do_generate(self, client: ChatSession, prompt: str,
Expand All @@ -664,26 +668,41 @@ def truncate_prompt(self,
raw_prompt_text: Any,
extra_text: Any = None) -> Any:
"""Truncates the prompt text to fit in MAX_INPUT_TOKEN."""
original_token_count = self.estimate_token_num(raw_prompt_text)
if self.system_instruction:
system_instructions = ''.join(self.system_instruction)
else:
system_instructions = ''
original_token_count = self.estimate_token_num(raw_prompt_text +
system_instructions)

token_count = original_token_count
logger.warning('original_token_count: %s', original_token_count)
logger.warning('self.MAX_INPUT_TOKEN: %s', self.MAX_INPUT_TOKEN)
if token_count > self.MAX_INPUT_TOKEN:
raw_prompt_text = raw_prompt_text[-3 * self.MAX_INPUT_TOKEN:]

logger.warning('raw_prompt_text: %s', raw_prompt_text)
extra_text_token_count = self.estimate_token_num(extra_text)
# Reserve 10000 tokens for raw prompt wrappers.
max_raw_prompt_token_size = (self.MAX_INPUT_TOKEN - extra_text_token_count -
10000)
logger.warning('max_raw_prompt_token_size: %s', max_raw_prompt_token_size)

while token_count > max_raw_prompt_token_size:
while token_count > max_raw_prompt_token_size // 4:
estimate_truncate_size = int(
(1 - max_raw_prompt_token_size / token_count) * len(raw_prompt_text))
raw_prompt_text = raw_prompt_text[estimate_truncate_size + 1:]

logger.warning('estimate_truncate_size: %s', estimate_truncate_size)
raw_prompt_init = raw_prompt_text[:100] + (
'\n...(truncated due to exceeding input token limit)...\n')
raw_prompt_text = raw_prompt_init + raw_prompt_text[
100 + estimate_truncate_size + 1:]

token_count = self.estimate_token_num(raw_prompt_text)
logger.warning('Truncated raw prompt from %d to %d tokens:',
original_token_count, token_count)

logger.warning('Final token_count: %s', token_count)
return raw_prompt_text

def chat_llm(self, client: ChatSession, prompt: prompts.Prompt) -> str:
Expand Down
24 changes: 24 additions & 0 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,30 @@ def build(self,
self._prompt.append(tool_guides)
return self._prompt

def _get_instruction(self, benchmark: Benchmark,
instruction_name: str) -> str:
"""Gets the agent objective."""
if not instruction_name:
return ''
template_file = self._find_template(self.agent_templare_dir,
instruction_name)
instruction = self._get_template(template_file).replace(
'{LANGUAGE}', benchmark.file_type.value).replace(
'{FUNCTION-UNDER-TEST}',
benchmark.function_signature).replace('{FUZZ_TARGET_PATH}',
benchmark.target_path)
return instruction

def system_instructions(self, benchmark: Benchmark,
instruction_names: list[str]) -> list[str]:
"""Constructs a list of system instructions in plain text."""
instructions = []
for instruction_name in instruction_names:
instruction = self._get_instruction(benchmark, instruction_name)
if instruction:
instructions.append(instruction)
return instructions


class DefaultJvmTemplateBuilder(PromptBuilder):
"""Default builder for JVM projects."""
Expand Down
1 change: 1 addition & 0 deletions llm_toolkit/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class Prompt:
"""Base prompt."""

@abstractmethod
def __init__(self, initial=None):
"""Constructor."""

Expand Down
1 change: 1 addition & 0 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_trial_logger(name: str = __name__,
return _trial_logger

logger = logging.getLogger(name)
logger.setLevel(level)
if not logger.handlers:
formatter = logging.Formatter(
fmt=('%(asctime)s [Trial ID: %(trial)02d] %(levelname)s '
Expand Down
Loading