Skip to content

Commit 3eaf42c

Browse files
authored
Add support for new event types (#86)
* Add support for new event types * Add tests for new event types
1 parent c577b2b commit 3eaf42c

File tree

3 files changed

+176
-27
lines changed

3 files changed

+176
-27
lines changed

durabletask/internal/helpers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def new_orchestrator_started_event(timestamp: Optional[datetime] = None) -> pb.H
2020
return pb.HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent())
2121

2222

23+
def new_orchestrator_completed_event() -> pb.HistoryEvent:
24+
return pb.HistoryEvent(eventId=-1, timestamp=timestamp_pb2.Timestamp(),
25+
orchestratorCompleted=pb.OrchestratorCompletedEvent())
26+
27+
2328
def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None,
2429
tags: Optional[dict[str, str]] = None) -> pb.HistoryEvent:
2530
return pb.HistoryEvent(
@@ -119,6 +124,18 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails:
119124
)
120125

121126

127+
def new_event_sent_event(event_id: int, instance_id: str, input: str):
128+
return pb.HistoryEvent(
129+
eventId=event_id,
130+
timestamp=timestamp_pb2.Timestamp(),
131+
eventSent=pb.EventSentEvent(
132+
name="",
133+
input=get_string_value(input),
134+
instanceId=instance_id
135+
)
136+
)
137+
138+
122139
def new_event_raised_event(name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent:
123140
return pb.HistoryEvent(
124141
eventId=-1,

durabletask/worker.py

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from threading import Event, Thread
1313
from types import GeneratorType
1414
from enum import Enum
15-
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
15+
from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
1616
import uuid
1717
from packaging.version import InvalidVersion, parse
1818

@@ -832,6 +832,7 @@ def __init__(self, instance_id: str, registry: _Registry):
832832
self._pending_tasks: dict[int, task.CompletableTask] = {}
833833
# Maps entity ID to task ID
834834
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
835+
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
835836
# Maps criticalSectionId to task ID
836837
self._entity_lock_id_map: dict[str, int] = {}
837838
self._sequence_number = 0
@@ -1606,33 +1607,40 @@ def process_event(
16061607
else:
16071608
raise TypeError("Unexpected sub-orchestration task type")
16081609
elif event.HasField("eventRaised"):
1609-
# event names are case-insensitive
1610-
event_name = event.eventRaised.name.casefold()
1611-
if not ctx.is_replaying:
1612-
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1613-
task_list = ctx._pending_events.get(event_name, None)
1614-
decoded_result: Optional[Any] = None
1615-
if task_list:
1616-
event_task = task_list.pop(0)
1617-
if not ph.is_empty(event.eventRaised.input):
1618-
decoded_result = shared.from_json(event.eventRaised.input.value)
1619-
event_task.complete(decoded_result)
1620-
if not task_list:
1621-
del ctx._pending_events[event_name]
1622-
ctx.resume()
1610+
if event.eventRaised.name in ctx._entity_task_id_map:
1611+
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1612+
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
1613+
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
1614+
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
1615+
self._handle_entity_event_raised(ctx, event, entity_id, task_id, True)
16231616
else:
1624-
# buffer the event
1625-
event_list = ctx._received_events.get(event_name, None)
1626-
if not event_list:
1627-
event_list = []
1628-
ctx._received_events[event_name] = event_list
1629-
if not ph.is_empty(event.eventRaised.input):
1630-
decoded_result = shared.from_json(event.eventRaised.input.value)
1631-
event_list.append(decoded_result)
1617+
# event names are case-insensitive
1618+
event_name = event.eventRaised.name.casefold()
16321619
if not ctx.is_replaying:
1633-
self._logger.info(
1634-
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1635-
)
1620+
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1621+
task_list = ctx._pending_events.get(event_name, None)
1622+
decoded_result: Optional[Any] = None
1623+
if task_list:
1624+
event_task = task_list.pop(0)
1625+
if not ph.is_empty(event.eventRaised.input):
1626+
decoded_result = shared.from_json(event.eventRaised.input.value)
1627+
event_task.complete(decoded_result)
1628+
if not task_list:
1629+
del ctx._pending_events[event_name]
1630+
ctx.resume()
1631+
else:
1632+
# buffer the event
1633+
event_list = ctx._received_events.get(event_name, None)
1634+
if not event_list:
1635+
event_list = []
1636+
ctx._received_events[event_name] = event_list
1637+
if not ph.is_empty(event.eventRaised.input):
1638+
decoded_result = shared.from_json(event.eventRaised.input.value)
1639+
event_list.append(decoded_result)
1640+
if not ctx.is_replaying:
1641+
self._logger.info(
1642+
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1643+
)
16361644
elif event.HasField("executionSuspended"):
16371645
if not self._is_suspended and not ctx.is_replaying:
16381646
self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -1760,6 +1768,21 @@ def process_event(
17601768
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
17611769
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
17621770
pass
1771+
elif event.HasField("orchestratorCompleted"):
1772+
# Added in Functions only (for some reason) and does not affect orchestrator flow
1773+
pass
1774+
elif event.HasField("eventSent"):
1775+
# Check if this eventSent corresponds to an entity operation call after being translated to the old
1776+
# entity protocol by the Durable WebJobs extension. If so, treat this message similarly to
1777+
# entityOperationCalled and remove the pending action. Also store the entity id and event id for later
1778+
action = ctx._pending_actions.pop(event.eventId, None)
1779+
if action and action.HasField("sendEntityMessage"):
1780+
if action.sendEntityMessage.HasField("entityOperationCalled"):
1781+
entity_id, event_id = self._parse_entity_event_sent_input(event)
1782+
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
1783+
elif action.sendEntityMessage.HasField("entityLockRequested"):
1784+
entity_id, event_id = self._parse_entity_event_sent_input(event)
1785+
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
17631786
else:
17641787
eventType = event.WhichOneof("eventType")
17651788
raise task.OrchestrationStateError(
@@ -1769,6 +1792,44 @@ def process_event(
17691792
# The orchestrator generator function completed
17701793
ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED)
17711794

1795+
def _parse_entity_event_sent_input(self, event: pb.HistoryEvent) -> Tuple[EntityInstanceId, str]:
1796+
try:
1797+
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1798+
except ValueError:
1799+
raise RuntimeError(f"Could not parse entity ID from instanceId '{event.eventSent.instanceId}'")
1800+
try:
1801+
event_id = json.loads(event.eventSent.input.value)["id"]
1802+
except (json.JSONDecodeError, KeyError, TypeError) as ex:
1803+
raise RuntimeError(f"Could not parse event ID from eventSent input '{event.eventSent.input.value}'") from ex
1804+
return entity_id, event_id
1805+
1806+
def _handle_entity_event_raised(self,
1807+
ctx: _RuntimeOrchestrationContext,
1808+
event: pb.HistoryEvent,
1809+
entity_id: Optional[EntityInstanceId],
1810+
task_id: Optional[int],
1811+
is_lock_event: bool):
1812+
# This eventRaised represents the result of an entity operation after being translated to the old
1813+
# entity protocol by the Durable WebJobs extension
1814+
if entity_id is None:
1815+
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1816+
if task_id is None:
1817+
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1818+
entity_task = ctx._pending_tasks.pop(task_id, None)
1819+
if not entity_task:
1820+
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1821+
result = None
1822+
if not ph.is_empty(event.eventRaised.input):
1823+
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1824+
result = shared.from_json(event.eventRaised.input.value)["result"]
1825+
if is_lock_event:
1826+
ctx._entity_context.complete_acquire(event.eventRaised.name)
1827+
entity_task.complete(EntityLock(ctx))
1828+
else:
1829+
ctx._entity_context.recover_lock_after_call(entity_id)
1830+
entity_task.complete(result)
1831+
ctx.resume()
1832+
17721833
def evaluate_orchestration_versioning(self, versioning: Optional[VersioningOptions], orchestration_version: Optional[str]) -> Optional[pb.TaskFailureDetails]:
17731834
if versioning is None:
17741835
return None

tests/durabletask/test_orchestration_executor.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import durabletask.internal.helpers as helpers
1111
import durabletask.internal.orchestrator_service_pb2 as pb
12-
from durabletask import task, worker
12+
from durabletask import task, worker, entities
1313

1414
logging.basicConfig(
1515
format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s',
@@ -1183,6 +1183,77 @@ def orchestrator(ctx: task.OrchestrationContext, _):
11831183
assert str(ex) in complete_action.failureDetails.errorMessage
11841184

11851185

1186+
def test_orchestrator_completed_no_effect():
1187+
def dummy_activity(ctx, _):
1188+
pass
1189+
1190+
def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
1191+
yield ctx.call_activity(dummy_activity, input=orchestrator_input)
1192+
1193+
registry = worker._Registry()
1194+
name = registry.add_orchestrator(orchestrator)
1195+
1196+
encoded_input = json.dumps(42)
1197+
new_events = [
1198+
helpers.new_orchestrator_started_event(),
1199+
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input),
1200+
helpers.new_orchestrator_completed_event()]
1201+
executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
1202+
result = executor.execute(TEST_INSTANCE_ID, [], new_events)
1203+
actions = result.actions
1204+
1205+
assert len(actions) == 1
1206+
assert type(actions[0]) is pb.OrchestratorAction
1207+
assert actions[0].id == 1
1208+
assert actions[0].HasField("scheduleTask")
1209+
assert actions[0].scheduleTask.name == task.get_name(dummy_activity)
1210+
assert actions[0].scheduleTask.input.value == encoded_input
1211+
1212+
1213+
def test_entity_lock_created_as_event():
1214+
test_entity_id = entities.EntityInstanceId("Counter", "myCounter")
1215+
1216+
def orchestrator(ctx: task.OrchestrationContext, _):
1217+
entity_id = test_entity_id
1218+
with (yield ctx.lock_entities([entity_id])):
1219+
return (yield ctx.call_entity(entity_id, "set", 1))
1220+
1221+
registry = worker._Registry()
1222+
name = registry.add_orchestrator(orchestrator)
1223+
1224+
new_events = [
1225+
helpers.new_orchestrator_started_event(),
1226+
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, None),
1227+
]
1228+
1229+
executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
1230+
result1 = executor.execute(TEST_INSTANCE_ID, [], new_events)
1231+
actions = result1.actions
1232+
assert len(actions) == 1
1233+
assert type(actions[0]) is pb.OrchestratorAction
1234+
assert actions[0].id == 1
1235+
assert actions[0].HasField("sendEntityMessage")
1236+
assert actions[0].sendEntityMessage.HasField("entityLockRequested")
1237+
1238+
old_events = new_events
1239+
event_sent_input = {
1240+
"id": actions[0].sendEntityMessage.entityLockRequested.criticalSectionId,
1241+
}
1242+
new_events = [
1243+
helpers.new_event_sent_event(1, str(test_entity_id), json.dumps(event_sent_input)),
1244+
helpers.new_event_raised_event(event_sent_input["id"], None),
1245+
]
1246+
result = executor.execute(TEST_INSTANCE_ID, old_events, new_events)
1247+
actions = result.actions
1248+
1249+
assert len(actions) == 1
1250+
assert type(actions[0]) is pb.OrchestratorAction
1251+
assert actions[0].id == 2
1252+
assert actions[0].HasField("sendEntityMessage")
1253+
assert actions[0].sendEntityMessage.HasField("entityOperationCalled")
1254+
assert actions[0].sendEntityMessage.entityOperationCalled.targetInstanceId.value == str(test_entity_id)
1255+
1256+
11861257
def get_and_validate_complete_orchestration_action_list(expected_action_count: int, actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction:
11871258
assert len(actions) == expected_action_count
11881259
assert type(actions[-1]) is pb.OrchestratorAction

0 commit comments

Comments
 (0)