diff --git a/programmer-ui/ui.py b/programmer-ui/ui.py
index e09d7ea..4dcb55b 100644
--- a/programmer-ui/ui.py
+++ b/programmer-ui/ui.py
@@ -6,10 +6,17 @@
import streamlit as st
import weave
import os
+import openai
+import copy
from weave.trace.weave_client import WeaveClient
from programmer.weave_next.api import init_local_client
-from programmer.weave_next.weave_query import calls, expand_refs
+from programmer.weave_next.weave_query import (
+ calls,
+ expand_refs,
+ get_call,
+ expand_json_refs,
+)
from programmer.settings_manager import SettingsManager
st.set_page_config(layout="wide")
@@ -47,7 +54,46 @@ def init_from_settings() -> WeaveClient:
raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}")
-client = init_from_settings()
+# Add sidebar for Weave project configuration
+with st.sidebar:
+ st.header("Weave Project Configuration")
+
+ # Initialize from settings
+ initial_weave_logging = SettingsManager.get_setting("weave_logging")
+ initial_project_type = "local" if initial_weave_logging == "local" else "cloud"
+ initial_project_path = (
+ os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db")
+ if initial_weave_logging == "local"
+ else ""
+ )
+ initial_project_name = (
+ f"programmer-{os.path.basename(os.path.abspath(os.curdir))}"
+ if initial_weave_logging == "cloud"
+ else ""
+ )
+
+ project_type = st.radio(
+ "Project Type",
+ ["local", "cloud"],
+ index=0 if initial_project_type == "local" else 1,
+ )
+
+ if project_type == "local":
+ project_path = st.text_input("Local DB Path", value=initial_project_path)
+ # SettingsManager.set_setting("weave_logging", "local")
+ # SettingsManager.set_setting("weave_db_path", project_path)
+ client = init_local_weave(project_path)
+ print("C2", client._project_id())
+ else:
+ # SettingsManager.set_setting("weave_logging", "cloud")
+ # SettingsManager.set_setting("weave_project_name", project_name)
+ project_name = st.text_input("Cloud Project Name", value=initial_project_name)
+ client = init_remote_weave(project_name)
+ print("C3", client._project_id())
+
+# Initialize client based on current settings
+# client = init_from_settings()
+print("CLIENT", client._project_id())
def set_focus_step_id(call_id):
@@ -76,6 +122,16 @@ def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]):
return expand_refs(wc, refs).to_pandas()
+@st.cache_data(hash_funcs=ST_HASH_FUNCS)
+def cached_get_call(wc: WeaveClient, call_id: str):
+ return get_call(wc, call_id)
+
+
+@st.cache_data(hash_funcs=ST_HASH_FUNCS)
+def cached_expand_json_refs(wc: WeaveClient, json: dict):
+ return expand_json_refs(wc, json)
+
+
def print_step_call(call):
start_history = call["inputs.state.history"]
end_history = call["output.history"]
@@ -174,30 +230,234 @@ def print_session_call(session_id):
)
-session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"])
-if len(session_calls_df) == 0:
- st.error("No programmer sessions found.")
- st.stop()
-session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
- lambda v: v[-1]["content"]
-)
-
+def sessions_page():
+ session_calls_df = cached_calls(
+ client, "session", expand_refs=["inputs.agent_state"]
+ )
+ if len(session_calls_df) == 0:
+ st.error("No programmer sessions found.")
+ st.stop()
+ session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
+ lambda v: v[-1]["content"]
+ )
+ with st.sidebar:
+ st.header("Session Selection")
+ if st.button("Refresh"):
+ st.cache_data.clear()
+ st.rerun()
+ message_ids = {
+ f"{cid[-5:]}: {m}": cid
+ for cid, m in reversed(
+ list(zip(session_calls_df["id"], session_user_message_df))
+ )
+ }
+ sel_message = st.radio("Session", options=message_ids.keys())
+ sel_id = None
+ if sel_message:
+ sel_id = message_ids.get(sel_message)
+ if sel_id:
+ st.header(f"Session: {sel_id}")
+ print_session_call(sel_id)
+
+
+sessions_pg = st.Page(sessions_page, title="Sessions")
+
+
+# def write_chat_message(m, key):
+# with st.chat_message(m["role"]):
+# if "content" in m:
+# st.text_area(
+# "", value=str(m["content"]), label_visibility="collapsed", key=key
+# )
+def write_chat_message(m, key, readonly=False):
+ def on_change_content():
+ new_value = st.session_state[key]
+ st.session_state.playground_state["editable_call"]["inputs"]["messages"][
+ m["original_index"]
+ ]["content"] = new_value
+
+ with st.chat_message(m["role"]):
+ if m.get("content"):
+ if readonly:
+ st.code(m["content"])
+ else:
+ st.text_area(
+ "",
+ value=m["content"],
+ label_visibility="collapsed",
+ key=key,
+ on_change=on_change_content,
+ )
+ if m.get("tool_calls"):
+ for t in m["tool_calls"]:
+ st.write(t["function"]["name"])
+ st.json(
+ {
+ "arguments": t["function"]["arguments"],
+ "response": t.get("response", {}).get("content"),
+ },
+ expanded=True,
+ )
+
+
+def attach_tool_call_responses(messages):
+ new_messages = []
+ for i, m in enumerate(messages):
+ new_m = copy.deepcopy(m)
+ new_m["original_index"] = i
+ if new_m["role"] == "assistant" and "tool_calls" in new_m:
+ new_m["tool_call_responses"] = []
+ for t in new_m["tool_calls"]:
+ t_id = t["id"]
+ for j, t_response in enumerate(messages):
+ if t_response.get("tool_call_id") == t_id:
+ t["response"] = t_response
+ t["response"]["original_index"] = j
+ break
+ if "tool_call_id" not in new_m:
+ new_messages.append(new_m)
+ return new_messages
+
+
+def playground_page():
+ with st.sidebar:
+ if not st.session_state.get("playground_state"):
+ st.session_state.playground_state = {
+ "call_id": None,
+ "call": None,
+ "expanded_call": None,
+ "editable_call": None,
+ }
+ playground_state = st.session_state.playground_state
+ call_id = st.text_input("Call ID")
+ if not call_id:
+ st.error("Please set call ID")
+ st.stop()
+
+ # st.write(playground_state)
+ if playground_state["expanded_call"] != playground_state["editable_call"]:
+ st.warning("Call has been modified")
+ if st.button("Restore original call"):
+ st.session_state.playground_state["editable_call"] = copy.deepcopy(
+ playground_state["expanded_call"]
+ )
+ st.rerun()
+
+ if call_id != st.session_state.playground_state["call_id"]:
+ st.spinner("Loading call...")
+ call = cached_get_call(client, call_id)
+ editable_call = cached_expand_json_refs(client, call)
+ st.session_state.playground_state = {
+ "call_id": call_id,
+ "call": call,
+ "expanded_call": editable_call,
+ "editable_call": copy.deepcopy(editable_call),
+ }
+ st.rerun()
+
+ call = st.session_state.playground_state["call"]
+ editable_call = st.session_state.playground_state["editable_call"]
+ if call is None or editable_call is None:
+ st.warning("call not yet loaded")
+ st.stop()
+
+ st.write(call["op_name"])
+ # st.json(call["inputs"])
+ # st.json(call["inputs"]["tools"])
+
+ def on_change_temperature():
+ st.session_state.playground_state["editable_call"]["inputs"][
+ "temperature"
+ ] = st.session_state["temperature"]
+
+ st.slider(
+ "Temperature",
+ min_value=0.0,
+ max_value=1.0,
+ value=editable_call["inputs"]["temperature"],
+ key="temperature",
+ on_change=on_change_temperature,
+ )
-with st.sidebar:
- if st.button("Refresh"):
- st.cache_data.clear()
- st.rerun()
- message_ids = {
- f"{cid[-5:]}: {m}": cid
- for cid, m in reversed(
- list(zip(session_calls_df["id"], session_user_message_df))
+ tools = call["inputs"].get("tools", [])
+ if tools:
+ st.write("Tools")
+ for tool_idx, t in enumerate(tools):
+ with st.expander(t["function"]["name"]):
+
+ def on_change_tool():
+ st.session_state.playground_state["editable_call"]["inputs"][
+ "tools"
+ ][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"])
+ st.rerun()
+
+ st.text_area(
+ "json",
+ value=json.dumps(t, indent=2),
+ height=300,
+ key=f"tool-{tool_idx}",
+ on_change=on_change_tool,
+ )
+
+ def on_change_parallel_tool_calls():
+ st.session_state.playground_state["editable_call"]["inputs"][
+ "parallel_tool_calls"
+ ] = st.session_state["parallel_tool_calls"]
+
+ st.checkbox(
+ "Parallel tool calls",
+ value=editable_call["inputs"].get("parallel_tool_calls", True),
+ key="parallel_tool_calls",
+ on_change=on_change_parallel_tool_calls,
)
+
+ inputs = editable_call["inputs"]
+ all_input_messages = inputs["messages"]
+ other_inputs = {
+ k: v
+ for k, v in inputs.items()
+ if (k != "messages" and k != "self" and k != "stream")
}
- sel_message = st.radio("Session", options=message_ids.keys())
- sel_id = None
- if sel_message:
- sel_id = message_ids.get(sel_message)
-
-if sel_id:
- st.header(f"Session: {sel_id}")
- print_session_call(sel_id)
+
+ tool_call_attached_messages = attach_tool_call_responses(all_input_messages)
+ for i, m in enumerate(tool_call_attached_messages):
+ write_chat_message(m, f"message-{i}")
+ # output = editable_call["output"]["choices"][0]["message"]
+ n_choices = st.number_input(
+ "Number of choices", value=1, min_value=1, max_value=100
+ )
+ if st.button("Generate"):
+ chat_inputs = {**editable_call["inputs"]}
+ # st.json(chat_inputs, expanded=False)
+ del chat_inputs["stream"]
+ del chat_inputs["self"]
+ chat_inputs["n"] = n_choices
+ call_resp = openai.chat.completions.create(**chat_inputs).model_dump()
+
+ editable_call["output"] = call_resp
+ st.rerun()
+ # st.json(response, expanded=False)
+ # output = response["choices"][0]["message"]
+ # st.json(output)
+ response = editable_call["output"]
+ st.write("full response")
+ st.json(response, expanded=False)
+ st.write("**system fingerprint**", response["system_fingerprint"])
+ st.write("**usage**", response["usage"])
+ for i, choice in enumerate(response["choices"]):
+ output = choice["message"]
+ st.write(f"Choice {i+1}")
+ write_chat_message(output, f"output_message-{i}", readonly=True)
+
+ # all_messages = [*all_input_messages, output]
+ # st.json(st.session_state.playground_state, expanded=False)
+ # st.json(all_messages, expanded=False)
+
+ # st.write(expanded_call)
+
+
+playground_pg = st.Page(playground_page, title="Playground")
+
+
+pg = st.navigation([sessions_pg, playground_pg])
+pg.run()
diff --git a/programmer/agent.py b/programmer/agent.py
index cd180ac..7c026c9 100644
--- a/programmer/agent.py
+++ b/programmer/agent.py
@@ -39,6 +39,12 @@ class AgentState(weave.Object):
history: list[Any] = Field(default_factory=list)
env_snapshot_key: Optional[EnvironmentSnapshotKey] = None
+ def with_history(self, history: list[Any]) -> "AgentState":
+ environment = get_current_environment()
+ msg = get_commit_message(history)
+ snapshot_key = environment.make_snapshot(msg)
+ return self.__class__(history=history, env_snapshot_key=snapshot_key)
+
def unweavify(v: Any) -> Any:
if isinstance(v, list):
@@ -55,6 +61,9 @@ class Agent(weave.Object):
system_message: str
tools: list[Any] = Field(default_factory=list)
+ def initial_state(self, history: list[Any]) -> AgentState:
+ return AgentState().with_history(history)
+
@weave.op()
def step(self, state: AgentState) -> AgentState:
"""Run a step of the agent.
@@ -118,12 +127,8 @@ def step(self, state: AgentState) -> AgentState:
# new_history = state.history + new_messages
new_history = weavelist_add(state.history, new_messages)
- msg = get_commit_message(new_history)
-
- environment = get_current_environment()
- snapshot_key = environment.make_snapshot(msg)
- return AgentState(history=new_history, env_snapshot_key=snapshot_key)
+ return state.with_history(new_history)
@weave.op()
def run(self, state: AgentState, max_runtime_seconds: int = -1):
diff --git a/programmer/agent_texteditor.py b/programmer/agent_texteditor.py
new file mode 100644
index 0000000..599ca43
--- /dev/null
+++ b/programmer/agent_texteditor.py
@@ -0,0 +1,165 @@
+from typing import Any, Union
+from pydantic import Field
+import litellm
+from openai.types.chat import (
+ ChatCompletionMessageParam,
+)
+
+import weave
+from weave.trace.vals import WeaveList
+from weave.flow.chat_util import OpenAIStream
+
+from .console import Console
+from .tool_calling import chat_call_tool_params, perform_tool_calls
+from .text_editor import (
+ TextEditor,
+ TextEditorState,
+ TextEditorStateful,
+ open_file,
+ close_file_range,
+ replace_file_lines,
+ text_editor,
+)
+from .agent import AgentState, Agent
+
+
+# Weave bug workaround: adding two WeaveLists can create that cause
+# downstream crashes.
+# Can be removed after https://github.com/wandb/weave/pull/2165 is merged.
+def weavelist_add(self: Union[list, WeaveList], other: list) -> Union[list, WeaveList]:
+ if isinstance(self, list):
+ return self + other
+ if not isinstance(other, list):
+ return NotImplemented
+ return WeaveList(list(self) + other, server=self.server)
+
+
+class AgentStateTextEditor(AgentState):
+ text_editor_state: TextEditorState = Field(default_factory=TextEditorState)
+
+ def with_history(self, history: list[Any]) -> "AgentStateTextEditor":
+ next_state = super().with_history(history)
+ return AgentStateTextEditor(
+ history=next_state.history,
+ env_snapshot_key=next_state.env_snapshot_key,
+ text_editor_state=self.text_editor_state,
+ )
+
+ def with_texteditor_state(
+ self, text_editor_state: TextEditorState
+ ) -> "AgentStateTextEditor":
+ return AgentStateTextEditor(
+ history=self.history,
+ env_snapshot_key=self.env_snapshot_key,
+ text_editor_state=text_editor_state,
+ )
+
+
+def unweavify(v: Any) -> Any:
+ if isinstance(v, list):
+ return [unweavify(m) for m in v]
+ elif isinstance(v, dict):
+ return {k: unweavify(v) for k, v in v.items()}
+ else:
+ return v
+
+
+class AgentTextEditor(Agent):
+ parallel_tool_calls: bool = True
+ text_editor: TextEditor
+
+ def initial_state(self, history: list[Any]) -> AgentStateTextEditor:
+ return AgentStateTextEditor(history=history)
+
+ @weave.op()
+ def step(self, state: AgentStateTextEditor) -> AgentStateTextEditor:
+ """Run a step of the agent.
+
+ Args:
+ state: The current state of the environment.
+ action: The action to take.
+
+ Returns:
+ The new state of the environment.
+ """
+ Console.step_start("agent", "green")
+ # Printing this is ugly
+ # ref = weave.obj_ref(state)
+ # if ref:
+ # print("state ref:", ref.uri())
+
+ messages: list[ChatCompletionMessageParam] = [
+ {"role": "system", "content": self.system_message},
+ ]
+ open_file_info = state.text_editor_state.get_open_file_info()
+
+ messages.append(
+ {
+ "role": "system",
+ "content": open_file_info.format_for_messages(),
+ }
+ )
+
+ messages += state.history
+
+ messages.append(
+ {
+ "role": "system",
+ "content": open_file_info.format_for_messages(),
+ }
+ )
+
+ self_tools = [*self.tools] or []
+
+ text_editor_stateful = TextEditorStateful(
+ self.text_editor, state.text_editor_state
+ )
+
+ # self_tools += [open_file, close_file_range, replace_file_lines]
+ self_tools += [open_file, replace_file_lines]
+
+ # make type checkers happy by passing NotGiven instead of None
+ tools = None
+ if self_tools:
+ tools = chat_call_tool_params(self_tools)
+
+ Console.chat_response_start()
+
+ # Workaround a weave bug, litellm tries to deepcopy messages which has
+ # a TraceDict. TraceDict is not pickable, because it has a reference to
+ # a weave server, which has a lock.
+ messages = unweavify(messages)
+
+ stream = litellm.completion(
+ model=self.model_name,
+ temperature=self.temperature,
+ messages=messages,
+ tools=tools,
+ stream=True,
+ timeout=60,
+ parallel_tool_calls=self.parallel_tool_calls,
+ )
+ wrapped_stream = OpenAIStream(stream) # type: ignore
+ for chunk in wrapped_stream:
+ if chunk.choices[0].delta.content:
+ Console.chat_message_content_delta(chunk.choices[0].delta.content)
+
+ response = wrapped_stream.final_response()
+ response_message = response.choices[0].message
+ if response_message.content:
+ Console.chat_response_complete(response_message.content)
+
+ new_messages = []
+ # we always store the dict representations of messages in agent state
+ # instead of mixing in some pydantic objects.
+ new_messages.append(response_message.model_dump(exclude_none=True))
+ if response_message.tool_calls:
+ with text_editor(text_editor_stateful):
+ new_messages.extend(
+ perform_tool_calls(self_tools, response_message.tool_calls)
+ )
+ new_history = weavelist_add(state.history, new_messages)
+
+ next_state = state.with_history(new_history)
+ next_state = next_state.with_texteditor_state(text_editor_stateful.state)
+ return next_state
diff --git a/programmer/config.py b/programmer/config.py
index 9b7cd56..fb81ea8 100644
--- a/programmer/config.py
+++ b/programmer/config.py
@@ -15,6 +15,8 @@
splice_lines_in_file,
)
from .agent import Agent
+from .agent_texteditor import AgentTextEditor
+from .text_editor import TextEditor
agent_4o_basic = Agent(
name="gpt-4o-2024-08-06_basic",
@@ -96,3 +98,32 @@
splice_lines_in_file,
],
)
+
+text_editor = TextEditor(max_open_size=15000, open_chunk_size=2000)
+agent_texteditor_4o_basic = AgentTextEditor(
+ name="gpt-4o-2024-08-06_texteditor_basic",
+ model_name="gpt-4o-2024-08-06",
+ temperature=0.7,
+ system_message=SYSTEM_MESSAGE,
+ text_editor=text_editor,
+ tools=[list_files, run_command, view_image],
+)
+
+agent_texteditor_4o_basic_temp0 = AgentTextEditor(
+ name="gpt-4o-2024-08-06_texteditor_basic_temp0",
+ model_name="gpt-4o-2024-08-06",
+ temperature=0.0,
+ system_message=SYSTEM_MESSAGE,
+ text_editor=text_editor,
+ tools=[list_files, run_command, view_image],
+)
+
+agent_texteditor_4o_basic_noparalleltc = AgentTextEditor(
+ name="gpt-4o-2024-08-06_texteditor_basic_noparalleltc",
+ model_name="gpt-4o-2024-08-06",
+ temperature=0.7,
+ system_message=SYSTEM_MESSAGE,
+ text_editor=text_editor,
+ tools=[list_files, run_command, view_image],
+ parallel_tool_calls=False,
+)
diff --git a/programmer/evals/eval_repeated_edits.py b/programmer/evals/eval_repeated_edits.py
index 926712c..0797713 100644
--- a/programmer/evals/eval_repeated_edits.py
+++ b/programmer/evals/eval_repeated_edits.py
@@ -10,7 +10,7 @@
from ..agent import AgentState, Agent
from ..config import *
-from ..tools import tool_context, LocalToolContext, get_current_context
+from ..io_context import LocalIOContext, io_context, get_io_context
# NOTES
# - Try with other LLM and tool configs now that I have this test
@@ -20,7 +20,7 @@
@contextmanager
def tempdir():
with tempfile.TemporaryDirectory() as dir_:
- with tool_context(LocalToolContext(dir_)) as tc:
+ with io_context(LocalIOContext(dir_)) as tc:
yield tc
@@ -56,7 +56,7 @@ def eval_edit_memory(
f.write(prev_file_contents)
task_correct = False
- state = AgentState()
+ state = agent.initial_state(history=[])
def step6_insert_ampersands(lines):
new_lines = []
@@ -156,8 +156,8 @@ def run_task(
if call:
call.set_display_name(f"Task{task_idx}: {task_name}")
print(f"*** TASK: {task_idx}, {prompt}")
- state = AgentState(
- history=state.history
+ state = state.with_history(
+ state.history
+ [
{
"role": "user",
@@ -168,7 +168,7 @@ def run_task(
task_info = {"task_idx": task_idx}
task_correct = False
attempts = []
- for attempt_idx in range(5):
+ for attempt_idx in range(2):
attempt_result = run_attempt(config, agent, state, expected_lines, attempt_idx)
attempt_info = attempt_result["attempt_info"]
state = attempt_result["state"]
@@ -181,8 +181,8 @@ def run_task(
print()
print(f"*** FAILED ATTEMPT Task: {task_idx} Attempt: {attempt_idx}")
print()
- state = AgentState(
- history=state.history
+ state = state.with_history(
+ state.history
+ [
{
"role": "user",
@@ -214,7 +214,7 @@ def run_attempt(
call = weave.get_current_call()
if call:
call.set_display_name(f"Attempt{attempt_idx}")
- ctx = get_current_context()
+ ctx = get_io_context()
attempt_info: dict = {
"attempt_idx": attempt_idx,
"correct": False,
@@ -266,7 +266,7 @@ def mismatch_details(lines, file_lines):
error_details.append("Incorrect edit")
error_details.append("file.txt\texpected")
error_details.append(f"len={len(file_lines)}\tlen={len(lines)}")
- for i in range(len(lines)):
+ for i in range(len(max(lines, file_lines))):
try:
file_lines_i = file_lines[i]
except IndexError:
@@ -327,18 +327,21 @@ def run_single_trial(trial_idx: int):
if __name__ == "__main__":
weave.init("programmerdev-eval-edits1")
agents = [
- agent_4omini_basic,
- agent_4o_basic,
- agent_claude_basic,
- agent_4o_replace,
- agent_claude_replace,
- agent_4o_splice,
- agent_claude_splice,
+ # agent_4omini_basic,
+ # agent_4o_basic,
+ # agent_claude_basic,
+ # agent_4o_replace,
+ # agent_claude_replace,
+ # agent_4o_splice,
+ # agent_claude_splice,
+ # agent_texteditor_4o_basic,
+ # agent_texteditor_4o_basic_temp0,
+ agent_texteditor_4o_basic_noparalleltc,
]
- config = EvalEditMemoryConfig(n_lines=100, run_timeout_seconds=60)
- n_trials = 5
- config_s = f'bugfix_promptfix_{config["n_lines"]}lines_{config["run_timeout_seconds"]}timeout'
+ config = EvalEditMemoryConfig(n_lines=1000, run_timeout_seconds=60)
+ n_trials = 10
+ config_s = f'{config["n_lines"]}lines_{config["run_timeout_seconds"]}timeout'
results = {}
for agent in agents:
run_name = f"{agent.name}_{config_s}"
diff --git a/programmer/file_protocol.py b/programmer/file_protocol.py
new file mode 100644
index 0000000..6bb17e5
--- /dev/null
+++ b/programmer/file_protocol.py
@@ -0,0 +1,6 @@
+from typing import Protocol
+
+
+class FileSystem(Protocol):
+ def write_file(self, path: str, content: str) -> None: ...
+ def read_file(self, path: str) -> str: ...
diff --git a/programmer/io_context.py b/programmer/io_context.py
new file mode 100644
index 0000000..213267b
--- /dev/null
+++ b/programmer/io_context.py
@@ -0,0 +1,163 @@
+from typing import Protocol, TypedDict
+import os
+import subprocess
+import requests
+import shlex
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import Optional, Union
+
+
+class RunCommandResult(TypedDict):
+ exit_code: int
+ output: str
+
+
+class IOContext(Protocol):
+ def write_file(self, path: str, content: str) -> None: ...
+
+ def read_file(self, path: str) -> str: ...
+
+ def run_command(self, command: str) -> RunCommandResult: ...
+
+ def resolve_path(self, path: str) -> str: ...
+
+
+class LocalIOContext(IOContext):
+ def __init__(self, directory):
+ self.directory = os.path.abspath(directory)
+
+ def write_file(self, path: str, content: str) -> None:
+ full_path = self.resolve_path(path)
+ with open(full_path, "w") as f:
+ f.write(content)
+
+ def read_file(self, path: str) -> str:
+ full_path = self.resolve_path(path)
+ with open(full_path, "r") as f:
+ return f.read()
+
+ def run_command(self, command: str) -> RunCommandResult:
+ completed_process = subprocess.run(
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ shell=True,
+ cwd=self.directory,
+ )
+ exit_code = completed_process.returncode
+ output = completed_process.stdout.strip()
+
+ return {
+ "exit_code": exit_code,
+ "output": output,
+ }
+
+ def resolve_path(self, path: str) -> str:
+ return os.path.join(self.directory, path)
+
+
+class RemoteContainerIOContext(IOContext):
+ def __init__(self, base_url: str, directory: str, command_prefix: str):
+ self.base_url = base_url
+ self.container_id = None
+ self.directory = directory
+ self.command_prefix = command_prefix
+
+ @contextmanager
+ def context(self, image_id: str):
+ self.start_container(image_id)
+ try:
+ with io_context(self):
+ yield
+ finally:
+ self.stop_container()
+
+ def start_container(self, image_id):
+ response = requests.post(
+ f"{self.base_url}/container/start", json={"image_id": image_id}
+ )
+ if response.status_code == 200:
+ self.container_id = response.json().get("container_id")
+ else:
+ print(f"Failed to start container: {response.text}")
+
+ def stop_container(self):
+ response = requests.post(
+ f"{self.base_url}/container/stop",
+ json={"container_id": self.container_id, "delete": True},
+ )
+ if response.status_code == 200:
+ self.container_id = None
+ else:
+ print(f"Failed to stop container: {response.text}")
+
+ def write_file(self, path: str, content: str) -> None:
+ full_path = os.path.join(self.directory, path)
+ response = requests.post(
+ f"{self.base_url}/container/write_file",
+ json={
+ "container_id": self.container_id,
+ "file_path": full_path,
+ "file_content": content,
+ },
+ )
+ if response.status_code != 200:
+ raise Exception(f"Failed to write file: {response.text}")
+
+ def read_file(self, path: str) -> str:
+ full_path = os.path.join(self.directory, path)
+ response = requests.post(
+ f"{self.base_url}/container/read_file",
+ json={"container_id": self.container_id, "file_path": full_path},
+ )
+ if response.status_code == 200:
+ return response.json().get("file_content")
+ else:
+ raise Exception(f"Failed to read file: {response.text}")
+
+ def run_command(self, command: str) -> RunCommandResult:
+ command = self.command_prefix + command
+ command = f"bash -c {shlex.quote(command)}"
+ response = requests.post(
+ f"{self.base_url}/container/run",
+ json={
+ "container_id": self.container_id,
+ "workdir": self.directory,
+ "command": command,
+ },
+ )
+ if response.status_code == 200:
+ json = response.json()
+ return {
+ "exit_code": json["exit_code"],
+ "output": json["output"],
+ }
+ else:
+ raise Exception(f"Failed to run command: {response.text}")
+
+ def resolve_path(self, path: str) -> str:
+ return path # For remote containers, we assume paths are already resolved
+
+
+# Create a ContextVar to store the current ToolContext
+_io_context: ContextVar[Optional[Union[LocalIOContext, RemoteContainerIOContext]]] = (
+ ContextVar("_io_context", default=None)
+)
+
+
+@contextmanager
+def io_context(context: Union[LocalIOContext, RemoteContainerIOContext]):
+ token = _io_context.set(context)
+ try:
+ yield context
+ finally:
+ _io_context.reset(token)
+
+
+def get_io_context() -> Union[LocalIOContext, RemoteContainerIOContext]:
+ context = _io_context.get()
+ if context is None:
+ return LocalIOContext(".")
+ return context
diff --git a/programmer/programmer.py b/programmer/programmer.py
index 59352fe..403d1d6 100644
--- a/programmer/programmer.py
+++ b/programmer/programmer.py
@@ -14,6 +14,7 @@
from .console import Console
from .config import (
agent_4o_replace,
+ agent_texteditor_4o_basic,
)
from .environment import (
environment_session,
@@ -27,6 +28,8 @@
from .git import GitRepo
+agent = agent_texteditor_4o_basic
+
@weave.op
def get_user_input():
@@ -41,18 +44,13 @@ def user_input_step(state: AgentState) -> AgentState:
# if ref:
# print("state ref:", ref.uri())
user_input = get_user_input()
- environment = get_current_environment()
history = state.history + [
{
"role": "user",
"content": user_input,
}
]
- msg = get_commit_message(history)
- return AgentState(
- history=history,
- env_snapshot_key=environment.make_snapshot(msg),
- )
+ return state.with_history(history)
def make_environment():
@@ -74,12 +72,9 @@ def session(agent: Agent, agent_state: AgentState):
session_id = call.id
env = make_environment()
- msg = get_commit_message(agent_state.history)
with environment_session(env, session_id):
- agent_state = AgentState(
- history=agent_state.history, env_snapshot_key=env.make_snapshot(msg)
- )
+ agent_state = agent_state.with_history(agent_state.history)
while True:
result = agent.run(agent_state)
agent_state = result["state"]
@@ -155,16 +150,16 @@ def programmer():
else:
initial_prompt = input("Initial prompt: ")
- state = AgentState(
- history=[
+ state = agent.initial_state(
+ [
{
"role": "user",
"content": initial_prompt,
},
- ],
+ ]
)
- session(agent_4o_replace, state)
+ session(agent_texteditor_4o_basic, state)
def main():
diff --git a/programmer/swebench/score.py b/programmer/swebench/score.py
index e9cd799..2753b7c 100644
--- a/programmer/swebench/score.py
+++ b/programmer/swebench/score.py
@@ -10,12 +10,12 @@
SWEbenchInstance,
)
-from ..tools import RemoteContainerToolContext
+from ..io_context import RemoteContainerIOContext
def score_swebench(instance: SWEbenchInstance, model_output):
patch = model_output["answer"]
- tc = RemoteContainerToolContext(
+ tc = RemoteContainerIOContext(
"http://localhost:8000",
"/testbed",
"source /opt/miniconda3/bin/activate && conda activate testbed && ",
diff --git a/programmer/swebench/swebench_model.py b/programmer/swebench/swebench_model.py
index 74dc0b8..30d7c54 100644
--- a/programmer/swebench/swebench_model.py
+++ b/programmer/swebench/swebench_model.py
@@ -1,7 +1,7 @@
import weave
from ..agent import Agent, AgentState
-from ..tools import RemoteContainerToolContext
+from ..io_context import RemoteContainerIOContext
class SWEBenchProgrammerModel(weave.Model):
@@ -25,7 +25,7 @@ def predict(self, instance):
],
)
- tc = RemoteContainerToolContext(
+ tc = RemoteContainerIOContext(
"http://localhost:8000",
"/testbed",
"source /opt/miniconda3/bin/activate && conda activate testbed && ",
diff --git a/programmer/tests/test_file_line_tools.py b/programmer/tests/test_file_line_tools.py
index 74f1a29..84a1c65 100644
--- a/programmer/tests/test_file_line_tools.py
+++ b/programmer/tests/test_file_line_tools.py
@@ -4,16 +4,15 @@
from programmer.tools import (
read_lines_from_file,
splice_lines_in_file,
- LocalToolContext,
- tool_context,
- get_current_context,
+ get_io_context,
)
+from programmer.io_context import LocalIOContext, io_context
@pytest.fixture()
def tempdir_tool_context():
with TemporaryDirectory() as tmpdir:
- with tool_context(LocalToolContext(tmpdir)) as tc:
+ with io_context(LocalIOContext(tmpdir)) as tc:
yield tc
diff --git a/programmer/tests/test_text_editor.py b/programmer/tests/test_text_editor.py
new file mode 100644
index 0000000..f46bced
--- /dev/null
+++ b/programmer/tests/test_text_editor.py
@@ -0,0 +1,167 @@
+import pytest
+from tempfile import TemporaryDirectory
+from programmer.text_editor import (
+ TextEditor,
+ TextEditorState,
+ OpenFileState,
+ LineRange,
+ OpenFileResult,
+ WriteFileResult,
+ TextEditorMutationResult,
+)
+from programmer.io_context import LocalIOContext, io_context
+
+
+@pytest.fixture()
+def tempdir_tool_context():
+ with TemporaryDirectory() as tmpdir:
+ with io_context(LocalIOContext(tmpdir)) as tc:
+ yield tc
+
+
+@pytest.fixture()
+def sample_file(tempdir_tool_context):
+ file_path = "sample.txt"
+ content = "\n".join(f"Line {i}" for i in range(1, 201)) # 200 lines
+ tempdir_tool_context.write_file(file_path, content)
+ return file_path
+
+
+@pytest.fixture()
+def text_editor(tempdir_tool_context):
+ return TextEditor(max_open_size=150, open_chunk_size=50)
+
+
+@pytest.fixture()
+def initial_state():
+ return TextEditorState()
+
+
+def test_open_file(text_editor, sample_file, initial_state):
+ result = text_editor.open_file(initial_state, sample_file, 1)
+ assert isinstance(result, TextEditorMutationResult)
+ assert isinstance(result.action_result, OpenFileResult)
+ assert result.action_result.success
+ assert sample_file in result.new_state.open_files
+ assert (
+ result.new_state.open_files[sample_file].total_lines() == 50
+ ) # OPEN_CHUNK_SIZE
+
+
+def test_open_file_exceed_max_size(tempdir_tool_context, sample_file):
+ text_editor = TextEditor(max_open_size=75, open_chunk_size=50)
+ initial_state = TextEditorState()
+
+ # Open the file once (50 lines)
+ result1 = text_editor.open_file(initial_state, sample_file, 1)
+ assert result1.action_result.success
+
+ # Try to open another chunk, which would exceed the max_open_size
+ result2 = text_editor.open_file(result1.new_state, sample_file, 50)
+ assert isinstance(result2.action_result, OpenFileResult)
+ assert not result2.action_result.success
+ assert "exceeding the maximum" in result2.action_result.error
+
+
+def test_open_file_at_boundary(tempdir_tool_context, sample_file):
+ text_editor = TextEditor(max_open_size=100, open_chunk_size=50)
+ initial_state = TextEditorState()
+
+ # Open exactly MAX_OPEN_SIZE lines
+ result1 = text_editor.open_file(initial_state, sample_file, 1)
+ result2 = text_editor.open_file(result1.new_state, sample_file, 51)
+ assert result1.action_result.success and result2.action_result.success
+ assert result2.new_state.total_lines() == 100 # MAX_OPEN_SIZE
+
+ # Try to open one more line, which should fail
+ result3 = text_editor.open_file(result2.new_state, sample_file, 99)
+ assert not result3.action_result.success
+ assert "exceeding the maximum" in result3.action_result.error
+
+
+def test_replace_file_lines_at_boundary(text_editor, sample_file, initial_state):
+ state1 = text_editor.open_file(initial_state, sample_file, 1).new_state
+ state2 = text_editor.open_file(state1, sample_file, 51).new_state
+ state3 = text_editor.open_file(state2, sample_file, 101).new_state
+
+ # Replace 5 lines with 5 new lines (no net change)
+ result = text_editor.replace_file_lines(
+ state3,
+ sample_file,
+ [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 5}],
+ )
+ assert result.action_result.success
+
+ # Try to replace 5 lines with 6 new lines (net increase of 1, should fail)
+ result = text_editor.replace_file_lines(
+ state3,
+ sample_file,
+ [{"start_line": 1, "n_lines": 5, "lines": "New Line\n" * 6}],
+ )
+ assert not result.action_result.success
+ assert "exceeding the maximum" in result.action_result.error
+
+
+def test_replace_file_lines_middle(text_editor, sample_file, initial_state):
+ state1 = text_editor.open_file(initial_state, sample_file, 1).new_state
+
+ # Replace 5 lines with 5 new lines (no net change)
+ result = text_editor.replace_file_lines(
+ state1,
+ sample_file,
+ [{"start_line": 5, "n_lines": 5, "lines": "A\nB\n"}],
+ )
+ assert result.action_result.success
+
+ # Try to replace 5 lines with 6 new lines (net increase of 1, should fail)
+ file_info = result.new_state.get_open_file_info()
+ assert file_info.open_file_buffers[sample_file].total_lines == 197
+ assert len(file_info.open_file_buffers[sample_file].buffers) == 1
+ buffer0 = file_info.open_file_buffers[sample_file].buffers[0]
+ assert buffer0.line_range.start_line == 1
+ assert buffer0.line_range.n_lines == 50
+
+
+def test_close_file_range(text_editor, sample_file, initial_state):
+ state1 = text_editor.open_file(initial_state, sample_file, 1).new_state
+ result = text_editor.close_file_range(state1, sample_file, 1, 25)
+ assert result.new_state.open_files[sample_file].total_lines() == 25
+
+
+def test_get_open_file_info(text_editor, sample_file, initial_state):
+ state1 = text_editor.open_file(initial_state, sample_file, 1).new_state
+ info = state1.get_open_file_info()
+ assert sample_file in info.open_file_buffers
+ assert info.open_file_buffers[sample_file].total_lines == 200
+ assert len(info.open_file_buffers[sample_file].buffers) == 1
+ assert info.open_file_buffers[sample_file].buffers[0].line_range.start_line == 1
+ assert info.open_file_buffers[sample_file].buffers[0].line_range.n_lines == 50
+
+
+def test_open_file_multiple_ranges(text_editor, sample_file, initial_state):
+ state1 = text_editor.open_file(initial_state, sample_file, 1).new_state
+ state2 = text_editor.open_file(state1, sample_file, 51).new_state
+ assert len(state2.open_files[sample_file].ranges) == 1
+ assert state2.open_files[sample_file].ranges[0].start_line == 1
+ assert state2.open_files[sample_file].ranges[0].n_lines == 100
+
+
+def test_open_file_beyond_end(text_editor, sample_file, initial_state):
+ result = text_editor.open_file(initial_state, sample_file, 201)
+ assert isinstance(result.action_result, OpenFileResult)
+ assert not result.action_result.success
+ assert "beyond the end of the file" in result.action_result.error
+
+
+def test_open_file_at_end(text_editor, sample_file, initial_state):
+ result = text_editor.open_file(initial_state, sample_file, 199)
+ assert isinstance(result.action_result, OpenFileResult)
+ assert result.action_result.success
+ assert result.new_state.open_files[sample_file].total_lines() == 1
+
+
+def test_open_file_near_end(text_editor, sample_file, initial_state):
+ result = text_editor.open_file(initial_state, sample_file, 190)
+ assert isinstance(result.action_result, OpenFileResult)
+ assert result.action_result.success
+ assert result.new_state.open_files[sample_file].total_lines() == 10
diff --git a/programmer/tests/test_tool_calling.py b/programmer/tests/test_tool_calling.py
new file mode 100644
index 0000000..600297a
--- /dev/null
+++ b/programmer/tests/test_tool_calling.py
@@ -0,0 +1,94 @@
+from enum import Enum
+from typing import TypedDict
+
+import weave
+
+from programmer.tool_calling import generate_json_schema
+
+
+class Range(TypedDict):
+ start: int
+ end: int
+
+
+@weave.op
+def merge_ranges(ranges: list[Range]) -> list[Range]:
+ """Merge a list of ranges into a single range.
+
+ Args:
+ ranges: A list of ranges to merge.
+
+ Returns:
+ A list of merged ranges.
+ """
+ return ranges
+
+
+def test_list_of_typeddict_schema():
+ schema = generate_json_schema(merge_ranges)
+ assert schema == {
+ "function": {
+ "description": "Merge a list of ranges into a single range.",
+ "name": "merge_ranges",
+ "parameters": {
+ "properties": {
+ "ranges": {
+ "description": "A list of ranges to merge.",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "start": {"type": "integer"},
+ "end": {"type": "integer"},
+ },
+ "required": ["start", "end"],
+ },
+ }
+ },
+ "required": ["ranges"],
+ "type": "object",
+ },
+ },
+ "type": "function",
+ }
+
+
+class Color(Enum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+@weave.op
+def color_name(color: Color) -> str:
+ """Get the name of a color.
+
+ Args:
+ color: The color to get the name of.
+
+ Returns:
+ The name of the color.
+ """
+ return color.name
+
+
+def test_enum_schema():
+ schema = generate_json_schema(color_name)
+ assert schema == {
+ "function": {
+ "description": "Get the name of a color.",
+ "name": "color_name",
+ "parameters": {
+ "properties": {
+ "color": {
+ "description": "The color to get the name of.",
+ "enum": [1, 2, 3],
+ "type": "integer",
+ }
+ },
+ "required": ["color"],
+ "type": "object",
+ },
+ },
+ "type": "function",
+ }
diff --git a/programmer/text_editor.py b/programmer/text_editor.py
new file mode 100644
index 0000000..171bbc2
--- /dev/null
+++ b/programmer/text_editor.py
@@ -0,0 +1,507 @@
+from typing import Optional, Generic, TypeVar
+from dataclasses import dataclass, field
+from contextlib import contextmanager
+from contextvars import ContextVar
+from typing import Optional, TypedDict
+
+import weave
+
+from .io_context import get_io_context
+
+
+@dataclass(frozen=True)
+class LineRange:
+ start_line: int
+ n_lines: int
+
+
+@dataclass(frozen=True)
+class OpenFileState:
+ # Invariant: ranges must be non-overlapping and non-adjacent
+ # and must be in sorted order
+ ranges: tuple[LineRange, ...] = field(default_factory=tuple)
+
+ def add_range(self, range: LineRange) -> "OpenFileState":
+ # Create a new list of ranges
+ new_ranges = list(self.ranges)
+
+ # Find the correct position to insert the new range
+ insert_index = 0
+ for i, existing_range in enumerate(new_ranges):
+ if range.start_line < existing_range.start_line:
+ insert_index = i
+ break
+ insert_index = i + 1
+
+ # Insert the new range
+ new_ranges.insert(insert_index, range)
+
+ # Merge overlapping or adjacent ranges
+ i = 0
+ while i < len(new_ranges) - 1:
+ current_range = new_ranges[i]
+ next_range = new_ranges[i + 1]
+
+ if (
+ current_range.start_line + current_range.n_lines
+ >= next_range.start_line
+ ):
+ # Merge the ranges
+ merged_end = max(
+ current_range.start_line + current_range.n_lines,
+ next_range.start_line + next_range.n_lines,
+ )
+ new_ranges[i] = LineRange(
+ current_range.start_line, merged_end - current_range.start_line
+ )
+ new_ranges.pop(i + 1)
+ else:
+ i += 1
+
+ # Return a new OpenFileState with the updated ranges
+ return OpenFileState(ranges=tuple(new_ranges))
+
+ def subtract_range(self, range: LineRange) -> "OpenFileState":
+ new_ranges = []
+ for existing_range in self.ranges:
+ if range.start_line >= existing_range.start_line + existing_range.n_lines:
+ # The subtracted range is after this range, keep it as is
+ new_ranges.append(existing_range)
+ elif range.start_line + range.n_lines <= existing_range.start_line:
+ # The subtracted range is before this range, keep it as is
+ new_ranges.append(existing_range)
+ else:
+ # The ranges overlap, we need to split or adjust
+ if range.start_line > existing_range.start_line:
+ # Keep the part before the subtracted range
+ new_ranges.append(
+ LineRange(
+ existing_range.start_line,
+ range.start_line - existing_range.start_line,
+ )
+ )
+ if (
+ range.start_line + range.n_lines
+ < existing_range.start_line + existing_range.n_lines
+ ):
+ # Keep the part after the subtracted range
+ new_ranges.append(
+ LineRange(
+ range.start_line + range.n_lines,
+ (existing_range.start_line + existing_range.n_lines)
+ - (range.start_line + range.n_lines),
+ )
+ )
+
+ return OpenFileState(ranges=tuple(new_ranges))
+
+ def total_lines(self) -> int:
+ return sum(r.n_lines for r in self.ranges)
+
+ def is_range_open(self, start_line: int, n_lines: int) -> bool:
+ end_line = start_line + n_lines
+ for range in self.ranges:
+ if (
+ range.start_line <= start_line
+ and range.start_line + range.n_lines >= end_line
+ ):
+ return True
+ return False
+
+
+@dataclass(frozen=True)
+class TextEditorState:
+ open_files: dict[str, OpenFileState] = field(default_factory=dict)
+
+ def total_lines(self) -> int:
+ return sum(file.total_lines() for file in self.open_files.values())
+
+ def get_open_file_info(self) -> "OpenFileInfoResult":
+ file_io_context = get_io_context()
+ open_file_buffers = {}
+ for path, open_file in self.open_files.items():
+ contents = file_io_context.read_file(path)
+ lines = contents.split("\n")
+ buffers = []
+ for range in open_file.ranges:
+ buffer = Buffer(
+ line_range=range,
+ lines=lines[
+ range.start_line - 1 : range.start_line - 1 + range.n_lines
+ ],
+ )
+ buffers.append(buffer)
+ open_file_info = OpenFileInfo(
+ buffers=tuple(buffers), total_lines=len(lines)
+ )
+ open_file_buffers[path] = open_file_info
+ return OpenFileInfoResult(open_file_buffers=open_file_buffers)
+
+
+@dataclass(frozen=True)
+class Buffer:
+ line_range: LineRange
+ lines: list[str]
+
+
+@dataclass(frozen=True)
+class OpenFileInfo:
+ buffers: tuple[Buffer, ...] = field(default_factory=tuple)
+ total_lines: int = 0
+
+ def n_lines(self) -> int:
+ return sum(buffer.line_range.n_lines for buffer in self.buffers)
+
+
+@dataclass(frozen=True)
+class OpenFileInfoResult:
+ open_file_buffers: dict[str, OpenFileInfo] = field(default_factory=dict)
+
+ def format_for_messages(self) -> str:
+ lines = [
+ "Visible file buffers. These are the latest states of any previously opened file ranges, and reflect the results of all prior edits."
+ ]
+ for path, open_file_info in self.open_file_buffers.items():
+ lines.append(f"")
+ # lines.append(f"")
+ for buffer in open_file_info.buffers:
+ lines.append("")
+ for i, line in enumerate(buffer.lines):
+ lines.append(f"{buffer.line_range.start_line + i}: {line}")
+ lines.append("")
+ lines.append("")
+ return "\n".join(lines)
+
+
+@dataclass(frozen=True)
+class ClosedFileRange:
+ path: str
+ start_line: int
+ n_lines: int
+
+
+@dataclass(frozen=True)
+class OpenFileResult:
+ success: bool
+ error: str
+
+
+@dataclass(frozen=True)
+class WriteFileResult:
+ success: bool
+ error: str
+
+
+T = TypeVar("T")
+
+
+@dataclass(frozen=True)
+class TextEditorMutationResult(Generic[T]):
+ new_state: TextEditorState
+ action_result: T
+
+
+class LineRangeReplacement(TypedDict):
+ start_line: int
+ n_lines: int
+ lines: str
+
+
+class TextEditor:
+ def __init__(
+ self,
+ max_open_size: int = 1500,
+ open_chunk_size: int = 500,
+ ):
+ self.MAX_OPEN_SIZE = max_open_size
+ self.OPEN_CHUNK_SIZE = open_chunk_size
+
+ def open_file(
+ self, state: TextEditorState, path: str, start_line: int
+ ) -> TextEditorMutationResult[OpenFileResult]:
+ file_io_context = get_io_context()
+ try:
+ file_contents = file_io_context.read_file(path)
+ except FileNotFoundError:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=OpenFileResult(success=False, error="File not found"),
+ )
+
+ file_lines = file_contents.split("\n")
+ file_lines_count = len(file_lines)
+
+ if start_line < 1:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=OpenFileResult(
+ success=False,
+ error=f"Start line {start_line} is before the start of the file.",
+ ),
+ )
+
+ if start_line - 1 >= file_lines_count:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=OpenFileResult(
+ success=False,
+ error=f"Start line {start_line} is beyond the end of the file (which has {file_lines_count} lines).",
+ ),
+ )
+
+ orig_open_file_state = state.open_files.get(path, OpenFileState())
+ new_buffer = LineRange(
+ start_line, min(self.OPEN_CHUNK_SIZE, file_lines_count - start_line)
+ )
+ new_open_file_state = orig_open_file_state.add_range(new_buffer)
+ added_lines = (
+ new_open_file_state.total_lines() - orig_open_file_state.total_lines()
+ )
+
+ if state.total_lines() + added_lines > self.MAX_OPEN_SIZE:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=OpenFileResult(
+ success=False,
+ error=f"This request would result in {state.total_lines() + added_lines} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.",
+ ),
+ )
+
+ new_open_files = dict(state.open_files)
+ new_open_files[path] = new_open_file_state
+ new_state = TextEditorState(open_files=new_open_files)
+
+ return TextEditorMutationResult(
+ new_state=new_state,
+ action_result=OpenFileResult(success=True, error=""),
+ )
+
+ def close_file_range(
+ self, state: TextEditorState, path: str, start_line: int, n_lines: int
+ ) -> TextEditorMutationResult[None]:
+ open_file_state = state.open_files[path]
+ new_open_file_state = open_file_state.subtract_range(
+ LineRange(start_line, n_lines)
+ )
+
+ new_open_files = dict(state.open_files)
+ if new_open_file_state.total_lines() == 0:
+ del new_open_files[path]
+ else:
+ new_open_files[path] = new_open_file_state
+
+ new_state = TextEditorState(open_files=new_open_files)
+ return TextEditorMutationResult(new_state=new_state, action_result=None)
+
+ def replace_file_lines(
+ self,
+ state: TextEditorState,
+ path: str,
+ replacements: list[LineRangeReplacement],
+ ) -> TextEditorMutationResult[WriteFileResult]:
+ file_io_context = get_io_context()
+
+ # Check if the file is open
+ open_file_state = state.open_files.get(path)
+ if not open_file_state:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(
+ success=False,
+ error=f"The file {path} is not open.",
+ ),
+ )
+
+ # Check if all ranges are open
+ missing_ranges = []
+ for replacement in replacements:
+ if not open_file_state.is_range_open(
+ replacement["start_line"], replacement["n_lines"]
+ ):
+ missing_ranges.append(replacement)
+ if missing_ranges:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(
+ success=False,
+ error=f"The following ranges are not open: {missing_ranges}",
+ ),
+ )
+
+ # Sort replacements by start line
+ replacements.sort(key=lambda x: x["start_line"])
+
+ # Ensure replacements are non-overlapping
+ for i in range(len(replacements) - 1):
+ if (
+ replacements[i]["start_line"] + replacements[i]["n_lines"]
+ > replacements[i + 1]["start_line"]
+ ):
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(
+ success=False,
+ error=f"The following replacements are overlapping: {replacements[i]}, {replacements[i+1]}",
+ ),
+ )
+
+ all_new_lines = [l["lines"].rstrip("\n").split("\n") for l in replacements]
+
+ net_change = sum(len(l) for l in all_new_lines) - sum(
+ l["n_lines"] for l in replacements
+ )
+ if state.total_lines() + net_change > self.MAX_OPEN_SIZE:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(
+ success=False,
+ error=f"This edit would result in {state.total_lines() + net_change} open lines exceeding the maximum of {self.MAX_OPEN_SIZE} lines.",
+ ),
+ )
+
+ file_io_context = get_io_context()
+ try:
+ file_contents = file_io_context.read_file(path)
+ file_lines = file_contents.split("\n")
+ except Exception as e:
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(
+ success=False,
+ error=f"Failed to write to file: {str(e)}",
+ ),
+ )
+
+ # Apply replacements in reverse order to indexes don't change while iterating
+ for i, replacement in reversed(list(enumerate(replacements))):
+ start_line = replacement["start_line"]
+ n_lines = replacement["n_lines"]
+ file_lines[start_line - 1 : start_line - 1 + n_lines] = all_new_lines[i]
+
+ new_contents = "\n".join(file_lines)
+
+ file_io_context.write_file(path, new_contents)
+ return TextEditorMutationResult(
+ new_state=state,
+ action_result=WriteFileResult(success=True, error=""),
+ )
+
+
+class TextEditorStateful:
+ def __init__(self, text_editor: TextEditor, initial_state: TextEditorState):
+ self.text_editor = text_editor
+ self.state = initial_state
+
+ def open_file(self, path: str, start_line: int) -> OpenFileResult:
+ result = self.text_editor.open_file(self.state, path, start_line)
+ self.state = result.new_state
+ return result.action_result
+
+ def close_file_range(self, path: str, start_line: int, n_lines: int) -> None:
+ result = self.text_editor.close_file_range(
+ self.state, path, start_line, n_lines
+ )
+ self.state = result.new_state
+ return result.action_result
+
+ def replace_file_lines(
+ self,
+ path: str,
+ replacements: list[LineRangeReplacement],
+ ) -> WriteFileResult:
+ result = self.text_editor.replace_file_lines(self.state, path, replacements)
+ self.state = result.new_state
+ return result.action_result
+
+
+_text_editor_context: ContextVar[Optional[TextEditorStateful]] = ContextVar(
+ "_text_editor_context", default=None
+)
+
+
+@contextmanager
+def text_editor(context: TextEditorStateful):
+ token = _text_editor_context.set(context)
+ try:
+ yield context
+ finally:
+ _text_editor_context.reset(token)
+
+
+def require_text_editor() -> TextEditorStateful:
+ context = _text_editor_context.get()
+ assert context is not None
+ return context
+
+
+@weave.op
+def open_file(path: str, start_line: int) -> str:
+ """Open a buffer of lines from the given file.
+
+ Args:
+ path: The path to the file.
+ start_line: The line number to start reading from (1-indexed).
+
+ Returns:
+ "success" if the file was opened successfully,
+ "error: " if the file was not opened successfully.
+ """
+ text_editor = require_text_editor()
+ response = text_editor.open_file(path, start_line)
+ if response.success:
+ return "success"
+ else:
+ return f"error: {response.error}"
+
+
+@weave.op
+def close_file_range(path: str, start_line: int, n_lines: int) -> str:
+ """Close a buffer of lines from the given file.
+
+ Args:
+ path: The path to the file.
+ start_line: The line number to start reading from (1-indexed).
+ n_lines: The number of lines to close.
+
+ Returns:
+ "success" if the file was closed successfully.
+ """
+ text_editor = require_text_editor()
+ response = text_editor.close_file_range(path, start_line, n_lines)
+ return "success"
+
+
+class LineRangeReplacementStartEnd(TypedDict):
+ start_line: int
+ remove_up_to_line: int
+ lines: str
+
+
+@weave.op
+def replace_file_lines(
+ path: str, replacements: list[LineRangeReplacementStartEnd]
+) -> str:
+ """Replace ranges of lines within a file. Changes must be made to open ranges, and will be reflected immediately on the filesystem. First, existing lines are removed starting at start line, up to but not including replace_up_to_line. Then the new lines are added in that position.
+
+ Args:
+ path: The path to the file.
+ replacements: A list of replacements to make. Each replacement is a dictionary with keys: start_line (1-indexed, inclusive), remove_up_to_line (1-indexed, exclusive), lines (a string of newline separated lines to insert)
+
+ Returns:
+ "success" if the file was replaced successfully,
+ "error: " if the file was not replaced successfully.
+ """
+ text_editor = require_text_editor()
+ replacements_list = [
+ LineRangeReplacement(
+ start_line=r["start_line"],
+ n_lines=r["remove_up_to_line"] - r["start_line"],
+ lines=r["lines"],
+ )
+ for r in replacements
+ ]
+ response = text_editor.replace_file_lines(path, replacements_list)
+ if response.success:
+ return "success"
+ else:
+ return f"error: {response.error}"
diff --git a/programmer/tool_calling.py b/programmer/tool_calling.py
index c9a9f50..3207ebc 100644
--- a/programmer/tool_calling.py
+++ b/programmer/tool_calling.py
@@ -1,12 +1,53 @@
import inspect
import json
-from typing import Callable, get_type_hints
+import traceback
+import typing_extensions
+
+from typing import Any, Callable, get_type_hints, TypedDict
from openai.types.chat import ChatCompletionMessageToolCall, ChatCompletionToolParam
from .console import Console
+class TypedDictLike:
+ __required_keys__: frozenset[str]
+
+
+def is_typed_dict_like(t: type) -> typing_extensions.TypeGuard[TypedDictLike]:
+ return hasattr(t, "__required_keys__")
+
+
+def pytype_to_jsonschema(pytype: Any) -> dict:
+ if pytype.__name__ == "str":
+ return {"type": "string"}
+ elif pytype.__name__ == "int":
+ return {"type": "integer"}
+ elif is_typed_dict_like(pytype):
+ return {
+ "type": "object",
+ "properties": {
+ k: pytype_to_jsonschema(v) for k, v in pytype.__annotations__.items()
+ },
+ "required": list(pytype.__annotations__.keys()),
+ }
+ elif pytype.__name__ == "list":
+ return {"type": "array", "items": pytype_to_jsonschema(pytype.__args__[0])}
+ elif hasattr(pytype, "__members__"):
+ member_types = [
+ pytype_to_jsonschema(type(v.value)) for v in pytype.__members__.values()
+ ]
+ t0 = member_types[0]
+ for t in member_types[1:]:
+ if t != t0:
+ raise ValueError("All member types must be the same")
+ mem_type = t0["type"]
+ if mem_type != "string" and mem_type != "integer":
+ raise ValueError(f"Enum member type {mem_type} is not supported")
+ return {"type": mem_type, "enum": [e.value for e in pytype]}
+ raise ValueError(f"Unsupported type: {pytype.__name__}")
+
+
def generate_json_schema(func: Callable) -> dict:
"""Given a function, generate an OpenAI tool compatible JSON schema.
@@ -40,27 +81,21 @@ def generate_json_schema(func: Callable) -> dict:
is_required = param.default == inspect.Parameter.empty
# Extract parameter type and description
- param_type = type_hints[name].__name__ if name in type_hints else "string"
- if param_type == "str":
- param_type = "string"
- elif param_type == "int":
- param_type = "integer"
- param_desc = ""
+ param_schema = pytype_to_jsonschema(type_hints[name])
# Attempt to extract description from docstring
+ param_desc = ""
if func.__doc__:
doc_lines = func.__doc__.split("\n")[1:]
for line in doc_lines:
if name in line:
param_desc = line.strip().split(":")[-1].strip()
break
-
- # Populate schema for this parameter
- param_schema = {"type": param_type, "description": param_desc}
-
- # Handle special case for enums
- if hasattr(type_hints[name], "__members__"): # Check if it's an Enum
- param_schema["enum"] = [e.value for e in type_hints[name]]
+ if not param_desc:
+ raise ValueError(
+ f"Function {func.__name__} description for parameter {name} is missing"
+ )
+ param_schema["description"] = param_desc
schema["function"]["parameters"]["properties"][name] = param_schema # type: ignore
@@ -96,12 +131,15 @@ def perform_tool_calls(
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
- function_response = str(e)
+ print(f"Tool call {tool_call_s} failed to parse arguments: {e}")
+ function_response = f"Argument parse error: {str(e)}"
if not function_response:
try:
function_response = tool(**function_args)
except Exception as e:
- function_response = str(e)
+ print(f"Error occurred in tool {function_name}:")
+ traceback.print_exc()
+ function_response = f"Error: {str(e)}"
additional_message = None
if isinstance(function_response, tuple):
diff --git a/programmer/tools.py b/programmer/tools.py
index 44e155b..fb49272 100644
--- a/programmer/tools.py
+++ b/programmer/tools.py
@@ -1,14 +1,8 @@
import base64
-import json
import os
-import subprocess
import weave
-import contextlib
-import shlex
-from contextvars import ContextVar
-from contextlib import contextmanager
-from typing import Protocol, Union, TypedDict, Optional
-import requests
+
+from .io_context import get_io_context
LENGTH_LIMIT = 30000
@@ -17,161 +11,6 @@
# - must return FileNotFoundError in read_file in Remote
-class RunCommandResult(TypedDict):
- exit_code: int
- output: str
-
-
-class ToolContext(Protocol):
- def write_file(self, path: str, content: str) -> None: ...
-
- def read_file(self, path: str) -> str: ...
-
- def run_command(self, command: str) -> RunCommandResult: ...
-
- def resolve_path(self, path: str) -> str: ...
-
-
-class LocalToolContext(ToolContext):
- def __init__(self, directory):
- self.directory = os.path.abspath(directory)
-
- def write_file(self, path: str, content: str) -> None:
- full_path = self.resolve_path(path)
- with open(full_path, "w") as f:
- f.write(content)
-
- def read_file(self, path: str) -> str:
- full_path = self.resolve_path(path)
- with open(full_path, "r") as f:
- return f.read()
-
- def run_command(self, command: str) -> RunCommandResult:
- completed_process = subprocess.run(
- command,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- text=True,
- shell=True,
- cwd=self.directory,
- )
- exit_code = completed_process.returncode
- output = completed_process.stdout.strip()
-
- return {
- "exit_code": exit_code,
- "output": output,
- }
-
- def resolve_path(self, path: str) -> str:
- return os.path.join(self.directory, path)
-
-
-class RemoteContainerToolContext(ToolContext):
- def __init__(self, base_url: str, directory: str, command_prefix: str):
- self.base_url = base_url
- self.container_id = None
- self.directory = directory
- self.command_prefix = command_prefix
-
- @contextmanager
- def context(self, image_id: str):
- self.start_container(image_id)
- try:
- with tool_context(self):
- yield
- finally:
- self.stop_container()
-
- def start_container(self, image_id):
- response = requests.post(
- f"{self.base_url}/container/start", json={"image_id": image_id}
- )
- if response.status_code == 200:
- self.container_id = response.json().get("container_id")
- else:
- print(f"Failed to start container: {response.text}")
-
- def stop_container(self):
- response = requests.post(
- f"{self.base_url}/container/stop",
- json={"container_id": self.container_id, "delete": True},
- )
- if response.status_code == 200:
- self.container_id = None
- else:
- print(f"Failed to stop container: {response.text}")
-
- def write_file(self, path: str, content: str) -> None:
- full_path = os.path.join(self.directory, path)
- response = requests.post(
- f"{self.base_url}/container/write_file",
- json={
- "container_id": self.container_id,
- "file_path": full_path,
- "file_content": content,
- },
- )
- if response.status_code != 200:
- raise Exception(f"Failed to write file: {response.text}")
-
- def read_file(self, path: str) -> str:
- full_path = os.path.join(self.directory, path)
- response = requests.post(
- f"{self.base_url}/container/read_file",
- json={"container_id": self.container_id, "file_path": full_path},
- )
- if response.status_code == 200:
- return response.json().get("file_content")
- else:
- raise Exception(f"Failed to read file: {response.text}")
-
- def run_command(self, command: str) -> RunCommandResult:
- command = self.command_prefix + command
- command = f"bash -c {shlex.quote(command)}"
- response = requests.post(
- f"{self.base_url}/container/run",
- json={
- "container_id": self.container_id,
- "workdir": self.directory,
- "command": command,
- },
- )
- if response.status_code == 200:
- json = response.json()
- return {
- "exit_code": json["exit_code"],
- "output": json["output"],
- }
- else:
- raise Exception(f"Failed to run command: {response.text}")
-
- def resolve_path(self, path: str) -> str:
- return path # For remote containers, we assume paths are already resolved
-
-
-# Create a ContextVar to store the current ToolContext
-current_context: ContextVar[
- Optional[Union[LocalToolContext, RemoteContainerToolContext]]
-] = ContextVar("current_context", default=None)
-
-
-@contextlib.contextmanager
-def tool_context(context: Union[LocalToolContext, RemoteContainerToolContext]):
- token = current_context.set(context)
- try:
- yield context
- finally:
- current_context.reset(token)
-
-
-def get_current_context() -> Union[LocalToolContext, RemoteContainerToolContext]:
- context = current_context.get()
- if context is None:
- return LocalToolContext(".")
- return context
-
-
def read_image_as_base64(path: str):
ext = os.path.splitext(path)[1]
if ext not in [".jpg", ".jpeg", ".png"]:
@@ -201,7 +40,7 @@ def view_image(path: str):
Returns:
A message indicating that the image was displayed successfully.
"""
- context = get_current_context()
+ context = get_io_context()
full_path = context.resolve_path(path)
base64_image = read_image_as_base64(full_path)
@@ -226,7 +65,7 @@ def list_files(directory: str) -> str:
Returns:
The list of files in the directory.
"""
- context = get_current_context()
+ context = get_io_context()
# full_path = context.resolve_path(directory)
result = context.run_command(f"ls {directory}")
exit_code = result["exit_code"]
@@ -252,7 +91,7 @@ def write_to_file(path: str, content: str) -> str:
Returns:
A message indicating whether the file was written successfully.
"""
- context = get_current_context()
+ context = get_io_context()
if len(content) > LENGTH_LIMIT:
content = content[:LENGTH_LIMIT]
content += "\n... (truncated)"
@@ -270,7 +109,7 @@ def read_from_file(path: str) -> str:
Returns:
The content of the file.
"""
- context = get_current_context()
+ context = get_io_context()
result = context.read_file(path)
if len(result) > LENGTH_LIMIT:
result = result[:LENGTH_LIMIT]
@@ -288,7 +127,7 @@ def run_command(command: str) -> str:
Returns:
The output of the command.
"""
- context = get_current_context()
+ context = get_io_context()
result = context.run_command(command)
exit_code = result["exit_code"]
@@ -318,7 +157,7 @@ def read_lines_from_file(file_path: str, start_line: int) -> str:
Raises:
Exception: If the file does not exist or start_line is invalid.
"""
- context = get_current_context()
+ context = get_io_context()
full_path = context.resolve_path(file_path)
content = context.read_file(full_path)
lines = content.splitlines()
@@ -358,7 +197,7 @@ def replace_lines_in_file(
Raises:
Exception: If the line range is invalid or file cannot be accessed.
"""
- context = get_current_context()
+ context = get_io_context()
full_path = context.resolve_path(file_path)
try:
content = context.read_file(full_path)
@@ -423,7 +262,7 @@ def splice_lines_in_file(
Raises:
Exception: If the line range is invalid or file cannot be accessed.
"""
- context = get_current_context()
+ context = get_io_context()
full_path = context.resolve_path(file_path)
try:
content = context.read_file(full_path)
diff --git a/programmer/weave_next/weave_query.py b/programmer/weave_next/weave_query.py
index 54c0adb..28b90b9 100644
--- a/programmer/weave_next/weave_query.py
+++ b/programmer/weave_next/weave_query.py
@@ -167,7 +167,7 @@ def expand_refs(wc: WeaveClient, refs: Sequence[str]):
return Objs(wc, refs)
-def call(wc: WeaveClient, call_id: str):
+def get_call(wc: WeaveClient, call_id: str):
"""Return a raw Weave call."""
response = wc.server.calls_query(
CallsQueryReq(