Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instrument task runs #15955

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
138 changes: 124 additions & 14 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@

import anyio
import pendulum
from opentelemetry import trace
from opentelemetry.trace import (
Status,
StatusCode,
Tracer,
get_tracer,
)
from typing_extensions import ParamSpec

import prefect
from prefect import Task
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
from prefect.client.schemas import TaskRun
Expand Down Expand Up @@ -99,6 +107,12 @@
BACKOFF_MAX = 10


def get_labels_from_context(context: Optional[FlowRunContext]) -> Dict[str, Any]:
if context is None:
return {}
return context.flow_run.labels


class TaskRunTimeoutError(TimeoutError):
"""Raised when a task run exceeds its timeout."""

Expand All @@ -120,6 +134,9 @@ class BaseTaskRunEngine(Generic[P, R]):
_is_started: bool = False
_task_name_set: bool = False
_last_event: Optional[PrefectEvent] = None
_tracer: Tracer = field(
default_factory=lambda: get_tracer("prefect", prefect.__version__)
)

def __post_init__(self):
if self.parameters is None:
Expand Down Expand Up @@ -460,7 +477,15 @@ def set_state(self, state: State, force: bool = False) -> State:
validated_state=self.task_run.state,
follows=self._last_event,
)

self._span.add_event(
new_state.name,
{
"prefect.state.message": new_state.message or "",
"prefect.state.type": new_state.type,
"prefect.state.name": new_state.name or new_state.type,
"prefect.state.id": str(new_state.id),
},
)
return new_state

def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -514,6 +539,11 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
self.record_terminal_state_timing(terminal_state)
self.set_state(terminal_state)
self._return_value = result

self._span.set_status(Status(StatusCode.OK), terminal_state.message)
self._span.end(time.time_ns())
self._span = None

return result

def handle_retry(self, exc: Exception) -> bool:
Expand Down Expand Up @@ -562,6 +592,7 @@ def handle_retry(self, exc: Exception) -> bool:

