diff --git a/tests/trace/conftest.py b/tests/trace/conftest.py new file mode 100644 index 000000000000..864981b59765 --- /dev/null +++ b/tests/trace/conftest.py @@ -0,0 +1,18 @@ +"""Shared fixtures for trace tests.""" + +from __future__ import annotations + +import pytest + +from tests.trace.test_utils import FailingSaveType, failing_load, failing_save +from weave.trace.serialization import serializer + + +@pytest.fixture +def failing_serializer(): + """Register a serializer that always fails, and clean up after the test.""" + serializer.register_serializer(FailingSaveType, failing_save, failing_load) + yield FailingSaveType + serializer.SERIALIZERS[:] = [ + s for s in serializer.SERIALIZERS if s.target_class is not FailingSaveType + ] diff --git a/tests/trace/test_custom_objs.py b/tests/trace/test_custom_objs.py index 8bdec4402d99..13f418231935 100644 --- a/tests/trace/test_custom_objs.py +++ b/tests/trace/test_custom_objs.py @@ -1,9 +1,12 @@ +from __future__ import annotations + from datetime import datetime, timezone import rich.markdown from PIL import Image import weave +from tests.trace.test_utils import FailingSaveType from weave.trace.serialization.custom_objs import ( KNOWN_TYPES, decode_custom_obj, @@ -90,3 +93,37 @@ def make_datetime(): # due to deserializing a custom object calls = client.get_calls() assert len(calls) == 1 + + +def test_encode_custom_obj_save_exception_returns_none(client, failing_serializer): + """Requirement: Type handler save exceptions should not crash user code + Interface: encode_custom_obj function + Given: A serializer is registered whose save function raises an exception + When: encode_custom_obj is called with an object of that type + Then: Returns None (graceful degradation) + """ + obj = FailingSaveType("test_value") + + # This should NOT raise - if it does, the test fails + result = encode_custom_obj(obj) + + # Should return None instead of raising + assert result is None + + +def test_encode_custom_obj_save_exception_does_not_propagate( + client, failing_serializer +): + """Requirement: Type handler save exceptions must not propagate to user code + Interface: encode_custom_obj function + Given: A serializer is registered whose save function raises RuntimeError + When: encode_custom_obj is called + Then: No exception is raised to the caller + """ + obj = FailingSaveType("test_value") + + # This should NOT raise - if it does, the test fails + result = encode_custom_obj(obj) + + # We expect None as the graceful degradation + assert result is None diff --git a/tests/trace/test_utils.py b/tests/trace/test_utils.py new file mode 100644 index 000000000000..e06e504d441c --- /dev/null +++ b/tests/trace/test_utils.py @@ -0,0 +1,23 @@ +"""Shared test utilities for trace tests.""" + +from __future__ import annotations + + +class FailingSaveType: + """A type whose serializer save function always raises an exception.""" + + def __init__(self, value: str): + self.value = value + + def __repr__(self) -> str: + return f"FailingSaveType({self.value!r})" + + +def failing_save(obj, artifact, name): + """A save function that always raises an exception.""" + raise RuntimeError("Intentional failure in save function") + + +def failing_load(artifact, name, val): + """A load function (not used in these tests).""" + return FailingSaveType(val) diff --git a/tests/trace/type_handlers/test_type_handler_safety.py b/tests/trace/type_handlers/test_type_handler_safety.py new file mode 100644 index 000000000000..89395fc31e5f --- /dev/null +++ b/tests/trace/type_handlers/test_type_handler_safety.py @@ -0,0 +1,124 @@ +"""Tests to verify that type handlers never crash user code. + +Requirement: The weave op decorator is complex but should never crash user code. +Type handlers that fail during serialization should gracefully degrade without +affecting the user's program execution. +""" + +from __future__ import annotations + +import weave +from tests.trace.test_utils import FailingSaveType + + +def test_op_output_with_failing_serializer_does_not_raise(client, failing_serializer): + """Requirement: Op functions must return their values even when serialization fails + Interface: @weave.op decorated function returning an object with failing type handler + Given: An @weave.op function returns an object whose type handler save raises an exception + When: The function is called + Then: The function returns the correct value to the user (not None, not an exception) + """ + + @weave.op + def return_failing_type(value: str) -> FailingSaveType: + return FailingSaveType(value) + + # This should NOT raise - the user should get their return value + result = return_failing_type("hello") + + # The user must receive the actual object they created + assert isinstance(result, FailingSaveType) + assert result.value == "hello" + + +def test_op_input_with_failing_serializer_does_not_raise(client, failing_serializer): + """Requirement: Op functions must execute normally even when input serialization fails + Interface: @weave.op decorated function accepting an object with failing type handler + Given: An @weave.op function accepts an object whose type handler save raises an exception + When: The function is called + Then: The function executes normally and returns the expected result + """ + + @weave.op + def process_failing_type(obj: FailingSaveType) -> str: + return f"processed: {obj.value}" + + failing_obj = FailingSaveType("test_input") + + # This should NOT raise - the function should execute normally + result = process_failing_type(failing_obj) + + # The function must return its computed result + assert result == "processed: test_input" + + +def test_op_with_multiple_args_one_failing_serializer_does_not_raise( + client, failing_serializer +): + """Requirement: A failing serializer for one argument should not affect other arguments + Interface: @weave.op decorated function with multiple arguments + Given: An @weave.op function has multiple args, one with a failing type handler + When: The function is called + Then: The function executes normally and non-failing arguments are serialized properly + """ + + @weave.op + def mixed_args(normal_arg: str, failing_arg: FailingSaveType) -> str: + return f"{normal_arg}: {failing_arg.value}" + + failing_obj = FailingSaveType("failing_value") + + # This should NOT raise + result = mixed_args("normal", failing_obj) + + # Function should execute normally + assert result == "normal: failing_value" + + # Verify call was recorded + client.flush() + calls = mixed_args.calls() + assert len(calls) == 1 + + call = calls[0] + # The normal_arg should be serialized properly + assert call.inputs["normal_arg"] == "normal" + + +def test_op_with_failing_serializer_call_is_recorded(client, failing_serializer): + """Requirement: Calls should still be recorded even when serialization fails (with stringified fallback) + Interface: @weave.op decorated function and call record retrieval + Given: An @weave.op function returns an object whose type handler save fails + When: The function is called and we fetch the call record + Then: The call is recorded with a stringified representation of the failed object + """ + + @weave.op + def return_failing_for_record(value: str) -> FailingSaveType: + return FailingSaveType(value) + + # Call the function + result = return_failing_for_record("record_test") + + # Ensure the result was returned correctly + assert isinstance(result, FailingSaveType) + assert result.value == "record_test" + + # Flush to ensure the call is recorded + client.flush() + + # Get the call record + calls = return_failing_for_record.calls() + assert len(calls) == 1 + + call = calls[0] + + # The output should be recorded - either as the actual object (if serialization + # worked on the second try or there's a fallback) or as a stringified version + # The key assertion is that the call record exists and has an output + assert call.output is not None + + # If it fell back to stringify, it would be a string representation + # If serialization succeeded elsewhere, it might be the actual object + # Either way, the call should be recorded + output_str = str(call.output) + assert "record_test" in output_str or "FailingSaveType" in output_str diff --git a/weave/trace/serialization/custom_objs.py b/weave/trace/serialization/custom_objs.py index 122418205dc6..a737aa15fbdf 100644 --- a/weave/trace/serialization/custom_objs.py +++ b/weave/trace/serialization/custom_objs.py @@ -104,7 +104,21 @@ def encode_custom_obj(obj: Any) -> EncodedCustomObjDict | None: } art = MemTraceFilesArtifact() - val = serializer.save(obj, art, "obj") + try: + val = serializer.save(obj, art, "obj") + # TODO: In future, this should raise a specific WeaveException that can be caught + # and managed. A higher level handler will then catch that exception and ignore + # it by default, leading to the current behaviour. + except Exception: + # Type handler save functions should never crash user code. + # If a serializer fails, we log a warning and return None, + # which will cause the caller to fall back to stringify(). + logger.warning( + f"Failed to serialize object of type {type(obj).__name__}. " + "Falling back to string representation.", + exc_info=True, + ) + return None if art.path_contents: encoded_path_contents = { k: (v.encode("utf-8") if isinstance(v, str) else v) # type: ignore