Skip to content

Custom serialization for non-user types and non-serializable types for Hera runner (Parameter/Artifact inputs and outputs) #1166

Closed
@elliotgunton

Description

@elliotgunton

Proposal for custom serialisation and deserialisation

A largely common-sense design: we add the loader and dumper fields to Parameter and Artifact, allowing loader to be a Callable for the existing loader: Optional[ArtifactLoader] = None in artifact.py.

The attributes will be:

    dumper: Optional[Callable[[Any], str]] = None
    """used to specify a dumper function to serialise the parameter value for Annotated parameters"""

    loader: Optional[Callable[[str], Any]] = None
    """used to specify a loader function to load the parameter value for Annotated parameters"""

Then, I've split the problem into two main cases to deal with (which are duplicated for Artifacts, but are identical to Parameters from the user perspective):

  1. The value to be (de)serialised is a plain data type or BaseModel - this is already possible without any further implementation (tests exist).
  2. The value to be (de)serialised is a custom class that the user may not control, e.g. a pandas.DataFrame from the original issue motivation. The (de)serialisation function should be passed in the annotation to the Parameter/Artifact's loader/dumper.

Then the third "case" worth pointing out is if the script function is using the Pydantic IO feature (so a class inheriting from Input or Output from hera.workflows.io), any attributes needing custom dumpers/loaders should have their (de)serialisation function passed in the annotation of the attributes within the class, not in the annotation of the function. The BaseModel will then need to add allow_arbitrary_types = True in the model config.

The examples below should demonstrate the intention from a user perspective.

Preamble code (imports and custom classes)
import json

from pydantic import BaseModel

from hera.shared import global_config
from hera.workflows import Parameter, script
from hera.workflows.io.v1 import Input, Output

try:
    from typing import Annotated
except ImportError:
    from typing_extensions import Annotated
try:
    from pydantic.v1 import BaseModel
except ImportError:
    from pydantic import BaseModel

global_config.experimental_features["script_pydantic_io"] = True


# Classes used in examples:
class MyBaseModel(BaseModel):
    a: str = "a"
    b: str = "b"


class NonUserNonBaseModelClass:
    """Represents a non-user-defined class (e.g. pandas DataFrame) that does not inherit from BaseModel."""

    def __init__(self, a: str, b: str):
        self.a = a
        self.b = b

    @classmethod
    def from_json(cls, json_str) -> "NonUserNonBaseModelClass":
        return cls(**json.loads(json_str))

    def to_json(self) -> str:
        self_dict = {
            "a": self.a,
            "b": self.b,
        }
        return json.dumps(self_dict)

class MyInput(Input):
    non_user_defined_class: Annotated[
        NonUserNonBaseModelClass, Parameter(name="my-parameter", loader=NonUserNonBaseModelClass.from_json)
    ]


class MyOutput(Output):
    non_user_defined_class: Annotated[
        NonUserNonBaseModelClass, Parameter(name="my-output", dumper=NonUserNonBaseModelClass.to_json)
    ]

Example code:

# Deserialisation case 1: Basic types and BaseModel
# These already work with the current implementation (see the
# annotated-param-no-name test in test_runner_annotated_parameter_inputs)
@script(constructor="runner")
def base_model_auto_load(
    a_parameter: Annotated[MyBaseModel, Parameter(name="my-parameter")],
    another_param: Annotated[int, Parameter(name="int-parameter")],
) -> str:
    print(another_param)
    return a_parameter.a + a_parameter.b


# Deserialisation case 2: A class that does not inherit from BaseModel (representing a non-user-defined class)
# The user needs to be able to specify their own serialisation and deserialisation methods. This could be
# from specifying the existing function from the non-user defined class, or by passing a lambda function.
@script(constructor="runner")
def non_base_model_with_class_loader(
    a_parameter: Annotated[
        NonUserNonBaseModelClass,
        Parameter(name="my-parameter", loader=NonUserNonBaseModelClass.from_json),
    ],
) -> str:
    return a_parameter.a + a_parameter.b


