diff --git a/programmer-ui/ui.py b/programmer-ui/ui.py index e09d7ea..4dcb55b 100644 --- a/programmer-ui/ui.py +++ b/programmer-ui/ui.py @@ -6,10 +6,17 @@ import streamlit as st import weave import os +import openai +import copy from weave.trace.weave_client import WeaveClient from programmer.weave_next.api import init_local_client -from programmer.weave_next.weave_query import calls, expand_refs +from programmer.weave_next.weave_query import ( + calls, + expand_refs, + get_call, + expand_json_refs, +) from programmer.settings_manager import SettingsManager st.set_page_config(layout="wide") @@ -47,7 +54,46 @@ def init_from_settings() -> WeaveClient: raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}") -client = init_from_settings() +# Add sidebar for Weave project configuration +with st.sidebar: + st.header("Weave Project Configuration") + + # Initialize from settings + initial_weave_logging = SettingsManager.get_setting("weave_logging") + initial_project_type = "local" if initial_weave_logging == "local" else "cloud" + initial_project_path = ( + os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db") + if initial_weave_logging == "local" + else "" + ) + initial_project_name = ( + f"programmer-{os.path.basename(os.path.abspath(os.curdir))}" + if initial_weave_logging == "cloud" + else "" + ) + + project_type = st.radio( + "Project Type", + ["local", "cloud"], + index=0 if initial_project_type == "local" else 1, + ) + + if project_type == "local": + project_path = st.text_input("Local DB Path", value=initial_project_path) + # SettingsManager.set_setting("weave_logging", "local") + # SettingsManager.set_setting("weave_db_path", project_path) + client = init_local_weave(project_path) + print("C2", client._project_id()) + else: + # SettingsManager.set_setting("weave_logging", "cloud") + # SettingsManager.set_setting("weave_project_name", project_name) + project_name = st.text_input("Cloud Project Name", value=initial_project_name) + client = init_remote_weave(project_name) + print("C3", client._project_id()) + +# Initialize client based on current settings +# client = init_from_settings() +print("CLIENT", client._project_id()) def set_focus_step_id(call_id): @@ -76,6 +122,16 @@ def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]): return expand_refs(wc, refs).to_pandas() +@st.cache_data(hash_funcs=ST_HASH_FUNCS) +def cached_get_call(wc: WeaveClient, call_id: str): + return get_call(wc, call_id) + + +@st.cache_data(hash_funcs=ST_HASH_FUNCS) +def cached_expand_json_refs(wc: WeaveClient, json: dict): + return expand_json_refs(wc, json) + + def print_step_call(call): start_history = call["inputs.state.history"] end_history = call["output.history"] @@ -174,30 +230,234 @@ def print_session_call(session_id): ) -session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"]) -if len(session_calls_df) == 0: - st.error("No programmer sessions found.") - st.stop() -session_user_message_df = session_calls_df["inputs.agent_state.history"].apply( - lambda v: v[-1]["content"] -) - +def sessions_page(): + session_calls_df = cached_calls( + client, "session", expand_refs=["inputs.agent_state"] + ) + if len(session_calls_df) == 0: + st.error("No programmer sessions found.") + st.stop() + session_user_message_df = session_calls_df["inputs.agent_state.history"].apply( + lambda v: v[-1]["content"] + ) + with st.sidebar: + st.header("Session Selection") + if st.button("Refresh"): + st.cache_data.clear() + st.rerun() + message_ids = { + f"{cid[-5:]}: {m}": cid + for cid, m in reversed( + list(zip(session_calls_df["id"], session_user_message_df)) + ) + } + sel_message = st.radio("Session", options=message_ids.keys()) + sel_id = None + if sel_message: + sel_id = message_ids.get(sel_message) + if sel_id: + st.header(f"Session: {sel_id}") + print_session_call(sel_id) + + +sessions_pg = st.Page(sessions_page, title="Sessions") + + +# def write_chat_message(m, key): +# with st.chat_message(m["role"]): +# if "content" in m: +# st.text_area( +# "", value=str(m["content"]), label_visibility="collapsed", key=key +# ) +def write_chat_message(m, key, readonly=False): + def on_change_content(): + new_value = st.session_state[key] + st.session_state.playground_state["editable_call"]["inputs"]["messages"][ + m["original_index"] + ]["content"] = new_value + + with st.chat_message(m["role"]): + if m.get("content"): + if readonly: + st.code(m["content"]) + else: + st.text_area( + "", + value=m["content"], + label_visibility="collapsed", + key=key, + on_change=on_change_content, + ) + if m.get("tool_calls"): + for t in m["tool_calls"]: + st.write(t["function"]["name"]) + st.json( + { + "arguments": t["function"]["arguments"], + "response": t.get("response", {}).get("content"), + }, + expanded=True, + ) + + +def attach_tool_call_responses(messages): + new_messages = [] + for i, m in enumerate(messages): + new_m = copy.deepcopy(m) + new_m["original_index"] = i + if new_m["role"] == "assistant" and "tool_calls" in new_m: + new_m["tool_call_responses"] = [] + for t in new_m["tool_calls"]: + t_id = t["id"] + for j, t_response in enumerate(messages): + if t_response.get("tool_call_id") == t_id: + t["response"] = t_response + t["response"]["original_index"] = j + break + if "tool_call_id" not in new_m: + new_messages.append(new_m) + return new_messages + + +def playground_page(): + with st.sidebar: + if not st.session_state.get("playground_state"): + st.session_state.playground_state = { + "call_id": None, + "call": None, + "expanded_call": None, + "editable_call": None, + } + playground_state = st.session_state.playground_state + call_id = st.text_input("Call ID") + if not call_id: + st.error("Please set call ID") + st.stop() + + # st.write(playground_state) + if playground_state["expanded_call"] != playground_state["editable_call"]: + st.warning("Call has been modified") + if st.button("Restore original call"): + st.session_state.playground_state["editable_call"] = copy.deepcopy( + playground_state["expanded_call"] + ) + st.rerun() + + if call_id != st.session_state.playground_state["call_id"]: + st.spinner("Loading call...") + call = cached_get_call(client, call_id) + editable_call = cached_expand_json_refs(client, call) + st.session_state.playground_state = { + "call_id": call_id, + "call": call, + "expanded_call": editable_call, + "editable_call": copy.deepcopy(editable_call), + } + st.rerun() + + call = st.session_state.playground_state["call"] + editable_call = st.session_state.playground_state["editable_call"] + if call is None or editable_call is None: + st.warning("call not yet loaded") + st.stop() + + st.write(call["op_name"]) + # st.json(call["inputs"]) + # st.json(call["inputs"]["tools"]) + + def on_change_temperature(): + st.session_state.playground_state["editable_call"]["inputs"][ + "temperature" + ] = st.session_state["temperature"] + + st.slider( + "Temperature", + min_value=0.0, + max_value=1.0, + value=editable_call["inputs"]["temperature"], + key="temperature", + on_change=on_change_temperature, + ) -with st.sidebar: - if st.button("Refresh"): - st.cache_data.clear() - st.rerun() - message_ids = { - f"{cid[-5:]}: {m}": cid - for cid, m in reversed( - list(zip(session_calls_df["id"], session_user_message_df)) + tools = call["inputs"].get("tools", []) + if tools: + st.write("Tools") + for tool_idx, t in enumerate(tools): + with st.expander(t["function"]["name"]): + + def on_change_tool(): + st.session_state.playground_state["editable_call"]["inputs"][ + "tools" + ][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"]) + st.rerun() + + st.text_area( + "json", + value=json.dumps(t, indent=2), + height=300, + key=f"tool-{tool_idx}", + on_change=on_change_tool, + ) + + def on_change_parallel_tool_calls(): + st.session_state.playground_state["editable_call"]["inputs"][ + "parallel_tool_calls" + ] = st.session_state["parallel_tool_calls"] + + st.checkbox( + "Parallel tool calls", + value=editable_call["inputs"].get("parallel_tool_calls", True), + key="parallel_tool_calls", + on_change=on_change_parallel_tool_calls, ) + + inputs = editable_call["inputs"] + all_input_messages = inputs["messages"] + other_inputs = { + k: v + for k, v in inputs.items() + if (k != "messages" and k != "self" and k != "stream") } - sel_message = st.radio("Session", options=message_ids.keys()) - sel_id = None - if sel_message: - sel_id = message_ids.get(sel_message) - -if sel_id: - st.header(f"Session: {sel_id}") - print_session_call(sel_id) + + tool_call_attached_messages = attach_tool_call_responses(all_input_messages) + for i, m in enumerate(tool_call_attached_messages): + write_chat_message(m, f"message-{i}") + # output = editable_call["output"]["choices"][0]["message"] + n_choices = st.number_input( + "Number of choices", value=1, min_value=1, max_value=100 + ) + if st.button("Generate"): + chat_inputs = {**editable_call["inputs"]} + # st.json(chat_inputs, expanded=False) + del chat_inputs["stream"] + del chat_inputs["self"] + chat_inputs["n"] = n_choices + call_resp = openai.chat.completions.create(**chat_inputs).model_dump() + + editable_call["output"] = call_resp + st.rerun() + # st.json(response, expanded=False) + # output = response["choices"][0]["message"] + # st.json(output) + response = editable_call["output"] + st.write("full response") + st.json(response, expanded=False) + st.write("**system fingerprint**", response["system_fingerprint"]) + st.write("**usage**", response["usage"]) + for i, choice in enumerate(response["choices"]): + output = choice["message"] + st.write(f"Choice {i+1}") + write_chat_message(output, f"output_message-{i}", readonly=True) + + # all_messages = [*all_input_messages, output] + # st.json(st.session_state.playground_state, expanded=False) + # st.json(all_messages, expanded=False) + + # st.write(expanded_call) + + +playground_pg = st.Page(playground_page, title="Playground") + + +pg = st.navigation([sessions_pg, playground_pg]) +pg.run() diff --git a/programmer/agent.py b/programmer/agent.py index cd180ac..7c026c9 100644 --- a/programmer/agent.py +++ b/programmer/agent.py @@ -39,6 +39,12 @@ class AgentState(weave.Object): history: list[Any] = Field(default_factory=list) env_snapshot_key: Optional[EnvironmentSnapshotKey] = None + def with_history(self, history: list[Any]) -> "AgentState": + environment = get_current_environment() + msg = get_commit_message(history) + snapshot_key = environment.make_snapshot(msg) + return self.__class__(history=history, env_snapshot_key=snapshot_key) + def unweavify(v: Any) -> Any: if isinstance(v, list): @@ -55,6 +61,9 @@ class Agent(weave.Object): system_message: str tools: list[Any] = Field(default_factory=list) + def initial_state(self, history: list[Any]) -> AgentState: + return AgentState().with_history(history) + @weave.op() def step(self, state: AgentState) -> AgentState: """Run a step of the agent. @@ -118,12 +127,8 @@ def step(self, state: AgentState) -> AgentState: # new_history = state.history + new_messages new_history = weavelist_add(state.history, new_messages) - msg = get_commit_message(new_history) - - environment = get_current_environment() - snapshot_key = environment.make_snapshot(msg) - return AgentState(history=new_history, env_snapshot_key=snapshot_key) + return state.with_history(new_history) @weave.op() def run(self, state: AgentState, max_runtime_seconds: int = -1): diff --git a/programmer/agent_texteditor.py b/programmer/agent_texteditor.py new file mode 100644 index 0000000..599ca43 --- /dev/null +++ b/programmer/agent_texteditor.py @@ -0,0 +1,165 @@ +from typing import Any, Union +from pydantic import Field +import litellm +from openai.types.chat import ( + ChatCompletionMessageParam, +) + +import weave +from weave.trace.vals import WeaveList +from weave.flow.chat_util import OpenAIStream + +from .console import Console +from .tool_calling import chat_call_tool_params, perform_tool_calls +from .text_editor import ( + TextEditor, + TextEditorState, + TextEditorStateful, + open_file, + close_file_range, + replace_file_lines, + text_editor, +) +from .agent import AgentState, Agent + + +# Weave bug workaround: adding two WeaveLists can create that cause +# downstream crashes. +# Can be removed after https://github.com/wandb/weave/pull/2165 is merged. +def weavelist_add(self: Union[list, WeaveList], other: list) -> Union[list, WeaveList]: + if isinstance(self, list): + return self + other + if not isinstance(other, list): + return NotImplemented + return WeaveList(list(self) + other, server=self.server) + + +class AgentStateTextEditor(AgentState): + text_editor_state: TextEditorState = Field(default_factory=TextEditorState) + + def with_history(self, history: list[Any]) -> "AgentStateTextEditor": + next_state = super().with_history(history) + return AgentStateTextEditor( + history=next_state.history, + env_snapshot_key=next_state.env_snapshot_key, + text_editor_state=self.text_editor_state, + ) + + def with_texteditor_state( + self, text_editor_state: TextEditorState + ) -> "AgentStateTextEditor": + return AgentStateTextEditor( + history=self.history, + env_snapshot_key=self.env_snapshot_key, + text_editor_state=text_editor_state, + ) + + +def unweavify(v: Any) -> Any: + if isinstance(v, list): + return [unweavify(m) for m in v] + elif isinstance(v, dict): + return {k: unweavify(v) for k, v in v.items()} + else: + return v + + +class AgentTextEditor(Agent): + parallel_tool_calls: bool = True + text_editor: TextEditor + + def initial_state(self, history: list[Any]) -> AgentStateTextEditor: + return AgentStateTextEditor(history=history) + + @weave.op() + def step(self, state: AgentStateTextEditor) -> AgentStateTextEditor: + """Run a step of the agent. + + Args: + state: The current state of the environment. + action: The action to take. + + Returns: + The new state of the environment. + """ + Console.step_start("agent", "green") + # Printing this is ugly + # ref = weave.obj_ref(state) + # if ref: + # print("state ref:", ref.uri()) + + messages: list[ChatCompletionMessageParam] = [ + {"role": "system", "content": self.system_message}, + ] + open_file_info = state.text_editor_state.get_open_file_info() + + messages.append( + { + "role": "system", + "content": open_file_info.format_for_messages(), + } + ) + + messages += state.history + + messages.append( + { + "role": "system", + "content": open_file_info.format_for_messages(), + } + ) + + self_tools = [*self.tools] or [] + + text_editor_stateful = TextEditorStateful( + self.text_editor, state.text_editor_state + ) + + # self_tools += [open_file, close_file_range, replace_file_lines] + self_tools += [open_file, replace_file_lines] + + # make type checkers happy by passing NotGiven instead of None + tools = None + if self_tools: + tools = chat_call_tool_params(self_tools) + + Console.chat_response_start() + + # Workaround a weave bug, litellm tries to deepcopy messages which has + # a TraceDict. TraceDict is not pickable, because it has a reference to + # a weave server, which has a lock. + messages = unweavify(messages) + + stream = litellm.completion( + model=self.model_name, + temperature=self.temperature, + messages=messages, + tools=tools, + stream=True, + timeout=60, + parallel_tool_calls=self.parallel_tool_calls, + ) + wrapped_stream = OpenAIStream(stream) # type: ignore + for chunk in wrapped_stream: + if chunk.choices[0].delta.content: + Console.chat_message_content_delta(chunk.choices[0].delta.content) + + response = wrapped_stream.final_response() + response_message = response.choices[0].message + if response_message.content: + Console.chat_response_complete(response_message.content) + + new_messages = [] + # we always store the dict representations of messages in agent state + # instead of mixing in some pydantic objects. + new_messages.append(response_message.model_dump(exclude_none=True)) + if response_message.tool_calls: + with text_editor(text_editor_stateful): + new_messages.extend( + perform_tool_calls(self_tools, response_message.tool_calls) + ) + new_history = weavelist_add(state.history, new_messages) + + next_state = state.with_history(new_history) + next_state = next_state.with_texteditor_state(text_editor_stateful.state) + return next_state diff --git a/programmer/config.py b/programmer/config.py index 9b7cd56..fb81ea8 100644 --- a/programmer/config.py +++ b/programmer/config.py @@ -15,6 +15,8 @@ splice_lines_in_file, ) from .agent import Agent +from .agent_texteditor import AgentTextEditor +from .text_editor import TextEditor agent_4o_basic = Agent( name="gpt-4o-2024-08-06_basic", @@ -96,3 +98,32 @@ splice_lines_in_file, ], ) + +text_editor = TextEditor(max_open_size=15000, open_chunk_size=2000) +agent_texteditor_4o_basic = AgentTextEditor( + name="gpt-4o-2024-08-06_texteditor_basic", + model_name="gpt-4o-2024-08-06", + temperature=0.7, + system_message=SYSTEM_MESSAGE, + text_editor=text_editor, + tools=[list_files, run_command, view_image], +) + +agent_texteditor_4o_basic_temp0 = AgentTextEditor( + name="gpt-4o-2024-08-06_texteditor_basic_temp0", + model_name="gpt-4o-2024-08-06", + temperature=0.0, + system_message=SYSTEM_MESSAGE, + text_editor=text_editor, + tools=[list_files, run_command, view_image], +) + +agent_texteditor_4o_basic_noparalleltc = AgentTextEditor( + name="gpt-4o-2024-08-06_texteditor_basic_noparalleltc", + model_name="gpt-4o-2024-08-06", + temperature=0.7, + system_message=SYSTEM_MESSAGE, + text_editor=text_editor, + tools=[list_files, run_command, view_image], + parallel_tool_calls=False, +) diff --git a/programmer/evals/eval_repeated_edits.py b/programmer/evals/eval_repeated_edits.py index 926712c..0797713 100644 --- a/programmer/evals/eval_repeated_edits.py +++ b/programmer/evals/eval_repeated_edits.py @@ -10,7 +10,7 @@ from ..agent import AgentState, Agent from ..config import * -from ..tools import tool_context, LocalToolContext, get_current_context +from ..io_context import LocalIOContext, io_context, get_io_context # NOTES # - Try with other LLM and tool configs now that I have this test @@ -20,7 +20,7 @@ @contextmanager def tempdir(): with tempfile.TemporaryDirectory() as dir_: - with tool_context(LocalToolContext(dir_)) as tc: + with io_context(LocalIOContext(dir_)) as tc: yield tc @@ -56,7 +56,7 @@ def eval_edit_memory( f.write(prev_file_contents) task_correct = False - state = AgentState() + state = agent.initial_state(history=[]) def step6_insert_ampersands(lines): new_lines = [] @@ -156,8 +156,8 @@ def run_task( if call: call.set_display_name(f"Task{task_idx}: {task_name}") print(f"*** TASK: {task_idx}, {prompt}") - state = AgentState( - history=state.history + state = state.with_history( + state.history + [ { "role": "user", @@ -168,7 +168,7 @@ def run_task( task_info = {"task_idx": task_idx} task_correct = False attempts = [] - for attempt_idx in range(5): + for attempt_idx in range(2): attempt_result = run_attempt(config, agent, state, expected_lines, attempt_idx) attempt_info = attempt_result["attempt_info"] state = attempt_result["state"] @@ -181,8 +181,8 @@ def run_task( print() print(f"*** FAILED ATTEMPT Task: {task_idx} Attempt: {attempt_idx}") print() - state = AgentState( - history=state.history + state = state.with_history( + state.history + [ { "role": "user", @@ -214,7 +214,7 @@ def run_attempt( call = weave.get_current_call() if call: call.set_display_name(f"Attempt{attempt_idx}") - ctx = get_current_context() + ctx = get_io_context() attempt_info: dict = { "attempt_idx": attempt_idx, "correct": False, @@ -266,7 +266,7 @@ def mismatch_details(lines, file_lines): error_details.append("Incorrect edit") error_details.append("file.txt\texpected") error_details.append(f"len={len(file_lines)}\tlen={len(lines)}") - for i in range(len(lines)): + for i in range(len(max(lines, file_lines))): try: file_lines_i = file_lines[i] except IndexError: @@ -327,18 +327,21 @@ def run_single_trial(trial_idx: int): if __name__ == "__main__": weave.init("programmerdev-eval-edits1") agents = [ - agent_4omini_basic, - agent_4o_basic, - agent_claude_basic, - agent_4o_replace, - agent_claude_replace, - agent_4o_splice, - agent_claude_splice, + # agent_4omini_basic, + # agent_4o_basic, + # agent_claude_basic, + # agent_4o_replace, + # agent_claude_replace, + # agent_4o_splice, + # agent_claude_splice, + # agent_texteditor_4o_basic, + # agent_texteditor_4o_basic_temp0, + agent_texteditor_4o_basic_noparalleltc, ] - config = EvalEditMemoryConfig(n_lines=100, run_timeout_seconds=60) - n_trials = 5 - config_s = f'bugfix_promptfix_{config["n_lines"]}lines_{config["run_timeout_seconds"]}timeout' + config = EvalEditMemoryConfig(n_lines=1000, run_timeout_seconds=60) + n_trials = 10 + config_s = f'{config["n_lines"]}lines_{config["run_timeout_seconds"]}timeout' results = {} for agent in agents: run_name = f"{agent.name}_{config_s}" diff --git a/programmer/file_protocol.py b/programmer/file_protocol.py new file mode 100644 index 0000000..6bb17e5 --- /dev/null +++ b/programmer/file_protocol.py @@ -0,0 +1,6 @@ +from typing import Protocol + + +class FileSystem(Protocol): + def write_file(self, path: str, content: str) -> None: ... + def read_file(self, path: str) -> str: ... diff --git a/programmer/io_context.py b/programmer/io_context.py new file mode 100644 index 0000000..213267b --- /dev/null +++ b/programmer/io_context.py @@ -0,0 +1,163 @@ +from typing import Protocol, TypedDict +import os +import subprocess +import requests +import shlex +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Optional, Union + + +class RunCommandResult(TypedDict): + exit_code: int + output: str + + +class IOContext(Protocol): + def write_file(self, path: str, content: str) -> None: ... + + def read_file(self, path: str) -> str: ... + + def run_command(self, command: str) -> RunCommandResult: ... + + def resolve_path(self, path: str) -> str: ... + + +class LocalIOContext(IOContext): + def __init__(self, directory): + self.directory = os.path.abspath(directory) + + def write_file(self, path: str, content: str) -> None: + full_path = self.resolve_path(path) + with open(full_path, "w") as f: + f.write(content) + + def read_file(self, path: str) -> str: + full_path = self.resolve_path(path) + with open(full_path, "r") as f: + return f.read() + + def run_command(self, command: str) -> RunCommandResult: + completed_process = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + cwd=self.directory, + ) + exit_code = completed_process.returncode + output = completed_process.stdout.strip() + + return { + "exit_code": exit_code, + "output": output, + } + + def resolve_path(self, path: str) -> str: + return os.path.join(self.directory, path) + + +class RemoteContainerIOContext(IOContext): + def __init__(self, base_url: str, directory: str, command_prefix: str): + self.base_url = base_url + self.container_id = None + self.directory = directory + self.command_prefix = command_prefix + + @contextmanager + def context(self, image_id: str): + self.start_container(image_id) + try: + with io_context(self): + yield + finally: + self.stop_container() + + def start_container(self, image_id): + response = requests.post( + f"{self.base_url}/container/start", json={"image_id": image_id} + ) + if response.status_code == 200: + self.container_id = response.json().get("container_id") + else: + print(f"Failed to start container: {response.text}") + + def stop_container(self): + response = requests.post( + f"{self.base_url}/container/stop", + json={"container_id": self.container_id, "delete": True}, + ) + if response.status_code == 200: + self.container_id = None + else: + print(f"Failed to stop container: {response.text}") + + def write_file(self, path: str, content: str) -> None: + full_path = os.path.join(self.directory, path) + response = requests.post( + f"{self.base_url}/container/write_file", + json={ + "container_id": self.container_id, + "file_path": full_path, + "file_content": content, + }, + ) + if response.status_code != 200: + raise Exception(f"Failed to write file: {response.text}") + + def read_file(self, path: str) -> str: + full_path = os.path.join(self.directory, path) + response = requests.post( + f"{self.base_url}/container/read_file", + json={"container_id": self.container_id, "file_path": full_path}, + ) + if response.status_code == 200: + return response.json().get("file_content") + else: + raise Exception(f"Failed to read file: {response.text}") + + def run_command(self, command: str) -> RunCommandResult: + command = self.command_prefix + command + command = f"bash -c {shlex.quote(command)}" + response = requests.post( + f"{self.base_url}/container/run", + json={ + "container_id": self.container_id, + "workdir": self.directory, + "command": command, + }, + ) + if response.status_code == 200: + json = response.json() + return { + "exit_code": json["exit_code"], + "output": json["output"], + } + else: + raise Exception(f"Failed to run command: {response.text}") + + def resolve_path(self, path: str) -> str: + return path # For remote containers, we assume paths are already resolved + + +# Create a ContextVar to store the current ToolContext +_io_context: ContextVar[Optional[Union[LocalIOContext, RemoteContainerIOContext]]] = ( + ContextVar("_io_context", default=None) +) + + +@contextmanager +def io_context(context: Union[LocalIOContext, RemoteContainerIOContext]): + token = _io_context.set(context) + try: + yield context + finally: + _io_context.reset(token) + + +def get_io_context() -> Union[LocalIOContext, RemoteContainerIOContext]: + context = _io_context.get() + if context is None: + return LocalIOContext(".") + return context diff --git a/programmer/programmer.py b/programmer/programmer.py index 59352fe..403d1d6 100644 --- a/programmer/programmer.py +++ b/programmer/programmer.py @@ -14,6 +14,7 @@ from .console import Console from .config import ( agent_4o_replace, + agent_texteditor_4o_basic, ) from .environment import ( environment_session, @@ -27,6 +28,8 @@ from .git import GitRepo +agent = agent_texteditor_4o_basic + @weave.op def get_user_input(): @@ -41,18 +44,13 @@ def user_input_step(state: AgentState) -> AgentState: # if ref: # print("state ref:", ref.uri()) user_input = get_user_input() - environment = get_current_environment() history = state.history + [ { "role": "user", "content": user_input, } ] - msg = get_commit_message(history) - return AgentState( - history=history, - env_snapshot_key=environment.make_snapshot(msg), - ) + return state.with_history(history) def make_environment(): @@ -74,12 +72,9 @@ def session(agent: Agent, agent_state: AgentState): session_id = call.id env = make_environment() - msg = get_commit_message(agent_state.history) with environment_session(env, session_id): - agent_state = AgentState( - history=agent_state.history, env_snapshot_key=env.make_snapshot(msg) - ) + agent_state = agent_state.with_history(agent_state.history) while True: result = agent.run(agent_state) agent_state = result["state"] @@ -155,16 +150,16 @@ def programmer(): else: initial_prompt = input("Initial prompt: ") - state = AgentState( - history=[ + state = agent.initial_state( + [ { "role": "user", "content": initial_prompt, }, - ], + ] ) - session(agent_4o_replace, state) + session(agent_texteditor_4o_basic, state) def main(): diff --git a/programmer/swebench/score.py b/programmer/swebench/score.py index e9cd799..2753b7c 100644 --- a/programmer/swebench/score.py +++ b/programmer/swebench/score.py @@ -10,12 +10,12 @@ SWEbenchInstance, ) -from ..tools import RemoteContainerToolContext +from ..io_context import RemoteContainerIOContext def score_swebench(instance: SWEbenchInstance, model_output): patch = model_output["answer"] - tc = RemoteContainerToolContext( + tc = RemoteContainerIOContext( "http://localhost:8000", "/testbed", "source /opt/miniconda3/bin/activate && conda activate testbed && ", diff --git a/programmer/swebench/swebench_model.py b/programmer/swebench/swebench_model.py index 74dc0b8..30d7c54 100644 --- a/programmer/swebench/swebench_model.py +++ b/programmer/swebench/swebench_model.py @@ -1,7 +1,7 @@ import weave from ..agent import Agent, AgentState -from ..tools import RemoteContainerToolContext +from ..io_context import RemoteContainerIOContext class SWEBenchProgrammerModel(weave.Model): @@ -25,7 +25,7 @@ def predict(self, instance): ], ) - tc = RemoteContainerToolContext( + tc = RemoteContainerIOContext( "http://localhost:8000", "/testbed", "source /opt/miniconda3/bin/activate && conda activate testbed && ", diff --git a/programmer/tests/test_file_line_tools.py b/programmer/tests/test_file_line_tools.py index 74f1a29..84a1c65 100644 --- a/programmer/tests/test_file_line_tools.py +++ b/programmer/tests/test_file_line_tools.py @@ -4,16 +4,15 @@ from programmer.tools import ( read_lines_from_file, splice_lines_in_file, - LocalToolContext, - tool_context, - get_current_context, + get_io_context, ) +from programmer.io_context import LocalIOContext, io_context @pytest.fixture() def tempdir_tool_context(): with TemporaryDirectory() as tmpdir: - with tool_context(LocalToolContext(tmpdir)) as tc: + with io_context(LocalIOContext(tmpdir)) as tc: yield tc diff --git a/programmer/tests/test_text_editor.py b/programmer/tests/test_text_editor.py new file mode 100644 index 0000000..f46bced --- /dev/null +++ b/programmer/tests/test_text_editor.py @@ -0,0 +1,167 @@ +import pytest +from tempfile import TemporaryDirectory +from programmer.text_editor import ( + TextEditor, + TextEditorState, + OpenFileState, + LineRange, + OpenFileResult, + WriteFileResult, + TextEditorMutationResult, +) +from programmer.io_context import LocalIOContext, io_context + + +@pytest.fixture() +def tempdir_tool_context(): + with TemporaryDirectory() as tmpdir: + with io_context(LocalIOContext(tmpdir)) as tc: + yield tc + + +@pytest.fixture() +def sample_file(tempdir_tool_context): + file_path = "sample.txt" + content = "\n".join(f"Line {i}" for i in range(1, 201)) # 200 lines + tempdir_tool_context.write_file(file_path, content) + return file_path + + +@pytest.fixture() +def text_editor(tempdir_tool_context): + return TextEditor(max_open_size=150, open_chunk_size=50) + + +@pytest.fixture() +def initial_state(): + return TextEditorState() + + +def test_open_file(text_editor, sample_file, initial_state): + result = text_editor.open_file(initial_state, sample_file, 1) + assert isinstance(result, TextEditorMutationResult) + assert isinstance(result.action_result, OpenFileResult) + assert result.action_result.success + assert sample_file in result.new_state.open_files + assert ( + result.new_state.open_files[sample_file].total_lines() == 50 + ) # OPEN_CHUNK_SIZE + + +def test_open_file_exceed_max_size(tempdir_tool_context, sample_file): + text_editor = TextEditor(max_open_size=75, open_chunk_size=50) + initial_state = TextEditorState() + + # Open the file once (50 lines) + result1 = text_editor.open_file(initial_state, sample_file, 1) + assert result1.action_result.success + + # Try to open another chunk, which would exceed the max_open_size + result2 = text_editor.open_file(result1.new_state, sample_file, 50) + assert isinstance(result2.action_result, OpenFileResult) + assert not result2.action_result.success + assert "exceeding the maximum" in result2.action_result.error + + +def test_open_file_at_boundary(tempdir_tool_context, sample_file): + text_editor = TextEditor(max_open_size=100, open_chunk_size=50) + initial_state = TextEditorState() + + # Open exactly MAX_OPEN_SIZE lines + result1 = text_editor.open_file(initial_state, sample_file, 1) + result2 = text_editor.open_file(result1.new_state, sample_file, 51) + assert result1.action_result.success and result2.action_result.success + assert result2.new_state.total_lines() == 100 # MAX_OPEN_SIZE + + # Try to open one more line, which should fail + result3 = text_editor.open_file(result2.new_state, sample_file, 99) + assert not result3.action_result.success + assert "exceeding the maximum" in result3.action_result.error + + +def test_replace_file_lines_at_boundary(text_editor, sample_file, initial_state): + state1 = text_editor.open_file(initial_state, sample_file, 1).new_state + state2 = text_editor.open_file(state1, sample_file, 51).new_state + state3 = text_editor.open_file(state2, sample_file, 101).new_state + + # Replace 5 lines with 5 new lines (no net change) + result = text_editor.replace_file_lines( + state3, + sample_file, + [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 5}], + ) + assert result.action_result.success + + # Try to replace 5 lines with 6 new lines (net increase of 1, should fail) + result = text_editor.replace_file_lines( + state3, + sample_file, + [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 6}], + ) + assert not result.action_result.success + assert "exceeding the maximum" in result.action_result.error + + +def test_replace_file_lines_middle(text_editor, sample_file, initial_state): + state1 = text_editor.open_file(initial_state, sample_file, 1).new_state + + # Replace 5 lines with 5 new lines (no net change) + result = text_editor.replace_file_lines( + state1, + sample_file, + [{"start_line": 5, "n_lines": 5, "lines": "A\nB\n"}], + ) + assert result.action_result.success + + # Try to replace 5 lines with 6 new lines (net increase of 1, should fail) + file_info = result.new_state.get_open_file_info() + assert file_info.open_file_buffers[sample_file].total_lines == 197 + assert len(file_info.open_file_buffers[sample_file].buffers) == 1 + buffer0 = file_info.open_file_buffers[sample_file].buffers[0] + assert buffer0.line_range.start_line == 1 + assert buffer0.line_range.n_lines == 50 + + +def test_close_file_range(text_editor, sample_file, initial_state): + state1 = text_editor.open_file(initial_state, sample_file, 1).new_state + result = text_editor.close_file_range(state1, sample_file, 1, 25) + assert result.new_state.open_files[sample_file].total_lines() == 25 + + +def test_get_open_file_info(text_editor, sample_file, initial_state): + state1 = text_editor.open_file(initial_state, sample_file, 1).new_state + info = state1.get_open_file_info() + assert sample_file in info.open_file_buffers + assert info.open_file_buffers[sample_file].total_lines == 200 + assert len(info.open_file_buffers[sample_file].buffers) == 1 + assert info.open_file_buffers[sample_file].buffers[0].line_range.start_line == 1 + assert info.open_file_buffers[sample_file].buffers[0].line_range.n_lines == 50 + + +def test_open_file_multiple_ranges(text_editor, sample_file, initial_state): + state1 = text_editor.open_file(initial_state, sample_file, 1).new_state + state2 = text_editor.open_file(state1, sample_file, 51).new_state + assert len(state2.open_files[sample_file].ranges) == 1 + assert state2.open_files[sample_file].ranges[0].start_line == 1 + assert state2.open_files[sample_file].ranges[0].n_lines == 100 + + +def test_open_file_beyond_end(text_editor, sample_file, initial_state): + result = text_editor.open_file(initial_state, sample_file, 201) + assert isinstance(result.action_result, OpenFileResult) + assert not result.action_result.success + assert "beyond the end of the file" in result.action_result.error + + +def test_open_file_at_end(text_editor, sample_file, initial_state): + result = text_editor.open_file(initial_state, sample_file, 199) + assert isinstance(result.action_result, OpenFileResult) + assert result.action_result.success + assert result.new_state.open_files[sample_file].total_lines() == 1 + + +def test_open_file_near_end(text_editor, sample_file, initial_state): + result = text_editor.open_file(initial_state, sample_file, 190) + assert isinstance(result.action_result, OpenFileResult) + assert result.action_result.success + assert result.new_state.open_files[sample_file].total_lines() == 10 diff --git a/programmer/tests/test_tool_calling.py b/programmer/tests/test_tool_calling.py new file mode 100644 index 0000000..600297a --- /dev/null +++ b/programmer/tests/test_tool_calling.py @@ -0,0 +1,94 @@ +from enum import Enum +from typing import TypedDict + +import weave + +from programmer.tool_calling import generate_json_schema + + +class Range(TypedDict): + start: int + end: int + + +@weave.op +def merge_ranges(ranges: list[Range]) -> list[Range]: + """Merge a list of ranges into a single range. + + Args: + ranges: A list of ranges to merge. + + Returns: + A list of merged ranges. + """ + return ranges + + +def test_list_of_typeddict_schema(): + schema = generate_json_schema(merge_ranges) + assert schema == { + "function": { + "description": "Merge a list of ranges into a single range.", + "name": "merge_ranges", + "parameters": { + "properties": { + "ranges": { + "description": "A list of ranges to merge.", + "type": "array", + "items": { + "type": "object", + "properties": { + "start": {"type": "integer"}, + "end": {"type": "integer"}, + }, + "required": ["start", "end"], + }, + } + }, + "required": ["ranges"], + "type": "object", + }, + }, + "type": "function", + } + + +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +@weave.op +def color_name(color: Color) -> str: + """Get the name of a color. + + Args: + color: The color to get the name of. + + Returns: + The name of the color. + """ + return color.name + + +def test_enum_schema(): + schema = generate_json_schema(color_name) + assert schema == { + "function": { + "description": "Get the name of a color.", + "name": "color_name", + "parameters": { + "properties": { + "color": { + "description": "The color to get the name of.", + "enum": [1, 2, 3], + "type": "integer", + } + }, + "required": ["color"], + "type": "object", + }, + }, + "type": "function", + } diff --git a/programmer/text_editor.py b/programmer/text_editor.py new file mode 100644 index 0000000..171bbc2 --- /dev/null +++ b/programmer/text_editor.py @@ -0,0 +1,507 @@ +from typing import Optional, Generic, TypeVar +from dataclasses import dataclass, field +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Optional, TypedDict + +import weave + +from .io_context import get_io_context + + +@dataclass(frozen=True) +class LineRange: + start_line: int + n_lines: int + + +@dataclass(frozen=True) +class OpenFileState: + # Invariant: ranges must be non-overlapping and non-adjacent + # and must be in sorted order + ranges: tuple[LineRange, ...] = field(default_factory=tuple) + + def add_range(self, range: LineRange) -> "OpenFileState": + # Create a new list of ranges + new_ranges = list(self.ranges) + + # Find the correct position to insert the new range + insert_index = 0 + for i, existing_range in enumerate(new_ranges): + if range.start_line < existing_range.start_line: + insert_index = i + break + insert_index = i + 1 + + # Insert the new range + new_ranges.insert(insert_index, range) + + # Merge overlapping or adjacent ranges + i = 0 + while i < len(new_ranges) - 1: + current_range = new_ranges[i] + next_range = new_ranges[i + 1] + + if ( + current_range.start_line + current_range.n_lines + >= next_range.start_line + ): + # Merge the ranges + merged_end = max( + current_range.start_line + current_range.n_lines, + next_range.start_line + next_range.n_lines, + ) + new_ranges[i] = LineRange( + current_range.start_line, merged_end - current_range.start_line + ) + new_ranges.pop(i + 1) + else: + i += 1 + + # Return a new OpenFileState with the updated ranges + return OpenFileState(ranges=tuple(new_ranges)) + + def subtract_range(self, range: LineRange) -> "OpenFileState": + new_ranges = [] + for existing_range in self.ranges: + if range.start_line >= existing_range.start_line + existing_range.n_lines: + # The subtracted range is after this range, keep it as is + new_ranges.append(existing_range) + elif range.start_line + range.n_lines <= existing_range.start_line: + # The subtracted range is before this range, keep it as is + new_ranges.append(existing_range) + else: + # The ranges overlap, we need to split or adjust + if range.start_line > existing_range.start_line: + # Keep the part before the subtracted range + new_ranges.append( + LineRange( + existing_range.start_line, + range.start_line - existing_range.start_line, + ) + ) + if ( + range.start_line + range.n_lines + < existing_range.start_line + existing_range.n_lines + ): + # Keep the part after the subtracted range + new_ranges.append( + LineRange( + range.start_line + range.n_lines, + (existing_range.start_line + existing_range.n_lines) + - (range.start_line + range.n_lines), + ) + ) + + return OpenFileState(ranges=tuple(new_ranges)) + + def total_lines(self) -> int: + return sum(r.n_lines for r in self.ranges) + + def is_range_open(self, start_line: int, n_lines: int) -> bool: + end_line = start_line + n_lines + for range in self.ranges: + if ( + range.start_line <= start_line + and range.start_line + range.n_lines >= end_line + ): + return True + return False + + +@dataclass(frozen=True) +class TextEditorState: + open_files: dict[str, OpenFileState] = field(default_factory=dict) + + def total_lines(self) -> int: + return sum(file.total_lines() for file in self.open_files.values()) + + def get_open_file_info(self) -> "OpenFileInfoResult": + file_io_context = get_io_context() + open_file_buffers = {} + for path, open_file in self.open_files.items(): + contents = file_io_context.read_file(path) + lines = contents.split("\n") + buffers = [] + for range in open_file.ranges: + buffer = Buffer( + line_range=range, + lines=lines[ + range.start_line - 1 : range.start_line - 1 + range.n_lines + ], + ) + buffers.append(buffer) + open_file_info = OpenFileInfo( + buffers=tuple(buffers), total_lines=len(lines) + ) + open_file_buffers[path] = open_file_info + return OpenFileInfoResult(open_file_buffers=open_file_buffers) + + +@dataclass(frozen=True) +class Buffer: + line_range: LineRange + lines: list[str] + + +@dataclass(frozen=True) +class OpenFileInfo: + buffers: tuple[Buffer, ...] = field(default_factory=tuple) + total_lines: int = 0 + + def n_lines(self) -> int: + return sum(buffer.line_range.n_lines for buffer in self.buffers) + + +@dataclass(frozen=True) +class OpenFileInfoResult: + open_file_buffers: dict[str, OpenFileInfo] = field(default_factory=dict) + + def format_for_messages(self) -> str: + lines = [ + "Visible file buffers. These are the latest states of any previously opened file ranges, and reflect the results of all prior edits." + ] + for path, open_file_info in self.open_file_buffers.items(): + lines.append(f"") + # lines.append(f"") + for buffer in open_file_info.buffers: + lines.append("") + for i, line in enumerate(buffer.lines): + lines.append(f"{buffer.line_range.start_line + i}: {line}") + lines.append("") + lines.append("") + return "\n".join(lines) + + +@dataclass(frozen=True) +class ClosedFileRange: + path: str + start_line: int + n_lines: int + + +@dataclass(frozen=True) +class OpenFileResult: + success: bool + error: str + + +@dataclass(frozen=True) +class WriteFileResult: + success: bool + error: str + + +T = TypeVar("T") + + +@dataclass(frozen=True) +class TextEditorMutationResult(Generic[T]): + new_state: TextEditorState + action_result: T + + +class LineRangeReplacement(TypedDict): + start_line: int + n_lines: int + lines: str + + +class TextEditor: + def __init__( + self, + max_open_size: int = 1500, + open_chunk_size: int = 500, + ): + self.MAX_OPEN_SIZE = max_open_size + self.OPEN_CHUNK_SIZE = open_chunk_size + + def open_file( + self, state: TextEditorState, path: str, start_line: int + ) -> TextEditorMutationResult[OpenFileResult]: + file_io_context = get_io_context() + try: + file_contents = file_io_context.read_file(path) + except FileNotFoundError: + return TextEditorMutationResult( + new_state=state, + action_result=OpenFileResult(success=False, error="File not found"), + ) + + file_lines = file_contents.split("\n") + file_lines_count = len(file_lines) + + if start_line < 1: + return TextEditorMutationResult( + new_state=state, + action_result=OpenFileResult( + success=False, + error=f"Start line {start_line} is before the start of the file.", + ), + ) + + if start_line - 1 >= file_lines_count: + return TextEditorMutationResult( + new_state=state, + action_result=OpenFileResult( + success=False, + error=f"Start line {start_line} is beyond the end of the file (which has {file_lines_count} lines).", + ), + ) + + orig_open_file_state = state.open_files.get(path, OpenFileState()) + new_buffer = LineRange( + start_line, min(self.OPEN_CHUNK_SIZE, file_lines_count - start_line) + ) + new_open_file_state = orig_open_file_state.add_range(new_buffer) + added_lines = ( + new_open_file_state.total_lines() - orig_open_file_state.total_lines() + ) + + if state.total_lines() + added_lines > self.MAX_OPEN_SIZE: + return TextEditorMutationResult( + new_state=state, + action_result=OpenFileResult( + success=False, + error=f"This request would result in {state.total_lines() + added_lines} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.", + ), + ) + + new_open_files = dict(state.open_files) + new_open_files[path] = new_open_file_state + new_state = TextEditorState(open_files=new_open_files) + + return TextEditorMutationResult( + new_state=new_state, + action_result=OpenFileResult(success=True, error=""), + ) + + def close_file_range( + self, state: TextEditorState, path: str, start_line: int, n_lines: int + ) -> TextEditorMutationResult[None]: + open_file_state = state.open_files[path] + new_open_file_state = open_file_state.subtract_range( + LineRange(start_line, n_lines) + ) + + new_open_files = dict(state.open_files) + if new_open_file_state.total_lines() == 0: + del new_open_files[path] + else: + new_open_files[path] = new_open_file_state + + new_state = TextEditorState(open_files=new_open_files) + return TextEditorMutationResult(new_state=new_state, action_result=None) + + def replace_file_lines( + self, + state: TextEditorState, + path: str, + replacements: list[LineRangeReplacement], + ) -> TextEditorMutationResult[WriteFileResult]: + file_io_context = get_io_context() + + # Check if the file is open + open_file_state = state.open_files.get(path) + if not open_file_state: + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult( + success=False, + error=f"The file {path} is not open.", + ), + ) + + # Check if all ranges are open + missing_ranges = [] + for replacement in replacements: + if not open_file_state.is_range_open( + replacement["start_line"], replacement["n_lines"] + ): + missing_ranges.append(replacement) + if missing_ranges: + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult( + success=False, + error=f"The following ranges are not open: {missing_ranges}", + ), + ) + + # Sort replacements by start line + replacements.sort(key=lambda x: x["start_line"]) + + # Ensure replacements are non-overlapping + for i in range(len(replacements) - 1): + if ( + replacements[i]["start_line"] + replacements[i]["n_lines"] + > replacements[i + 1]["start_line"] + ): + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult( + success=False, + error=f"The following replacements are overlapping: {replacements[i]}, {replacements[i+1]}", + ), + ) + + all_new_lines = [l["lines"].rstrip("\n").split("\n") for l in replacements] + + net_change = sum(len(l) for l in all_new_lines) - sum( + l["n_lines"] for l in replacements + ) + if state.total_lines() + net_change > self.MAX_OPEN_SIZE: + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult( + success=False, + error=f"This edit would result in {state.total_lines() + net_change} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.", + ), + ) + + file_io_context = get_io_context() + try: + file_contents = file_io_context.read_file(path) + file_lines = file_contents.split("\n") + except Exception as e: + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult( + success=False, + error=f"Failed to write to file: {str(e)}", + ), + ) + + # Apply replacements in reverse order to indexes don't change while iterating + for i, replacement in reversed(list(enumerate(replacements))): + start_line = replacement["start_line"] + n_lines = replacement["n_lines"] + file_lines[start_line - 1 : start_line - 1 + n_lines] = all_new_lines[i] + + new_contents = "\n".join(file_lines) + + file_io_context.write_file(path, new_contents) + return TextEditorMutationResult( + new_state=state, + action_result=WriteFileResult(success=True, error=""), + ) + + +class TextEditorStateful: + def __init__(self, text_editor: TextEditor, initial_state: TextEditorState): + self.text_editor = text_editor + self.state = initial_state + + def open_file(self, path: str, start_line: int) -> OpenFileResult: + result = self.text_editor.open_file(self.state, path, start_line) + self.state = result.new_state + return result.action_result + + def close_file_range(self, path: str, start_line: int, n_lines: int) -> None: + result = self.text_editor.close_file_range( + self.state, path, start_line, n_lines + ) + self.state = result.new_state + return result.action_result + + def replace_file_lines( + self, + path: str, + replacements: list[LineRangeReplacement], + ) -> WriteFileResult: + result = self.text_editor.replace_file_lines(self.state, path, replacements) + self.state = result.new_state + return result.action_result + + +_text_editor_context: ContextVar[Optional[TextEditorStateful]] = ContextVar( + "_text_editor_context", default=None +) + + +@contextmanager +def text_editor(context: TextEditorStateful): + token = _text_editor_context.set(context) + try: + yield context + finally: + _text_editor_context.reset(token) + + +def require_text_editor() -> TextEditorStateful: + context = _text_editor_context.get() + assert context is not None + return context + + +@weave.op +def open_file(path: str, start_line: int) -> str: + """Open a buffer of lines from the given file. + + Args: + path: The path to the file. + start_line: The line number to start reading from (1-indexed). + + Returns: + "success" if the file was opened successfully, + "error: " if the file was not opened successfully. + """ + text_editor = require_text_editor() + response = text_editor.open_file(path, start_line) + if response.success: + return "success" + else: + return f"error: {response.error}" + + +@weave.op +def close_file_range(path: str, start_line: int, n_lines: int) -> str: + """Close a buffer of lines from the given file. + + Args: + path: The path to the file. + start_line: The line number to start reading from (1-indexed). + n_lines: The number of lines to close. + + Returns: + "success" if the file was closed successfully. + """ + text_editor = require_text_editor() + response = text_editor.close_file_range(path, start_line, n_lines) + return "success" + + +class LineRangeReplacementStartEnd(TypedDict): + start_line: int + remove_up_to_line: int + lines: str + + +@weave.op +def replace_file_lines( + path: str, replacements: list[LineRangeReplacementStartEnd] +) -> str: + """Replace ranges of lines within a file. Changes must be made to open ranges, and will be reflected immediately on the filesystem. First, existing lines are removed starting at start line, up to but not including replace_up_to_line. Then the new lines are added in that position. + + Args: + path: The path to the file. + replacements: A list of replacements to make. Each replacement is a dictionary with keys: start_line (1-indexed, inclusive), remove_up_to_line (1-indexed, exclusive), lines (a string of newline separated lines to insert) + + Returns: + "success" if the file was replaced successfully, + "error: " if the file was not replaced successfully. + """ + text_editor = require_text_editor() + replacements_list = [ + LineRangeReplacement( + start_line=r["start_line"], + n_lines=r["remove_up_to_line"] - r["start_line"], + lines=r["lines"], + ) + for r in replacements + ] + response = text_editor.replace_file_lines(path, replacements_list) + if response.success: + return "success" + else: + return f"error: {response.error}" diff --git a/programmer/tool_calling.py b/programmer/tool_calling.py index c9a9f50..3207ebc 100644 --- a/programmer/tool_calling.py +++ b/programmer/tool_calling.py @@ -1,12 +1,53 @@ import inspect import json -from typing import Callable, get_type_hints +import traceback +import typing_extensions + +from typing import Any, Callable, get_type_hints, TypedDict from openai.types.chat import ChatCompletionMessageToolCall, ChatCompletionToolParam from .console import Console +class TypedDictLike: + __required_keys__: frozenset[str] + + +def is_typed_dict_like(t: type) -> typing_extensions.TypeGuard[TypedDictLike]: + return hasattr(t, "__required_keys__") + + +def pytype_to_jsonschema(pytype: Any) -> dict: + if pytype.__name__ == "str": + return {"type": "string"} + elif pytype.__name__ == "int": + return {"type": "integer"} + elif is_typed_dict_like(pytype): + return { + "type": "object", + "properties": { + k: pytype_to_jsonschema(v) for k, v in pytype.__annotations__.items() + }, + "required": list(pytype.__annotations__.keys()), + } + elif pytype.__name__ == "list": + return {"type": "array", "items": pytype_to_jsonschema(pytype.__args__[0])} + elif hasattr(pytype, "__members__"): + member_types = [ + pytype_to_jsonschema(type(v.value)) for v in pytype.__members__.values() + ] + t0 = member_types[0] + for t in member_types[1:]: + if t != t0: + raise ValueError("All member types must be the same") + mem_type = t0["type"] + if mem_type != "string" and mem_type != "integer": + raise ValueError(f"Enum member type {mem_type} is not supported") + return {"type": mem_type, "enum": [e.value for e in pytype]} + raise ValueError(f"Unsupported type: {pytype.__name__}") + + def generate_json_schema(func: Callable) -> dict: """Given a function, generate an OpenAI tool compatible JSON schema. @@ -40,27 +81,21 @@ def generate_json_schema(func: Callable) -> dict: is_required = param.default == inspect.Parameter.empty # Extract parameter type and description - param_type = type_hints[name].__name__ if name in type_hints else "string" - if param_type == "str": - param_type = "string" - elif param_type == "int": - param_type = "integer" - param_desc = "" + param_schema = pytype_to_jsonschema(type_hints[name]) # Attempt to extract description from docstring + param_desc = "" if func.__doc__: doc_lines = func.__doc__.split("\n")[1:] for line in doc_lines: if name in line: param_desc = line.strip().split(":")[-1].strip() break - - # Populate schema for this parameter - param_schema = {"type": param_type, "description": param_desc} - - # Handle special case for enums - if hasattr(type_hints[name], "__members__"): # Check if it's an Enum - param_schema["enum"] = [e.value for e in type_hints[name]] + if not param_desc: + raise ValueError( + f"Function {func.__name__} description for parameter {name} is missing" + ) + param_schema["description"] = param_desc schema["function"]["parameters"]["properties"][name] = param_schema # type: ignore @@ -96,12 +131,15 @@ def perform_tool_calls( try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: - function_response = str(e) + print(f"Tool call {tool_call_s} failed to parse arguments: {e}") + function_response = f"Argument parse error: {str(e)}" if not function_response: try: function_response = tool(**function_args) except Exception as e: - function_response = str(e) + print(f"Error occurred in tool {function_name}:") + traceback.print_exc() + function_response = f"Error: {str(e)}" additional_message = None if isinstance(function_response, tuple): diff --git a/programmer/tools.py b/programmer/tools.py index 44e155b..fb49272 100644 --- a/programmer/tools.py +++ b/programmer/tools.py @@ -1,14 +1,8 @@ import base64 -import json import os -import subprocess import weave -import contextlib -import shlex -from contextvars import ContextVar -from contextlib import contextmanager -from typing import Protocol, Union, TypedDict, Optional -import requests + +from .io_context import get_io_context LENGTH_LIMIT = 30000 @@ -17,161 +11,6 @@ # - must return FileNotFoundError in read_file in Remote -class RunCommandResult(TypedDict): - exit_code: int - output: str - - -class ToolContext(Protocol): - def write_file(self, path: str, content: str) -> None: ... - - def read_file(self, path: str) -> str: ... - - def run_command(self, command: str) -> RunCommandResult: ... - - def resolve_path(self, path: str) -> str: ... - - -class LocalToolContext(ToolContext): - def __init__(self, directory): - self.directory = os.path.abspath(directory) - - def write_file(self, path: str, content: str) -> None: - full_path = self.resolve_path(path) - with open(full_path, "w") as f: - f.write(content) - - def read_file(self, path: str) -> str: - full_path = self.resolve_path(path) - with open(full_path, "r") as f: - return f.read() - - def run_command(self, command: str) -> RunCommandResult: - completed_process = subprocess.run( - command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - cwd=self.directory, - ) - exit_code = completed_process.returncode - output = completed_process.stdout.strip() - - return { - "exit_code": exit_code, - "output": output, - } - - def resolve_path(self, path: str) -> str: - return os.path.join(self.directory, path) - - -class RemoteContainerToolContext(ToolContext): - def __init__(self, base_url: str, directory: str, command_prefix: str): - self.base_url = base_url - self.container_id = None - self.directory = directory - self.command_prefix = command_prefix - - @contextmanager - def context(self, image_id: str): - self.start_container(image_id) - try: - with tool_context(self): - yield - finally: - self.stop_container() - - def start_container(self, image_id): - response = requests.post( - f"{self.base_url}/container/start", json={"image_id": image_id} - ) - if response.status_code == 200: - self.container_id = response.json().get("container_id") - else: - print(f"Failed to start container: {response.text}") - - def stop_container(self): - response = requests.post( - f"{self.base_url}/container/stop", - json={"container_id": self.container_id, "delete": True}, - ) - if response.status_code == 200: - self.container_id = None - else: - print(f"Failed to stop container: {response.text}") - - def write_file(self, path: str, content: str) -> None: - full_path = os.path.join(self.directory, path) - response = requests.post( - f"{self.base_url}/container/write_file", - json={ - "container_id": self.container_id, - "file_path": full_path, - "file_content": content, - }, - ) - if response.status_code != 200: - raise Exception(f"Failed to write file: {response.text}") - - def read_file(self, path: str) -> str: - full_path = os.path.join(self.directory, path) - response = requests.post( - f"{self.base_url}/container/read_file", - json={"container_id": self.container_id, "file_path": full_path}, - ) - if response.status_code == 200: - return response.json().get("file_content") - else: - raise Exception(f"Failed to read file: {response.text}") - - def run_command(self, command: str) -> RunCommandResult: - command = self.command_prefix + command - command = f"bash -c {shlex.quote(command)}" - response = requests.post( - f"{self.base_url}/container/run", - json={ - "container_id": self.container_id, - "workdir": self.directory, - "command": command, - }, - ) - if response.status_code == 200: - json = response.json() - return { - "exit_code": json["exit_code"], - "output": json["output"], - } - else: - raise Exception(f"Failed to run command: {response.text}") - - def resolve_path(self, path: str) -> str: - return path # For remote containers, we assume paths are already resolved - - -# Create a ContextVar to store the current ToolContext -current_context: ContextVar[ - Optional[Union[LocalToolContext, RemoteContainerToolContext]] -] = ContextVar("current_context", default=None) - - -@contextlib.contextmanager -def tool_context(context: Union[LocalToolContext, RemoteContainerToolContext]): - token = current_context.set(context) - try: - yield context - finally: - current_context.reset(token) - - -def get_current_context() -> Union[LocalToolContext, RemoteContainerToolContext]: - context = current_context.get() - if context is None: - return LocalToolContext(".") - return context - - def read_image_as_base64(path: str): ext = os.path.splitext(path)[1] if ext not in [".jpg", ".jpeg", ".png"]: @@ -201,7 +40,7 @@ def view_image(path: str): Returns: A message indicating that the image was displayed successfully. """ - context = get_current_context() + context = get_io_context() full_path = context.resolve_path(path) base64_image = read_image_as_base64(full_path) @@ -226,7 +65,7 @@ def list_files(directory: str) -> str: Returns: The list of files in the directory. """ - context = get_current_context() + context = get_io_context() # full_path = context.resolve_path(directory) result = context.run_command(f"ls {directory}") exit_code = result["exit_code"] @@ -252,7 +91,7 @@ def write_to_file(path: str, content: str) -> str: Returns: A message indicating whether the file was written successfully. """ - context = get_current_context() + context = get_io_context() if len(content) > LENGTH_LIMIT: content = content[:LENGTH_LIMIT] content += "\n... (truncated)" @@ -270,7 +109,7 @@ def read_from_file(path: str) -> str: Returns: The content of the file. """ - context = get_current_context() + context = get_io_context() result = context.read_file(path) if len(result) > LENGTH_LIMIT: result = result[:LENGTH_LIMIT] @@ -288,7 +127,7 @@ def run_command(command: str) -> str: Returns: The output of the command. """ - context = get_current_context() + context = get_io_context() result = context.run_command(command) exit_code = result["exit_code"] @@ -318,7 +157,7 @@ def read_lines_from_file(file_path: str, start_line: int) -> str: Raises: Exception: If the file does not exist or start_line is invalid. """ - context = get_current_context() + context = get_io_context() full_path = context.resolve_path(file_path) content = context.read_file(full_path) lines = content.splitlines() @@ -358,7 +197,7 @@ def replace_lines_in_file( Raises: Exception: If the line range is invalid or file cannot be accessed. """ - context = get_current_context() + context = get_io_context() full_path = context.resolve_path(file_path) try: content = context.read_file(full_path) @@ -423,7 +262,7 @@ def splice_lines_in_file( Raises: Exception: If the line range is invalid or file cannot be accessed. """ - context = get_current_context() + context = get_io_context() full_path = context.resolve_path(file_path) try: content = context.read_file(full_path) diff --git a/programmer/weave_next/weave_query.py b/programmer/weave_next/weave_query.py index 54c0adb..28b90b9 100644 --- a/programmer/weave_next/weave_query.py +++ b/programmer/weave_next/weave_query.py @@ -167,7 +167,7 @@ def expand_refs(wc: WeaveClient, refs: Sequence[str]): return Objs(wc, refs) -def call(wc: WeaveClient, call_id: str): +def get_call(wc: WeaveClient, call_id: str): """Return a raw Weave call.""" response = wc.server.calls_query( CallsQueryReq(