Skip to content

Commit 119e53e

Browse files
committed
Simplify UI.
1 parent fc245fc commit 119e53e

File tree

1 file changed

+43
-78
lines changed

1 file changed

+43
-78
lines changed

ui.py

Lines changed: 43 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,29 @@ def cached_calls(
5656
op_names: Optional[Union[str, Sequence[str]]] = None,
5757
parent_ids: Optional[Union[str, Sequence[str]]] = None,
5858
limit: Optional[int] = None,
59+
expand_refs: Optional[list[str]] = None,
5960
):
60-
return calls(wc, op_names=op_names, parent_ids=parent_ids, limit=limit).to_pandas()
61+
return calls(
62+
wc,
63+
op_names=op_names,
64+
parent_ids=parent_ids,
65+
limit=limit,
66+
expand_refs=expand_refs,
67+
).to_pandas()
6168

6269

6370
@st.cache_data(hash_funcs=ST_HASH_FUNCS)
6471
def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]):
6572
return expand_refs(wc, refs).to_pandas()
6673

6774

68-
def print_step_call(call, start_state, end_state):
69-
if isinstance(end_state.history, float):
75+
def print_step_call(call):
76+
start_history = call["inputs.state.history"]
77+
end_history = call["output.history"]
78+
if isinstance(end_history, float):
7079
st.write("STEP WITH NO OUTPUT")
7180
return
72-
step_messages = list(end_state.history)[len(start_state.history) :]
81+
step_messages = list(end_history)[len(start_history) :]
7382
assistant_message = step_messages[0]
7483
tool_response_messages = step_messages[1:]
7584

@@ -96,113 +105,69 @@ def print_step_call(call, start_state, end_state):
96105
def set_focus_step_closure():
97106
set_focus_step_id(call.id)
98107

99-
# if start_snapshot_key is not None and end_snapshot_key is not None:
100-
# if (
101-
# start_snapshot_key["snapshot_info.commit"]
102-
# != end_snapshot_key["snapshot_info.commit"]
103-
# ):
104-
# st.text(
105-
# f'git diff {start_snapshot_key["snapshot_info.commit"]} {end_snapshot_key["snapshot_info.commit"]}'
106-
# )
108+
try:
109+
start_snapshot_commit = call[
110+
"inputs.state.env_snapshot_key.snapshot_info.commit"
111+
]
112+
end_snapshot_commit = call["output.env_snapshot_key.snapshot_info.commit"]
113+
114+
if start_snapshot_commit is not None and end_snapshot_commit is not None:
115+
if start_snapshot_commit != end_snapshot_commit:
116+
st.text(f"git diff {start_snapshot_commit} {end_snapshot_commit}")
117+
except KeyError:
118+
pass
107119

108120
# st.button("focus", key=f"focus-{call.id}", on_click=set_focus_step_closure)
109121

110122

111123
def print_run_call(
112124
call,
113125
steps_df,
114-
step_inputs_state,
115-
step_outputs,
116-
# steps_input_snapshot_key,
117-
# steps_output_snapshot_key,
118126
):
119127
st.write("RUN CALL", call.id)
120-
start_state = step_inputs_state.iloc[0]
121-
user_input = start_state["history"][-1]["content"]
128+
start_history = steps_df.iloc[0]["inputs.state.history"]
129+
user_input = start_history[-1]["content"]
122130
with st.chat_message("user"):
123131
st.write(user_input)
124132
for _, step in steps_df.iterrows():
125-
step_input_state = step_inputs_state.loc[step["inputs.state"]]
126-
step_output = step_outputs.loc[step["output"]]
127-
# step_input_snapshot_key = steps_input_snapshot_key.loc[step["inputs.state"]]
128-
# step_output_snapshot_key = steps_output_snapshot_key.loc[step["output"]]
129-
print_step_call(
130-
step,
131-
step_input_state,
132-
step_output,
133-
# step_input_snapshot_key,
134-
# step_output_snapshot_key,
135-
)
133+
print_step_call(step)
136134

