Skip to content

Commit ee985fd

Browse files
authored
A2A streaming using Dapr pub sub (#4)
* feat: A2A streaming works using Dapr pub-sub but things need to be cleaned up. * fix: Corrected LaTeX math-mode display in Gradio chatbot by adding escaped square braces characters as LaTeX delimiters. * chore: Added an environment variable to optionally specify a remote A2A MHQA endpoint. * fix: Corrected parsing issues in the A2A CLI client for both streaming and non-streaming modes. note: Read the A2A protocol documentation for details. * chore: Reduced the number of messages the actor posts to the pub-sub queue. fix: Corrected message statuses to _complete_ in the message history. * chore: Added yfmcp MCP server for stock tickers. chore: Minor UI modifications. * fix: Corrected test condition for total number of MCP tools.
1 parent 920cdf0 commit ee985fd

File tree

16 files changed

+957
-482
lines changed

16 files changed

+957
-482
lines changed

.dapr/components/pubsub.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ spec:
1010
value: localhost:6379
1111
- name: redisPassword
1212
value: ""
13+
scopes:
14+
- dapr-srv
15+
- mhqa-a2a-srv

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ The following environment variables are also relevant but not essential, except
4747
- `BROWSER_STATE_SECRET`: This is the secret used by Gradio to encrypt the browser state data. The default value is `a2a_dapr_bstate_secret`.
4848
- `BROWSER_STATE_CHAT_HISTORIES`: This is the key in browser state used by Gradio to store the chat histories (local values). The default value is `a2a_dapr_chat_histories`.
4949
- `APP_DAPR_SVC_HOST` and `APP_DAPR_SVC_PORT`: The host and port at which Dapr actor service will listen on. These default to `127.0.0.1` and `32768`. Should you change these, you must change the corresponding information in `dapr.yaml`.
50+
- `APP_DAPR_PUBSUB_STALE_MSG_SECS`: This specifies how old a message should be on the Dapr publish-subscribe topic queue before it will be considered too old, and dropped. The default value is 60 seconds.
5051
- `DAPR_PUBSUB_NAME`: The configured name of the publish-subscribe component at `.dapr/components/pubsub.yaml`. Change this environment variable only if you change the corresponding pub-sub component configuration.
5152
- `APP_A2A_SRV_HOST` and `APP_MHQA_A2A_SRV_PORT`: The host and port at which A2A endpoint will be available. These default to `127.0.0.1` and `32770`. Should you change these, you must change the corresponding information in `dapr.yaml`.
53+
- `APP_MHQA_A2A_REMOTE_URL`: This environment variable can be used to specify the full remote URL including the protocol, i.e., `http` or `https` where the MHQA A2A endpoint is available. This is useful in a scenario where the web app is deployed on a machine that is different from where the MHQA A2A endpoint and Dapr service are. Default value is `None`.
5254

5355
## Usage
5456

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ authors = [
99
requires-python = ">=3.12.0"
1010
dependencies = [
1111
"a2a-sdk[http-server]>=0.3.6",
12+
"anyio>=4.11.0",
1213
"dapr>=1.16.0",
1314
"dapr-ext-fastapi>=1.16.0",
1415
"environs>=14.3.0",
@@ -19,6 +20,7 @@ dependencies = [
1920
"llama-index-tools-mcp>=0.4.1",
2021
"ollama>=0.6.0",
2122
"typer>=0.19.1",
23+
"yfmcp>=0.4.8",
2224
]
2325

2426
[project.scripts]

src/dqa/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ class ParsedEnvVars:
3030
MCP_CONFIG_FILE: str = env.str("MCP_CONFIG_FILE", default="conf/mcp.json")
3131
APP_DAPR_SVC_HOST: str = env.str("APP_DAPR_SVC_HOST", default="127.0.0.1")
3232
APP_DAPR_SVC_PORT: int = env.int("APP_DAPR_SVC_PORT", default=32768)
33+
APP_DAPR_PUBSUB_STALE_MSG_SECS: int = env.int(
34+
"APP_DAPR_PUBSUB_STALE_MSG_SECS", default=60
35+
)
3336
APP_A2A_SRV_HOST: str = env.str("APP_A2A_SRV_HOST", default="127.0.0.1")
3437
APP_MHQA_A2A_SRV_PORT: int = env.int("APP_MHQA_A2A_SRV_PORT", default=32770)
38+
APP_MHQA_A2A_REMOTE_URL: str = env.str("APP_MHQA_A2A_REMOTE_URL", default=None)
3539
APP_ECHO_A2A_SRV_PORT: int = env.int("APP_ECHO_A2A_SRV_PORT", default=32769)
3640
DAPR_PUBSUB_NAME: str = env.str("DAPR_PUBSUB_NAME", default="pubsub")
3741
MCP_SERVER_HOST: str = env.str("FASTMCP_HOST", default="localhost")
@@ -57,5 +61,5 @@ def __new__(cls: type["ParsedEnvVars"]) -> "ParsedEnvVars":
5761
level=ParsedEnvVars().APP_LOG_LEVEL,
5862
format="%(message)s",
5963
datefmt="[%X]",
60-
handlers=[RichHandler()],
64+
handlers=[RichHandler(show_time=True, show_level=True, show_path=True)],
6165
)

src/dqa/actor/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from enum import StrEnum, auto
1+
from enum import StrEnum
22

33

44
class MHQAActorMethods(StrEnum):
5-
Respond = auto()
6-
GetChatHistory = auto()
7-
ResetChatHistory = auto()
8-
Cancel = auto()
5+
Respond = "Respond"
6+
GetChatHistory = "GetChatHistory"
7+
ResetChatHistory = "ResetChatHistory"
8+
Cancel = "Cancel"

src/dqa/actor/mhqa.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from llama_index.core.base.llms.types import ChatMessage, MessageRole
1212
from llama_index.core.tools.types import ToolOutput
1313

14+
1415
from llama_index.core.agent.workflow import (
1516
AgentOutput,
1617
ToolCall,
@@ -19,7 +20,6 @@
1920
AgentWorkflow,
2021
FunctionAgent,
2122
)
22-
2323
from abc import abstractmethod
2424
import os
2525
from dapr.actor import Actor, ActorInterface, actormethod
@@ -28,7 +28,8 @@
2828

2929
from dqa import ParsedEnvVars
3030
from dqa.actor import MHQAActorMethods
31-
from dqa.model.mhqa import MCPToolInvocation, MHQAResponse
31+
from dqa.actor.pubsub_topics import PubSubTopics
32+
from dqa.model.mhqa import MCPToolInvocation, MHQAResponse, MHQAResponseStatus
3233

3334

3435
logger = logging.getLogger(__name__)
@@ -221,6 +222,7 @@ async def respond(self, data: dict) -> dict:
221222
)
222223
full_response = ""
223224
tool_invocations: List[MCPToolInvocation] = []
225+
pubsub_topic_name = f"{PubSubTopics.MHQA_RESPONSE}/{self.id}"
224226
with DaprClient() as dc:
225227
async for ev in wf_handler.stream_events():
226228
if isinstance(ev, AgentStream):
@@ -260,11 +262,24 @@ async def respond(self, data: dict) -> dict:
260262
agent_output=full_response,
261263
tool_invocations=tool_invocations,
262264
)
263-
dc.publish_event(
264-
pubsub_name=ParsedEnvVars().DAPR_PUBSUB_NAME,
265-
topic_name=f"topic-{self.__class__.__name__}-{self.id}-respond",
266-
data=response.model_dump_json().encode(),
267-
)
265+
if (
266+
isinstance(ev, AgentStream)
267+
and ev.delta.strip() != ""
268+
and response
269+
and response.agent_output.strip() != ""
270+
):
271+
# logger.info(f"Publishing: {response.agent_output}")
272+
dc.publish_event(
273+
pubsub_name=ParsedEnvVars().DAPR_PUBSUB_NAME,
274+
topic_name=pubsub_topic_name,
275+
data=response.model_dump_json().encode(),
276+
)
277+
response.status = MHQAResponseStatus.completed
278+
dc.publish_event(
279+
pubsub_name=ParsedEnvVars().DAPR_PUBSUB_NAME,
280+
topic_name=pubsub_topic_name,
281+
data=response.model_dump_json().encode(),
282+
)
268283
memory_messages = await self.workflow_memory.aget_all()
269284
await self._state_manager.set_state(
270285
self._chat_memory_key,
@@ -293,6 +308,7 @@ async def get_chat_history(self) -> list:
293308
user_input=user_input,
294309
agent_output=agent_output,
295310
tool_invocations=tool_invocations,
311+
status=MHQAResponseStatus.completed,
296312
)
297313
)
298314
tool_name = ""

