Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text editor + Playground UI #21

Merged
merged 16 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 286 additions & 26 deletions programmer-ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
import streamlit as st
import weave
import os
import openai
import copy
from weave.trace.weave_client import WeaveClient

from programmer.weave_next.api import init_local_client
from programmer.weave_next.weave_query import calls, expand_refs
from programmer.weave_next.weave_query import (
calls,
expand_refs,
get_call,
expand_json_refs,
)
from programmer.settings_manager import SettingsManager

st.set_page_config(layout="wide")
Expand Down Expand Up @@ -47,7 +54,46 @@ def init_from_settings() -> WeaveClient:
raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}")


client = init_from_settings()
# Add sidebar for Weave project configuration
with st.sidebar:
st.header("Weave Project Configuration")

# Initialize from settings
initial_weave_logging = SettingsManager.get_setting("weave_logging")
initial_project_type = "local" if initial_weave_logging == "local" else "cloud"
initial_project_path = (
os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db")
if initial_weave_logging == "local"
else ""
)
initial_project_name = (
f"programmer-{os.path.basename(os.path.abspath(os.curdir))}"
if initial_weave_logging == "cloud"
else ""
)

project_type = st.radio(
"Project Type",
["local", "cloud"],
index=0 if initial_project_type == "local" else 1,
)

if project_type == "local":
project_path = st.text_input("Local DB Path", value=initial_project_path)
# SettingsManager.set_setting("weave_logging", "local")
# SettingsManager.set_setting("weave_db_path", project_path)
client = init_local_weave(project_path)
print("C2", client._project_id())
else:
# SettingsManager.set_setting("weave_logging", "cloud")
# SettingsManager.set_setting("weave_project_name", project_name)
project_name = st.text_input("Cloud Project Name", value=initial_project_name)
client = init_remote_weave(project_name)
print("C3", client._project_id())

# Initialize client based on current settings
# client = init_from_settings()
print("CLIENT", client._project_id())


def set_focus_step_id(call_id):
Expand Down Expand Up @@ -76,6 +122,16 @@ def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]):
return expand_refs(wc, refs).to_pandas()


@st.cache_data(hash_funcs=ST_HASH_FUNCS)
def cached_get_call(wc: WeaveClient, call_id: str):
return get_call(wc, call_id)


@st.cache_data(hash_funcs=ST_HASH_FUNCS)
def cached_expand_json_refs(wc: WeaveClient, json: dict):
return expand_json_refs(wc, json)


def print_step_call(call):
start_history = call["inputs.state.history"]
end_history = call["output.history"]
Expand Down Expand Up @@ -174,30 +230,234 @@ def print_session_call(session_id):
)


session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"])
if len(session_calls_df) == 0:
st.error("No programmer sessions found.")
st.stop()
session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
lambda v: v[-1]["content"]
)

def sessions_page():
session_calls_df = cached_calls(
client, "session", expand_refs=["inputs.agent_state"]
)
if len(session_calls_df) == 0:
st.error("No programmer sessions found.")
st.stop()
session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
lambda v: v[-1]["content"]
)
with st.sidebar:
st.header("Session Selection")
if st.button("Refresh"):
st.cache_data.clear()
st.rerun()
message_ids = {
f"{cid[-5:]}: {m}": cid
for cid, m in reversed(
list(zip(session_calls_df["id"], session_user_message_df))
)
}
sel_message = st.radio("Session", options=message_ids.keys())
sel_id = None
if sel_message:
sel_id = message_ids.get(sel_message)
if sel_id:
st.header(f"Session: {sel_id}")
print_session_call(sel_id)


sessions_pg = st.Page(sessions_page, title="Sessions")


# def write_chat_message(m, key):
# with st.chat_message(m["role"]):
# if "content" in m:
# st.text_area(
# "", value=str(m["content"]), label_visibility="collapsed", key=key
# )
def write_chat_message(m, key, readonly=False):
def on_change_content():
new_value = st.session_state[key]
st.session_state.playground_state["editable_call"]["inputs"]["messages"][
m["original_index"]
]["content"] = new_value

with st.chat_message(m["role"]):
if m.get("content"):
if readonly:
st.code(m["content"])
else:
st.text_area(
"",
value=m["content"],
label_visibility="collapsed",
key=key,
on_change=on_change_content,
)
if m.get("tool_calls"):
for t in m["tool_calls"]:
st.write(t["function"]["name"])
st.json(
{
"arguments": t["function"]["arguments"],
"response": t.get("response", {}).get("content"),
},
expanded=True,
)


def attach_tool_call_responses(messages):
new_messages = []
for i, m in enumerate(messages):
new_m = copy.deepcopy(m)
new_m["original_index"] = i
if new_m["role"] == "assistant" and "tool_calls" in new_m:
new_m["tool_call_responses"] = []
for t in new_m["tool_calls"]:
t_id = t["id"]
for j, t_response in enumerate(messages):
if t_response.get("tool_call_id") == t_id:
t["response"] = t_response
t["response"]["original_index"] = j
break
if "tool_call_id" not in new_m:
new_messages.append(new_m)
return new_messages


def playground_page():
with st.sidebar:
if not st.session_state.get("playground_state"):
st.session_state.playground_state = {
"call_id": None,
"call": None,
"expanded_call": None,
"editable_call": None,
}
playground_state = st.session_state.playground_state
call_id = st.text_input("Call ID")
if not call_id:
st.error("Please set call ID")
st.stop()