137135

138136
def print_session_call(session_id):
139137
runs_df = cached_calls(client, "Agent.run", parent_ids=session_id)
140-
steps_df = cached_calls(client, "Agent.step", parent_ids=runs_df["id"].tolist())
141-
step_input_state = cached_expand_refs(client, steps_df["inputs.state"].tolist())
142-
if "env_snapshot_key" in step_input_state.columns:
143-
step_input_snapshot_key = cached_expand_refs(
144-
client, step_input_state["env_snapshot_key"].tolist()
145-
)
146-
else:
147-
step_input_snapshot_key = pd.DataFrame()
148-
# Make step_input_snapshot_key unique by index
149-
step_input_snapshot_key = step_input_snapshot_key.groupby(level=0).first()
150-
step_output_state = cached_expand_refs(client, steps_df["output"].tolist())
151-
if "env_snapshot_key" in step_output_state.columns:
152-
step_output_snapshot_key = cached_expand_refs(
153-
client, step_output_state["env_snapshot_key"].tolist()
154-
)
155-
else:
156-
step_output_snapshot_key = pd.DataFrame()
157-
# Make step_output_snapshot_key unique by index
158-
step_output_snapshot_key = step_output_snapshot_key.groupby(level=0).first()
138+
steps_df = cached_calls(
139+
client,
140+
"Agent.step",
141+
parent_ids=runs_df["id"].tolist(),
142+
expand_refs=[
143+
"inputs.state",
144+
"inputs.state.env_snapshot_key",
145+
"output",
146+
"output.env_snapshot_key",
147+
],
148+
)
159149

160150
for _, run_call_data in runs_df.iterrows():
161151
run_steps_df = steps_df[steps_df["parent_id"] == run_call_data["id"]]
162-
run_steps_inputs_state = step_input_state.loc[run_steps_df["inputs.state"]]
163-
run_steps_output = step_output_state.loc[run_steps_df["output"]]
164-
# if "env_snapshot_key" in run_steps_inputs_state.columns:
165-
# run_steps_input_snapshot_key = run_steps_inputs_state[
166-
# "env_snapshot_key"
167-
# ].apply(lambda x: None if pd.isna(x) else step_input_snapshot_key.loc[x])
168-
# else:
169-
# run_steps_input_snapshot_key = pd.Series(
170-
# [None] * len(run_steps_inputs_state), index=run_steps_inputs_state.index
171-
# )
172-
# if "env_snapshot_key" in run_steps_output.columns:
173-
# run_steps_output_snapshot_key = run_steps_output["env_snapshot_key"].apply(
174-
# lambda x: None if pd.isna(x) else step_output_snapshot_key.loc[x]
175-
# )
176-
# else:
177-
# run_steps_output_snapshot_key = pd.Series(
178-
# [None] * len(run_steps_output), index=run_steps_output.index
179-
# )
180152

181153
print_run_call(
182154
run_call_data,
183155
run_steps_df,
184-
run_steps_inputs_state,
185-
run_steps_output,
186-
# run_steps_input_snapshot_key,
187-
# run_steps_output_snapshot_key,
188156
)
189157

190158

191-
session_calls_df = cached_calls(client, "session")
192-
193-
194-
session_agent_state_df = cached_expand_refs(
195-
client, session_calls_df["inputs.agent_state"].tolist()
196-
)
197-
session_user_message_df = session_agent_state_df["history"].apply(
159+
session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"])
160+
session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
198161
lambda v: v[-1]["content"]
199162
)
200163

201164

202165
with st.sidebar:
203166
message_ids = {
204167
f"{cid[-5:]}: {m}": cid
205-
for cid, m in zip(session_calls_df["id"], session_user_message_df)
168+
for cid, m in reversed(
169+
list(zip(session_calls_df["id"], session_user_message_df))
170+
)
206171
}
207172
sel_message = st.radio("Session", options=message_ids.keys())
208173
sel_id = message_ids.get(sel_message)

0 commit comments

Comments
 (0)