@script(constructor="runner")
def non_base_model_with_lambda_function_loader(
    a_parameter: Annotated[
        NonUserNonBaseModelClass,
        Parameter(name="my-parameter", loader=lambda json_str: NonUserNonBaseModelClass(**json.loads(json_str))),
    ],
) -> str:
    return a_parameter.a + a_parameter.b


# Deserialisation case 3: A subclass of the special Pydantic IO classes; custom
# deserialisation method provided in the attribute annotation
@script(constructor="runner")
def pydantic_input_with_loader_on_attribute(
    my_input: MyInput,
) -> str:
    return my_input.non_user_defined_class.a + my_input.non_user_defined_class.b


# Serialisation case 1: Basic types and BaseModel
# These already work with the current implementation (see the
# return_base_model test in test_script_annotations_outputs)
@script(constructor="runner")
def base_model_auto_save(
    a: str,
    b: str,
) -> tuple[
    Annotated[MyBaseModel, Parameter(name="my-output")],
    Annotated[int, Parameter(name="int-output")],
]:
    return (MyBaseModel(a=a, b=b), 42)


# Serialisation case 2: A class that does not inherit from BaseModel (representing a non-user-defined class)
@script(constructor="runner")
def non_base_model_with_class_serialiser(
    a: str,
    b: str,
) -> Annotated[
    NonUserNonBaseModelClass,
    Parameter(name="my-output", dumper=NonUserNonBaseModelClass.to_json),
]:
    return NonUserNonBaseModelClass(a=a, b=b)


# Serialisation case 3: A subclass of the special Pydantic IO classes; custom
# serialisation method provided in the attribute annotation
@script(constructor="runner")
def pydantic_output_with_dumper_on_attribute(
    a: str,
    b: str,
) -> MyOutput:
    return MyOutput(non_user_defined_class=NonUserNonBaseModelClass(a=a, b=b))
Original issue details (motivating example)

Is your feature request related to a problem? Please describe.
Tried to use pandas.DataFrame for outputs, got error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 222, in _runner
    output = _save_annotated_return_outputs(function(**kwargs), output_annotations)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 250, in _save_annotated_return_outputs
    _write_to_path(path, value)
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 326, in _write_to_path
    output_string = serialize(output_value)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/shared/serialization.py", line 51, in serialize
    if value == MISSING:
       ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/pandas/core/generic.py", line 1577, in __nonzero__
    raise ValueError(
ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

Python:

global_config.experimental_features["script_pydantic_io"] = True


class Datasets(Output):
    X_train: pd.DataFrame
    X_test: pd.DataFrame
    y_train: pd.Series
    y_test: pd.Series

    class Config:
        arbitrary_types_allowed=True


# Load dataset
@script(constructor="runner")
def load_and_split_dataset(
    dataset_path: Annotated[
        Path,
        S3Artifact(...),
    ],
) -> Datasets:
    data = pd.read_csv(dataset_path)

    # Split into features and target
    X = data.drop("Outcome", axis=1)
    y = data["Outcome"]

    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    return Datasets(
        X_train=X_train,
        X_test=X_test,
        y_train=y_train,
        y_test=y_test,
    )

Pandas DataFrames have a to_json method which would make things easier, but I have no way to tell the serialize function in hera.shared.serialization what to do with DataFrames. I also can't change the class code, hence "non-user" type (I could subclass it though?).

Describe the solution you'd like
A clear and concise description of what you want to happen.

An easy way to plug in the "how" for serializing custom types in the runner, e.g. as part of the type annotation, or a global setter such as global_config.serializer = my_serializer, or maybe in the RunnerScriptConstructor? (Needs some more thought)

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

  • Just use strs and use the DataFrame.to_json method
  • Create a subclass and associated runtime files of RunnerScriptConstructor so I can use my own serialize function

Additional context
Add any other context or screenshots about the feature request here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions