|
11 | 11 | from a2a.types import TaskState |
12 | 12 |
|
13 | 13 |
|
14 | | -from dqa import EnvVars |
| 14 | +from dqa import EnvVars, ic |
15 | 15 | from dqa.actor.mhqa import MHQAActor, MHQAActorInterface, MHQAActorMethods |
16 | 16 | from dqa.actor.pubsub_topics import PubSubTopics |
17 | 17 | from dqa.model.mhqa import ( |
@@ -43,6 +43,8 @@ async def do_mhqa_respond(self, data: MHQAInput): |
43 | 43 | EnvVars.APP_DAPR_PUBSUB_MEMORY_STREAM_BUFFER_SIZE |
44 | 44 | ) |
45 | 45 |
|
| 46 | + dc = DaprClient() |
| 47 | + |
46 | 48 | def pubsub_message_handler(message: SubscriptionMessage) -> TopicEventResponse: |
47 | 49 | # TODO: Is this a reasonable way to drop stale messages? |
48 | 50 | parsed_timestamp = message.extensions().get("time", None) |
@@ -73,20 +75,28 @@ async def invoke_actor(): |
73 | 75 | raw_body=data.model_dump_json().encode(), |
74 | 76 | ) |
75 | 77 |
|
76 | | - with DaprClient() as dc: |
77 | | - async with anyio.create_task_group() as tg: |
78 | | - pubsub_topic_name = f"{PubSubTopics.MHQA_RESPONSE}/{data.thread_id}" |
| 78 | + pubsub_topic_name = f"{PubSubTopics.MHQA_RESPONSE}/{data.thread_id}" |
| 79 | + |
| 80 | + async with anyio.create_task_group() as tg: |
| 81 | + try: |
79 | 82 | dc.subscribe_with_handler( |
80 | 83 | pubsub_name=EnvVars.DAPR_PUBSUB_NAME, |
81 | 84 | topic=pubsub_topic_name, |
82 | 85 | handler_fn=pubsub_message_handler, |
83 | 86 | ) |
84 | | - |
85 | | - tg.start_soon(invoke_actor) |
| 87 | + tg.start_soon(invoke_actor, name=invoke_actor.__name__) |
86 | 88 | # FIXME: Error "Attempted to exit cancel scope in a different task than it was entered in". |
87 | | - async with receive_stream: |
88 | | - async for item in receive_stream: |
89 | | - yield item |
| 89 | + async for item in receive_stream: |
| 90 | + yield item |
| 91 | + tg.cancel_scope.cancel() |
| 92 | + ic("Cancelled anyio task group") |
| 93 | + except Exception as e: |
| 94 | + logger.exception(e) |
| 95 | + finally: |
| 96 | + await receive_stream.aclose() |
| 97 | + await send_stream.aclose() |
| 98 | + dc.close() |
| 99 | + ic("Receive and send streams closed") |
90 | 100 |
|
91 | 101 | async def do_mhqa_get_history(self, data: MHQAHistoryInput) -> str: |
92 | 102 | proxy = ActorProxy.create( |
|
0 commit comments