Skip to content

Commit

Permalink
more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano committed Nov 12, 2024
1 parent c03c73c commit 943cfe6
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ class FlowRun(ObjectBaseModel):
description="A list of tags on the flow run",
examples=[["tag-1", "tag-2"]],
)
labels: Dict[str, str] = Field(
default_factory=dict,
description="A dictionary of labels on the flow run",
examples=[{"my-label": "my-value"}],
)
parent_task_run_id: Optional[UUID] = Field(
default=None,
description=(
Expand Down
7 changes: 6 additions & 1 deletion src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def digest_task_inputs(inputs, parameters) -> Tuple[Dict[str, str], list[Link]]:
links = []
for key, value in inputs.items():
for input in value:
if input in ("__parents__", "wait_for"):
continue
if isinstance(input, TaskRunInput):
parameter_attributes[f"prefect.run.parameter.{key}"] = type(
parameters[key]
Expand Down Expand Up @@ -716,7 +718,6 @@ def initialize_run(
self._is_started = True
flow_run_context = FlowRunContext.get()
parent_task_run_context = TaskRunContext.get()
self.logger.info(f"parameters {self.parameters}")

try:
if not self.task_run:
Expand Down Expand Up @@ -754,6 +755,9 @@ def initialize_run(
span_id=0,
is_remote=False,
)
labels = {}
if flow_run_context:
labels = flow_run_context.flow_run.labels

self._span = self._tracer.start_span(
name=self.task_run.name,
Expand All @@ -762,6 +766,7 @@ def initialize_run(
"prefect.run.id": str(self.task_run.id),
"prefect.tags": self.task_run.tags,
**parameter_attributes,
**labels,
},
links=links,
context=context,
Expand Down
135 changes: 134 additions & 1 deletion tests/telemetry/test_instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from uuid import UUID
from unittest.mock import Mock, patch
from uuid import UUID, uuid4

import pytest
from opentelemetry import metrics, trace
Expand All @@ -12,6 +13,10 @@
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider

from prefect import flow, task
from prefect.client.schemas import TaskRun
from prefect.states import Completed, Running
from prefect.task_engine import AsyncTaskRunEngine, digest_task_inputs
from prefect.telemetry.bootstrap import setup_telemetry
from prefect.telemetry.instrumentation import extract_account_and_workspace_id
from prefect.telemetry.logging import get_log_handler
Expand Down Expand Up @@ -160,3 +165,131 @@ def test_logger_provider(
log_handler = get_log_handler()
assert isinstance(log_handler, LoggingHandler)
assert log_handler._logger_provider == logger_provider


class TestTaskRunInstrumentation:
@pytest.fixture
def mock_tracer():
trace_provider, _, _ = setup_telemetry()
return trace_provider.get_tracer("prefect.test")

@pytest.fixture
async def task_run_engine(mock_tracer):
@task
async def test_task(x: int, y: int):
return x + y

task_run = TaskRun(
id=uuid4(),
task_key="test_task",
flow_run_id=uuid4(),
state=Running(),
dynamic_key="test_task-1",
)

engine = AsyncTaskRunEngine(
task=test_task,
task_run=task_run,
parameters={"x": 1, "y": 2},
_tracer=mock_tracer,
)
return engine

def test_digest_task_inputs():
inputs = {"x": 1, "y": 2}
parameters = {"x": int, "y": int}
otel_params, otel_inputs = digest_task_inputs(inputs, parameters)
assert otel_params == {
"prefect.run.parameter.x": "int",
"prefect.run.parameter.y": "int",
}
assert otel_inputs == []

@pytest.mark.asyncio
async def test_span_creation(task_run_engine, mock_tracer):
async with task_run_engine.start():
assert task_run_engine._span is not None
assert task_run_engine._span.name == task_run_engine.task_run.name
assert task_run_engine._span.attributes["prefect.run.type"] == "task"
assert task_run_engine._span.attributes["prefect.run.id"] == str(
task_run_engine.task_run.id
)

@pytest.mark.asyncio
async def test_span_attributes(task_run_engine):
async with task_run_engine.start():
assert "prefect.run.parameter.x" in task_run_engine._span.attributes
assert "prefect.run.parameter.y" in task_run_engine._span.attributes
assert task_run_engine._span.attributes["prefect.run.parameter.x"] == "int"
assert task_run_engine._span.attributes["prefect.run.parameter.y"] == "int"

@pytest.mark.asyncio
async def test_span_events(task_run_engine):
async with task_run_engine.start():
await task_run_engine.set_state(Running())
await task_run_engine.set_state(Completed())

events = task_run_engine._span.events
assert len(events) == 2
assert events[0].name == "Running"
assert events[1].name == "Completed"

@pytest.mark.asyncio
async def test_span_status_on_success(task_run_engine):
async with task_run_engine.start():
async with task_run_engine.run_context():
await task_run_engine.handle_success(3, Mock())

assert task_run_engine._span.status.status_code == trace.StatusCode.OK

@pytest.mark.asyncio
async def test_span_status_on_failure(task_run_engine):
async with task_run_engine.start():
async with task_run_engine.run_context():
await task_run_engine.handle_exception(ValueError("Test error"))

assert task_run_engine._span.status.status_code == trace.StatusCode.ERROR
assert "Test error" in task_run_engine._span.status.description

@pytest.mark.asyncio
async def test_span_exception_recording(task_run_engine):
test_exception = ValueError("Test error")
async with task_run_engine.start():
async with task_run_engine.run_context():
await task_run_engine.handle_exception(test_exception)

events = task_run_engine._span.events
assert any(event.name == "exception" for event in events)
exception_event = next(event for event in events if event.name == "exception")
assert exception_event.attributes["exception.type"] == "ValueError"
assert exception_event.attributes["exception.message"] == "Test error"

@pytest.mark.asyncio
async def test_span_links(task_run_engine):
# Simulate a parent task run
parent_task_run_id = uuid4()
task_run_engine.task_run.task_inputs = {
"x": [{"id": parent_task_run_id}],
"y": [2],
}

async with task_run_engine.start():
pass

assert len(task_run_engine._span.links) == 1
link = task_run_engine._span.links[0]
assert link.context.trace_id == int(parent_task_run_id)
assert link.attributes["prefect.run.id"] == str(parent_task_run_id)

@pytest.mark.asyncio
async def test_flow_run_labels(task_run_engine):
@flow
async def test_flow():
return await task_run_engine.task()

with patch("prefect.context.FlowRunContext.get") as mock_flow_run_context:
mock_flow_run_context.return_value.flow_run.labels = {"env": "test"}
async with task_run_engine.start():
pass

assert task_run_engine._span.attributes["env"] == "test"
Empty file removed tests/test_instrumentation.py
Empty file.

0 comments on commit 943cfe6

Please sign in to comment.