def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
self._span.record_exception(exc)
if not self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
state = run_coro_as_sync(
Expand All @@ -576,6 +607,10 @@ def handle_exception(self, exc: Exception) -> None:
self.set_state(state)
self._raised = exc

self._span.set_status(Status(StatusCode.ERROR, state.message))
self._span.end(time.time_ns())
self._span = None

def handle_timeout(self, exc: TimeoutError) -> None:
if not self.handle_retry(exc):
if isinstance(exc, TaskRunTimeoutError):
Expand All @@ -599,6 +634,11 @@ def handle_crash(self, exc: BaseException) -> None:
self.set_state(state, force=True)
self._raised = exc

self._span.record_exception(exc)
self._span.set_status(Status(StatusCode.ERROR, state.message))
self._span.end(time.time_ns())
self._span = None

@contextmanager
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
from prefect.utilities.engine import (
Expand Down Expand Up @@ -655,14 +695,17 @@ def initialize_run(
with SyncClientContext.get_or_create() as client_ctx:
self._client = client_ctx.client
self._is_started = True
flow_run_context = FlowRunContext.get()
parent_task_run_context = TaskRunContext.get()

try:
if not self.task_run:
self.task_run = run_coro_as_sync(
self.task.create_local_run(
id=task_run_id,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
parent_task_run_context=TaskRunContext.get(),
flow_run_context=flow_run_context,
parent_task_run_context=parent_task_run_context,
wait_for=self.wait_for,
extra_task_inputs=dependencies,
)
Expand All @@ -679,6 +722,23 @@ def initialize_run(
self.logger.debug(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
)

labels = get_labels_from_context(flow_run_context)
parameter_attributes = {
f"prefect.run.parameter.{k}": type(v).__name__
for k, v in self.parameters.items()
}
self._span = self._tracer.start_span(
name=self.task_run.name,
attributes={
"prefect.run.type": "task",
"prefect.run.id": str(self.task_run.id),
"prefect.tags": self.task_run.tags,
**parameter_attributes,
**labels,
},
)

yield self

except TerminationSignal as exc:
Expand Down Expand Up @@ -730,11 +790,12 @@ def start(
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
) -> Generator[None, None, None]:
with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
self.begin_run()
try:
yield
finally:
self.call_hooks()
with trace.use_span(self._span):
self.begin_run()
try:
yield
finally:
self.call_hooks()

@contextmanager
def transaction_context(self) -> Generator[Transaction, None, None]:
Expand Down Expand Up @@ -977,6 +1038,16 @@ async def set_state(self, state: State, force: bool = False) -> State:
follows=self._last_event,
)

self._span.add_event(
new_state.name,
{
"prefect.state.message": new_state.message or "",
"prefect.state.type": new_state.type,
"prefect.state.name": new_state.name or new_state.type,
"prefect.state.id": str(new_state.id),
},
)

return new_state

async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -1025,6 +1096,11 @@ async def handle_success(self, result: R, transaction: Transaction) -> R:
self.record_terminal_state_timing(terminal_state)
await self.set_state(terminal_state)
self._return_value = result

self._span.set_status(Status(StatusCode.OK, terminal_state.message))
self._span.end(time.time_ns())
self._span = None

return result

async def handle_retry(self, exc: Exception) -> bool:
Expand Down Expand Up @@ -1073,6 +1149,7 @@ async def handle_retry(self, exc: Exception) -> bool:

async def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
self._span.record_exception(exc)
if not await self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
state = await exception_to_failed_state(
Expand All @@ -1083,8 +1160,12 @@ async def handle_exception(self, exc: Exception) -> None:
self.record_terminal_state_timing(state)
await self.set_state(state)
self._raised = exc
self._span.set_status(Status(StatusCode.ERROR, state.message))
self._span.end(time.time_ns())
self._span = None

async def handle_timeout(self, exc: TimeoutError) -> None:
self._span.record_exception(exc)
if not await self.handle_retry(exc):
if isinstance(exc, TaskRunTimeoutError):
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
Expand All @@ -1098,6 +1179,9 @@ async def handle_timeout(self, exc: TimeoutError) -> None:
)
await self.set_state(state)
self._raised = exc
self._span.set_status(Status(StatusCode.ERROR, state.message))
self._span.end(time.time_ns())
self._span = None

async def handle_crash(self, exc: BaseException) -> None:
state = await exception_to_crashed_state(exc)
Expand All @@ -1107,6 +1191,11 @@ async def handle_crash(self, exc: BaseException) -> None:
await self.set_state(state, force=True)
self._raised = exc

self._span.record_exception(exc)
self._span.set_status(Status(StatusCode.ERROR, state.message))
self._span.end(time.time_ns())
self._span = None

@asynccontextmanager
async def setup_run_context(self, client: Optional[PrefectClient] = None):
from prefect.utilities.engine import (
Expand Down Expand Up @@ -1162,12 +1251,14 @@ async def initialize_run(
async with AsyncClientContext.get_or_create():
self._client = get_client()
self._is_started = True
flow_run_context = FlowRunContext.get()

try:
if not self.task_run:
self.task_run = await self.task.create_local_run(
id=task_run_id,
parameters=self.parameters,
flow_run_context=FlowRunContext.get(),
flow_run_context=flow_run_context,
parent_task_run_context=TaskRunContext.get(),
wait_for=self.wait_for,
extra_task_inputs=dependencies,
Expand All @@ -1184,6 +1275,24 @@ async def initialize_run(
self.logger.debug(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
)

labels = get_labels_from_context(flow_run_context)

parameter_attributes = {
f"prefect.run.parameter.{k}": type(v).__name__
for k, v in self.parameters.items()
}
self._span = self._tracer.start_span(
name=self.task_run.name,
attributes={
"prefect.run.type": "task",
"prefect.run.id": str(self.task_run.id),
"prefect.tags": self.task_run.tags,
**parameter_attributes,
**labels,
},
)

yield self

except TerminationSignal as exc:
Expand Down Expand Up @@ -1237,11 +1346,12 @@ async def start(
async with self.initialize_run(
task_run_id=task_run_id, dependencies=dependencies
):
await self.begin_run()
try:
yield
finally:
await self.call_hooks()
with trace.use_span(self._span):
await self.begin_run()
try:
yield
finally:
await self.call_hooks()

@asynccontextmanager
async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
Expand Down
Loading
Loading