# st.write(playground_state)
if playground_state["expanded_call"] != playground_state["editable_call"]:
st.warning("Call has been modified")
if st.button("Restore original call"):
st.session_state.playground_state["editable_call"] = copy.deepcopy(
playground_state["expanded_call"]
)
st.rerun()

if call_id != st.session_state.playground_state["call_id"]:
st.spinner("Loading call...")
call = cached_get_call(client, call_id)
editable_call = cached_expand_json_refs(client, call)
st.session_state.playground_state = {
"call_id": call_id,
"call": call,
"expanded_call": editable_call,
"editable_call": copy.deepcopy(editable_call),
}
st.rerun()

call = st.session_state.playground_state["call"]
editable_call = st.session_state.playground_state["editable_call"]
if call is None or editable_call is None:
st.warning("call not yet loaded")
st.stop()

st.write(call["op_name"])
# st.json(call["inputs"])
# st.json(call["inputs"]["tools"])

def on_change_temperature():
st.session_state.playground_state["editable_call"]["inputs"][
"temperature"
] = st.session_state["temperature"]

st.slider(
"Temperature",
min_value=0.0,
max_value=1.0,
value=editable_call["inputs"]["temperature"],
key="temperature",
on_change=on_change_temperature,
)

with st.sidebar:
if st.button("Refresh"):
st.cache_data.clear()
st.rerun()
message_ids = {
f"{cid[-5:]}: {m}": cid
for cid, m in reversed(
list(zip(session_calls_df["id"], session_user_message_df))
tools = call["inputs"].get("tools", [])
if tools:
st.write("Tools")
for tool_idx, t in enumerate(tools):
with st.expander(t["function"]["name"]):

def on_change_tool():
st.session_state.playground_state["editable_call"]["inputs"][
"tools"
][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"])
st.rerun()

st.text_area(
"json",
value=json.dumps(t, indent=2),
height=300,
key=f"tool-{tool_idx}",
on_change=on_change_tool,
)

def on_change_parallel_tool_calls():
st.session_state.playground_state["editable_call"]["inputs"][
"parallel_tool_calls"
] = st.session_state["parallel_tool_calls"]

st.checkbox(
"Parallel tool calls",
value=editable_call["inputs"].get("parallel_tool_calls", True),
key="parallel_tool_calls",
on_change=on_change_parallel_tool_calls,
)

inputs = editable_call["inputs"]
all_input_messages = inputs["messages"]
other_inputs = {
k: v
for k, v in inputs.items()
if (k != "messages" and k != "self" and k != "stream")
}
sel_message = st.radio("Session", options=message_ids.keys())
sel_id = None
if sel_message:
sel_id = message_ids.get(sel_message)

if sel_id:
st.header(f"Session: {sel_id}")
print_session_call(sel_id)

tool_call_attached_messages = attach_tool_call_responses(all_input_messages)
for i, m in enumerate(tool_call_attached_messages):
write_chat_message(m, f"message-{i}")
# output = editable_call["output"]["choices"][0]["message"]
n_choices = st.number_input(
"Number of choices", value=1, min_value=1, max_value=100
)
if st.button("Generate"):
chat_inputs = {**editable_call["inputs"]}
# st.json(chat_inputs, expanded=False)
del chat_inputs["stream"]
del chat_inputs["self"]
chat_inputs["n"] = n_choices
call_resp = openai.chat.completions.create(**chat_inputs).model_dump()

editable_call["output"] = call_resp
st.rerun()
# st.json(response, expanded=False)
# output = response["choices"][0]["message"]
# st.json(output)
response = editable_call["output"]
st.write("full response")
st.json(response, expanded=False)
st.write("**system fingerprint**", response["system_fingerprint"])
st.write("**usage**", response["usage"])
for i, choice in enumerate(response["choices"]):
output = choice["message"]
st.write(f"Choice {i+1}")
write_chat_message(output, f"output_message-{i}", readonly=True)

# all_messages = [*all_input_messages, output]
# st.json(st.session_state.playground_state, expanded=False)
# st.json(all_messages, expanded=False)

# st.write(expanded_call)


playground_pg = st.Page(playground_page, title="Playground")


pg = st.navigation([sessions_pg, playground_pg])
pg.run()
15 changes: 10 additions & 5 deletions programmer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class AgentState(weave.Object):
history: list[Any] = Field(default_factory=list)
env_snapshot_key: Optional[EnvironmentSnapshotKey] = None

def with_history(self, history: list[Any]) -> "AgentState":
environment = get_current_environment()
msg = get_commit_message(history)
snapshot_key = environment.make_snapshot(msg)
return self.__class__(history=history, env_snapshot_key=snapshot_key)


def unweavify(v: Any) -> Any:
if isinstance(v, list):
Expand All @@ -55,6 +61,9 @@ class Agent(weave.Object):
system_message: str
tools: list[Any] = Field(default_factory=list)

def initial_state(self, history: list[Any]) -> AgentState:
return AgentState().with_history(history)

@weave.op()
def step(self, state: AgentState) -> AgentState:
"""Run a step of the agent.
Expand Down Expand Up @@ -118,12 +127,8 @@ def step(self, state: AgentState) -> AgentState:

# new_history = state.history + new_messages
new_history = weavelist_add(state.history, new_messages)
msg = get_commit_message(new_history)

environment = get_current_environment()
snapshot_key = environment.make_snapshot(msg)

return AgentState(history=new_history, env_snapshot_key=snapshot_key)
return state.with_history(new_history)

@weave.op()
def run(self, state: AgentState, max_runtime_seconds: int = -1):
Expand Down
Loading
Loading