Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions ee/hogai/funnels/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI

from ee.hogai.funnels.prompts import funnel_system_prompt, react_system_prompt
from ee.hogai.funnels.prompts import FUNNEL_SYSTEM_PROMPT, REACT_SYSTEM_PROMPT
from ee.hogai.funnels.toolkit import FUNNEL_SCHEMA, FunnelsTaxonomyAgentToolkit
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
Expand All @@ -16,42 +15,35 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
toolkit = FunnelsTaxonomyAgentToolkit(self._team)
prompt = ChatPromptTemplate.from_messages(
[
("system", react_system_prompt),
("system", REACT_SYSTEM_PROMPT),
],
template_format="mustache",
)
return super()._run(state, prompt, toolkit, config=config)
return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config)


class FunnelPlannerToolsNode(TaxonomyAgentPlannerToolsNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
toolkit = FunnelsTaxonomyAgentToolkit(self._team)
return super()._run(state, toolkit, config=config)
return super()._run_with_toolkit(state, toolkit, config=config)


FunnelsSchemaGeneratorOutput = SchemaGeneratorOutput[AssistantFunnelsQuery]


class FunnelGeneratorNode(SchemaGeneratorNode[AssistantFunnelsQuery]):
insight_name = "Funnels"
output_model = FunnelsSchemaGeneratorOutput
INSIGHT_NAME = "Funnels"
OUTPUT_MODEL = FunnelsSchemaGeneratorOutput
OUTPUT_SCHEMA = FUNNEL_SCHEMA

def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
prompt = ChatPromptTemplate.from_messages(
[
("system", funnel_system_prompt),
("system", FUNNEL_SYSTEM_PROMPT),
],
template_format="mustache",
)
return super()._run(state, prompt, config=config)

@property
def _model(self):
return ChatOpenAI(model="gpt-4o", temperature=0.2, streaming=True).with_structured_output(
FUNNEL_SCHEMA,
method="function_calling",
include_raw=False,
)
return super()._run_with_prompt(state, prompt, config=config)


class FunnelGeneratorToolsNode(SchemaGeneratorToolsNode):
Expand Down
10 changes: 5 additions & 5 deletions ee/hogai/funnels/prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ee.hogai.taxonomy_agent.prompts import react_format_prompt, react_format_reminder_prompt
from ee.hogai.taxonomy_agent.prompts import REACT_FORMAT_PROMPT, REACT_FORMAT_REMINDER_PROMPT

react_system_prompt = f"""
REACT_SYSTEM_PROMPT = f"""
You're a product analyst agent. Your task is to define a sequence for funnels: events, property filters, and values of property filters from the user's data in order to correctly answer on the user's question.

The product being analyzed is described as follows:
{{{{product_description}}}}

{react_format_prompt}
{REACT_FORMAT_PROMPT}

Below you will find information on how to correctly discover the taxonomy of the user's data.

Expand Down Expand Up @@ -88,10 +88,10 @@

---

{react_format_reminder_prompt}
{REACT_FORMAT_REMINDER_PROMPT}
"""

funnel_system_prompt = """
FUNNEL_SYSTEM_PROMPT = """
Act as an expert product manager. Your task is to generate a JSON schema of funnel insights. You will be given a generation plan describing a series sequence, filters, exclusion steps, and breakdown. Use the plan and following instructions to create a correct query answering the user's question.

Below is the additional context.
Expand Down
18 changes: 9 additions & 9 deletions ee/hogai/router/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,28 @@
from pydantic import BaseModel, Field

from ee.hogai.router.prompts import (
router_insight_description_prompt,
router_system_prompt,
router_user_prompt,
ROUTER_INSIGHT_DESCRIPTION_PROMPT,
ROUTER_SYSTEM_PROMPT,
ROUTER_USER_PROMPT,
)
from ee.hogai.utils import AssistantNode, AssistantState
from ee.hogai.utils import AssistantState, AssistantNode
from posthog.schema import HumanMessage, RouterMessage

RouteName = Literal["trends", "funnel"]


class RouterOutput(BaseModel):
visualization_type: Literal["trends", "funnel"] = Field(..., description=router_insight_description_prompt)
visualization_type: Literal["trends", "funnel"] = Field(..., description=ROUTER_INSIGHT_DESCRIPTION_PROMPT)


class RouterNode(AssistantNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
prompt = ChatPromptTemplate.from_messages(
[
("system", router_system_prompt),
("system", ROUTER_SYSTEM_PROMPT),
],
template_format="mustache",
) + self._reconstruct_conversation(state)
) + self._construct_messages(state)
chain = prompt | self._model
output: RouterOutput = chain.invoke({}, config)
return {"messages": [RouterMessage(content=output.visualization_type)]}
Expand All @@ -45,12 +45,12 @@ def _model(self):
RouterOutput
)

def _reconstruct_conversation(self, state: AssistantState):
def _construct_messages(self, state: AssistantState):
history: list[BaseMessage] = []
for message in state["messages"]:
if isinstance(message, HumanMessage):
history += ChatPromptTemplate.from_messages(
[("user", router_user_prompt.strip())], template_format="mustache"
[("user", ROUTER_USER_PROMPT.strip())], template_format="mustache"
).format_messages(question=message.content)
elif isinstance(message, RouterMessage):
history += [
Expand Down
6 changes: 3 additions & 3 deletions ee/hogai/router/prompts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
router_system_prompt = """
ROUTER_SYSTEM_PROMPT = """
Act as an expert product manager. Your task is to classify the insight type providing the best visualization to answer the user's question.
"""

router_insight_description_prompt = f"""
ROUTER_INSIGHT_DESCRIPTION_PROMPT = f"""
Pick the most suitable visualization type for the user's question.

## `trends`
Expand All @@ -27,6 +27,6 @@
- If product changes are improving their funnel over time.
"""

router_user_prompt = """
ROUTER_USER_PROMPT = """
Question: {{question}}
"""
6 changes: 2 additions & 4 deletions ee/hogai/router/test/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def test_node_runs(self):
def test_node_reconstructs_conversation(self):
node = RouterNode(self.team)
state: Any = {"messages": [HumanMessage(content="generate trends")]}
self.assertEqual(
node._reconstruct_conversation(state), [LangchainHumanMessage(content="Question: generate trends")]
)
self.assertEqual(node._construct_messages(state), [LangchainHumanMessage(content="Question: generate trends")])
state = {
"messages": [
HumanMessage(content="generate trends"),
Expand All @@ -53,6 +51,6 @@ def test_node_reconstructs_conversation(self):
]
}
self.assertEqual(
node._reconstruct_conversation(state),
node._construct_messages(state),
[LangchainHumanMessage(content="Question: generate trends"), LangchainAIMessage(content="trends")],
)
68 changes: 36 additions & 32 deletions ee/hogai/schema_generator/nodes.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,64 @@
import itertools
import xml.etree.ElementTree as ET
from abc import abstractmethod
from functools import cached_property
from typing import Generic, Optional, TypeVar

from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAssistantMessage, BaseMessage, merge_message_runs
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, ValidationError

from ee.hogai.schema_generator.parsers import (
PydanticOutputParserException,
parse_pydantic_structured_output,
)
from ee.hogai.schema_generator.prompts import (
failover_output_prompt,
failover_prompt,
group_mapping_prompt,
new_plan_prompt,
plan_prompt,
question_prompt,
FAILOVER_OUTPUT_PROMPT,
FAILOVER_PROMPT,
GROUP_MAPPING_PROMPT,
NEW_PLAN_PROMPT,
PLAN_PROMPT,
QUESTION_PROMPT,
)
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
from ee.hogai.utils import AssistantNode, AssistantState, filter_visualization_conversation
from ee.hogai.utils import AssistantState, AssistantNode, filter_visualization_conversation
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.schema import (
FailureMessage,
VisualizationMessage,
)

T = TypeVar("T", bound=BaseModel)
Q = TypeVar("Q", bound=BaseModel)


class SchemaGeneratorNode(AssistantNode, Generic[T]):
insight_name: str
class SchemaGeneratorNode(AssistantNode, Generic[Q]):
INSIGHT_NAME: str
"""
Name of the insight type used in the exception messages.
"""
output_model: type[SchemaGeneratorOutput[T]]
OUTPUT_MODEL: type[SchemaGeneratorOutput[Q]]
"""Pydantic model of the output to be generated by the LLM."""
OUTPUT_SCHEMA: dict
"""JSON schema of OUTPUT_MODEL for LLM's use."""

@property
@abstractmethod
def _model(self) -> Runnable:
raise NotImplementedError
def _model(self):
return ChatOpenAI(model="gpt-4o", temperature=0.2, streaming=True).with_structured_output(
self.OUTPUT_SCHEMA,
method="function_calling",
include_raw=False,
)

@classmethod
def parse_output(cls, output: dict) -> Optional[SchemaGeneratorOutput[T]]:
def parse_output(cls, output: dict) -> Optional[SchemaGeneratorOutput[Q]]:
try:
return cls.output_model.model_validate(output)
return cls.OUTPUT_MODEL.model_validate(output)
except ValidationError:
return None

def _run(
def _run_with_prompt(
self,
state: AssistantState,
prompt: ChatPromptTemplate,
Expand All @@ -62,23 +68,21 @@ def _run(
intermediate_steps = state.get("intermediate_steps") or []
validation_error_message = intermediate_steps[-1][1] if intermediate_steps else None

generation_prompt = prompt + self._reconstruct_conversation(
state, validation_error_message=validation_error_message
)
generation_prompt = prompt + self._construct_messages(state, validation_error_message=validation_error_message)
merger = merge_message_runs()
parser = parse_pydantic_structured_output(self.output_model)
parser = parse_pydantic_structured_output(self.OUTPUT_MODEL)

chain = generation_prompt | merger | self._model | parser

try:
message: SchemaGeneratorOutput[T] = chain.invoke({}, config)
message: SchemaGeneratorOutput[Q] = chain.invoke({}, config)
except PydanticOutputParserException as e:
# Generation step is expensive. After a second unsuccessful attempt, it's better to send a failure message.
if len(intermediate_steps) >= 2:
return {
"messages": [
FailureMessage(
content=f"Oops! It looks like I’m having trouble generating this {self.insight_name} insight. Could you please try again?"
content=f"Oops! It looks like I’m having trouble generating this {self.INSIGHT_NAME} insight. Could you please try again?"
)
],
"intermediate_steps": None,
Expand Down Expand Up @@ -119,7 +123,7 @@ def _group_mapping_prompt(self) -> str:
)
return ET.tostring(root, encoding="unicode")

def _reconstruct_conversation(
def _construct_messages(
self, state: AssistantState, validation_error_message: Optional[str] = None
) -> list[BaseMessage]:
"""
Expand All @@ -132,7 +136,7 @@ def _reconstruct_conversation(
return []

conversation: list[BaseMessage] = [
HumanMessagePromptTemplate.from_template(group_mapping_prompt, template_format="mustache").format(
HumanMessagePromptTemplate.from_template(GROUP_MAPPING_PROMPT, template_format="mustache").format(
group_mapping=self._group_mapping_prompt
)
]
Expand All @@ -144,22 +148,22 @@ def _reconstruct_conversation(
if ai_message:
conversation.append(
HumanMessagePromptTemplate.from_template(
plan_prompt if first_ai_message else new_plan_prompt,
PLAN_PROMPT if first_ai_message else NEW_PLAN_PROMPT,
template_format="mustache",
).format(plan=ai_message.plan or "")
)
first_ai_message = False
elif generated_plan:
conversation.append(
HumanMessagePromptTemplate.from_template(
plan_prompt if first_ai_message else new_plan_prompt,
PLAN_PROMPT if first_ai_message else NEW_PLAN_PROMPT,
template_format="mustache",
).format(plan=generated_plan)
)

if human_message:
conversation.append(
HumanMessagePromptTemplate.from_template(question_prompt, template_format="mustache").format(
HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
question=human_message.content
)
)
Expand All @@ -171,7 +175,7 @@ def _reconstruct_conversation(

if validation_error_message:
conversation.append(
HumanMessagePromptTemplate.from_template(failover_prompt, template_format="mustache").format(
HumanMessagePromptTemplate.from_template(FAILOVER_PROMPT, template_format="mustache").format(
validation_error_message=validation_error_message
)
)
Expand All @@ -191,7 +195,7 @@ def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:

action, _ = intermediate_steps[-1]
prompt = (
ChatPromptTemplate.from_template(failover_output_prompt, template_format="mustache")
ChatPromptTemplate.from_template(FAILOVER_OUTPUT_PROMPT, template_format="mustache")
.format_messages(output=action.tool_input, exception_message=action.log)[0]
.content
)
Expand Down
12 changes: 6 additions & 6 deletions ee/hogai/schema_generator/prompts.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
group_mapping_prompt = """
GROUP_MAPPING_PROMPT = """
Here is the group mapping:
{{group_mapping}}
"""

plan_prompt = """
PLAN_PROMPT = """
Here is the plan:
{{plan}}
"""

new_plan_prompt = """
NEW_PLAN_PROMPT = """
Here is the new plan:
{{plan}}
"""

question_prompt = """
QUESTION_PROMPT = """
Answer to this question: {{question}}
"""

failover_output_prompt = """
FAILOVER_OUTPUT_PROMPT = """
Generation output:
```
{{output}}
Expand All @@ -29,7 +29,7 @@
```
"""

failover_prompt = """
FAILOVER_PROMPT = """
The result of the previous generation raised the Pydantic validation exception.

{{validation_error_message}}
Expand Down
Loading