src/dqa/actor/pubsub_topics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class PubSubTopics:
2+
MHQA_RESPONSE = "ps-topics/mhqa-response"

src/dqa/cli/a2a.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def _initialize(self):
7373
self.echo_base_url = f"http://{a2a_asgi_host}:{echo_a2a_asgi_port}"
7474

7575
mhqa_a2a_asgi_port = ParsedEnvVars().APP_MHQA_A2A_SRV_PORT
76-
self.mhqa_base_url = f"http://{a2a_asgi_host}:{mhqa_a2a_asgi_port}"
76+
self.mhqa_base_url = (
77+
ParsedEnvVars().APP_MHQA_A2A_REMOTE_URL
78+
or f"http://{a2a_asgi_host}:{mhqa_a2a_asgi_port}"
79+
)
7780
logger.debug(f"Echo A2A base URL: {self.echo_base_url}")
7881
logger.debug(f"MHQA A2A base URL: {self.mhqa_base_url}")
7982

@@ -277,8 +280,8 @@ async def _mhqa_chat(
277280
logger.info("Parsing streaming response from the A2A endpoint")
278281
full_message_content = ""
279282
async for response in streaming_response:
280-
if isinstance(response, Message):
281-
full_message_content += get_message_text(response)
283+
if response[0].status.message:
284+
full_message_content = get_message_text(response[0].status.message)
282285
validated_response = MHQAResponse.model_validate_json(full_message_content)
283286
return validated_response
284287

@@ -324,8 +327,8 @@ async def _mhqa_get_history(
324327
logger.info("Parsing streaming response from the A2A endpoint")
325328
full_message_content = ""
326329
async for response in streaming_response:
327-
if isinstance(response, Message):
328-
full_message_content += get_message_text(response)
330+
if response[0].status.message:
331+
full_message_content = get_message_text(response[0].status.message)
329332
response_adapter = TypeAdapter(List[MHQAResponse])
330333
validated_response = response_adapter.validate_json(full_message_content)
331334
validated_response = validated_response[
@@ -374,8 +377,8 @@ async def _mhqa_delete_history(
374377
logger.info("Parsing streaming response from the A2A endpoint")
375378
full_message_content = ""
376379
async for response in streaming_response:
377-
if isinstance(response, Message):
378-
full_message_content += get_message_text(response)
380+
if response[0].status.message:
381+
full_message_content = get_message_text(response[0].status.message)
379382
return full_message_content
380383

381384
async def run_mhqa_delete_history(
@@ -387,7 +390,9 @@ async def run_mhqa_delete_history(
387390
response = await self._mhqa_delete_history(
388391
thread_id=thread_id,
389392
)
390-
print(response)
393+
print(
394+
f"Deletion of thread '{thread_id}': {'successful' if response == 'true' else 'failed; maybe the thread does not exist?'}"
395+
)
391396
except Exception as e:
392397
logger.error(f"Error in MHQA delete history. {e}")
393398
finally:
@@ -486,7 +491,7 @@ def mhqa_get_history(
486491

487492

488493
@app.command()
489-
async def mhqa_delete_history(
494+
def mhqa_delete_history(
490495
thread_id: str = typer.Option(
491496
help="A thread ID to identify your conversation.",
492497
),

src/dqa/client/a2a_mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ async def obtain_a2a_client(
2424
final_agent_card_to_use: AgentCard | None = None
2525

2626
logger.info(
27-
f"Attempting to fetch public agent card from: {base_url}{AGENT_CARD_WELL_KNOWN_PATH}"
27+
f"Fetching A2A agent card from: {base_url}{AGENT_CARD_WELL_KNOWN_PATH}"
2828
)
2929
_public_card = (
3030
await resolver.get_agent_card()
3131
) # Fetches from default public path
3232
logger.info("Successfully fetched public agent card.")
33-
logger.info(_public_card.model_dump_json(indent=2, exclude_none=True))
33+
logger.debug(_public_card.model_dump_json(indent=2, exclude_none=True))
3434
final_agent_card_to_use = _public_card
3535

3636
client = ClientFactory(
3737
config=ClientConfig(streaming=True, polling=True, httpx_client=httpx_client)
3838
).create(card=final_agent_card_to_use)
39-
logger.info("A2A client initialised.")
39+
logger.debug("A2A client initialised.")
4040
return client, final_agent_card_to_use

0 commit comments

Comments
 (0)