diff --git a/backend/app/api/routes/teams.py b/backend/app/api/routes/teams.py index 4abf7d4f..3f64796c 100644 --- a/backend/app/api/routes/teams.py +++ b/backend/app/api/routes/teams.py @@ -243,6 +243,8 @@ async def stream( member.skills = member.skills return StreamingResponse( - generator(team, members, team_chat.messages, thread_id), + generator( + team, members, team_chat.messages, thread_id, team_chat.interrupt_decision + ), media_type="text/event-stream", ) diff --git a/backend/app/core/graph/build.py b/backend/app/core/graph/build.py index 1a4ba1df..636f0577 100644 --- a/backend/app/core/graph/build.py +++ b/backend/app/core/graph/build.py @@ -5,8 +5,9 @@ from functools import partial from typing import Any -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableLambda +from langchain_core.runnables.config import RunnableConfig from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END, StateGraph from langgraph.graph.graph import CompiledGraph @@ -27,7 +28,7 @@ WorkerNode, ) from app.core.graph.skills import all_skills -from app.models import ChatMessage, Member, Team +from app.models import ChatMessage, InterruptDecision, Member, Team def convert_hierarchical_team_to_dict( @@ -184,7 +185,6 @@ def exit_chain(state: TeamState) -> dict[str, list[AnyMessage]]: Pass the final response back to the top-level graph's state. """ answer = state["messages"][-1] - # Add human message at the end to prevent consecutive AI message which will cause error for some models return {"messages": [answer]} @@ -397,7 +397,11 @@ def convert_messages_and_tasks_to_dict(data: Any) -> Any: async def generator( - team: Team, members: list[Member], messages: list[ChatMessage], thread_id: str + team: Team, + members: list[Member], + messages: list[ChatMessage], + thread_id: str, + interrupt_decision: InterruptDecision | None = None, ) -> AsyncGenerator[Any, Any]: """Create the graph and stream responses as JSON.""" formatted_messages = [ @@ -417,7 +421,7 @@ async def generator( root = create_hierarchical_graph( teams, leader_name=team_leader, memory=memory ) - state = { + state: dict[str, Any] | None = { "messages": formatted_messages, "team": teams[team_leader], "main_task": formatted_messages, @@ -439,11 +443,32 @@ async def generator( ), "next": first_member.name, } + + config: RunnableConfig = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": 25, + } + # Handle interrupt logic by orriding state + if interrupt_decision == InterruptDecision.APPROVED: + state = None + elif interrupt_decision == InterruptDecision.REJECTED: + current_values = await root.aget_state(config) + tool_calls = current_values.values["messages"][-1].tool_calls + state = { + "messages": [ + ToolMessage( + tool_call_id=tool_call["id"], + content="API call denied by user. Continue assisting.", + ) + for tool_call in tool_calls + ] + } + async for output in root.astream_events( state, version="v1", include_names=["work", "delegate", "summarise"], - config={"configurable": {"thread_id": thread_id}, "recursion_limit": 25}, + config=config, ): if output["event"] == "on_chain_end": output_data = output["data"]["output"] @@ -453,6 +478,10 @@ async def generator( formatted_output = f"data: {json.dumps(transformed_output_data)}\n\n" if formatted_output != "data: null\n\n": yield formatted_output + snapshot = await root.aget_state(config) + if snapshot.next: + # Interrupt occured + yield f"data: {json.dumps({'interrupt': True})}\n\n" except Exception as e: error_message = { "error": str(e), diff --git a/backend/app/models.py b/backend/app/models.py index 6efea6b5..9d8bb8d9 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -112,8 +112,18 @@ class ChatMessage(BaseModel): content: str +class InterruptDecision(Enum): + APPROVED = "approved" + REJECTED = "rejected" + + +class Interrupt(BaseModel): + decision: InterruptDecision + + class TeamChat(BaseModel): messages: list[ChatMessage] + interrupt_decision: InterruptDecision | None = None class Team(TeamBase, table=True): diff --git a/frontend/src/client/index.ts b/frontend/src/client/index.ts index 73b04c5d..c96fd6ff 100644 --- a/frontend/src/client/index.ts +++ b/frontend/src/client/index.ts @@ -13,6 +13,7 @@ export type { ChatMessageType } from './models/ChatMessageType'; export type { CheckpointOut } from './models/CheckpointOut'; export type { CreateThreadOut } from './models/CreateThreadOut'; export type { HTTPValidationError } from './models/HTTPValidationError'; +export type { InterruptDecision } from './models/InterruptDecision'; export type { MemberCreate } from './models/MemberCreate'; export type { MemberOut } from './models/MemberOut'; export type { MembersOut } from './models/MembersOut'; @@ -47,6 +48,7 @@ export { $ChatMessageType } from './schemas/$ChatMessageType'; export { $CheckpointOut } from './schemas/$CheckpointOut'; export { $CreateThreadOut } from './schemas/$CreateThreadOut'; export { $HTTPValidationError } from './schemas/$HTTPValidationError'; +export { $InterruptDecision } from './schemas/$InterruptDecision'; export { $MemberCreate } from './schemas/$MemberCreate'; export { $MemberOut } from './schemas/$MemberOut'; export { $MembersOut } from './schemas/$MembersOut'; diff --git a/frontend/src/client/models/CreateThreadOut.ts b/frontend/src/client/models/CreateThreadOut.ts index 620b49b1..bb9a6b0e 100644 --- a/frontend/src/client/models/CreateThreadOut.ts +++ b/frontend/src/client/models/CreateThreadOut.ts @@ -9,6 +9,6 @@ export type CreateThreadOut = { id: string; query: string; updated_at: string; - last_checkpoint: CheckpointOut; + last_checkpoint: (CheckpointOut | null); }; diff --git a/frontend/src/client/models/InterruptDecision.ts b/frontend/src/client/models/InterruptDecision.ts new file mode 100644 index 00000000..63a82e6b --- /dev/null +++ b/frontend/src/client/models/InterruptDecision.ts @@ -0,0 +1,6 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type InterruptDecision = 'approved' | 'rejected'; diff --git a/frontend/src/client/models/TeamChat.ts b/frontend/src/client/models/TeamChat.ts index f0f6dc66..3e28fdb2 100644 --- a/frontend/src/client/models/TeamChat.ts +++ b/frontend/src/client/models/TeamChat.ts @@ -4,8 +4,10 @@ /* eslint-disable */ import type { ChatMessage } from './ChatMessage'; +import type { InterruptDecision } from './InterruptDecision'; export type TeamChat = { messages: Array; + interrupt_decision?: (InterruptDecision | null); }; diff --git a/frontend/src/client/schemas/$CreateThreadOut.ts b/frontend/src/client/schemas/$CreateThreadOut.ts index 78fd938e..2c71ae80 100644 --- a/frontend/src/client/schemas/$CreateThreadOut.ts +++ b/frontend/src/client/schemas/$CreateThreadOut.ts @@ -19,7 +19,12 @@ export const $CreateThreadOut = { format: 'date-time', }, last_checkpoint: { - type: 'CheckpointOut', + type: 'any-of', + contains: [{ + type: 'CheckpointOut', + }, { + type: 'null', + }], isRequired: true, }, }, diff --git a/frontend/src/client/schemas/$InterruptDecision.ts b/frontend/src/client/schemas/$InterruptDecision.ts new file mode 100644 index 00000000..e0547da7 --- /dev/null +++ b/frontend/src/client/schemas/$InterruptDecision.ts @@ -0,0 +1,7 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $InterruptDecision = { + type: 'Enum', +} as const; diff --git a/frontend/src/client/schemas/$TeamChat.ts b/frontend/src/client/schemas/$TeamChat.ts index b177c5cf..90a59e9f 100644 --- a/frontend/src/client/schemas/$TeamChat.ts +++ b/frontend/src/client/schemas/$TeamChat.ts @@ -11,5 +11,13 @@ export const $TeamChat = { }, isRequired: true, }, + interrupt_decision: { + type: 'any-of', + contains: [{ + type: 'InterruptDecision', + }, { + type: 'null', + }], + }, }, } as const; diff --git a/frontend/src/components/Members/EditMember.tsx b/frontend/src/components/Members/EditMember.tsx index 74d151b6..4fe966f7 100644 --- a/frontend/src/components/Members/EditMember.tsx +++ b/frontend/src/components/Members/EditMember.tsx @@ -1,5 +1,6 @@ import { Button, + Checkbox, FormControl, FormErrorMessage, FormLabel, @@ -51,7 +52,11 @@ const customSelectOption = { // TODO: Place this somewhere else. const AVAILABLE_MODELS = { ChatOpenAI: ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"], - ChatAnthropic: ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"], + ChatAnthropic: [ + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + ], // ChatCohere: ["command"], // ChatGoogleGenerativeAI: ["gemini-pro"], } @@ -238,6 +243,14 @@ export function EditMember({ )} /> + {member.type.startsWith("freelancer") ? ( + + Human In The Loop + + Require approval before executing skills. + + + ) : null} Provider