Description
Proposal for custom serialisation and deserialisation
A largely common-sense design: we add the loader
and dumper
fields to Parameter
and Artifact
, allowing loader
to be a Callable
for the existing loader: Optional[ArtifactLoader] = None
in artifact.py
.
The attributes will be:
dumper: Optional[Callable[[Any], str]] = None
"""used to specify a dumper function to serialise the parameter value for Annotated parameters"""
loader: Optional[Callable[[str], Any]] = None
"""used to specify a loader function to load the parameter value for Annotated parameters"""
Then, I've split the problem into two main cases to deal with (which are duplicated for Artifacts, but are identical to Parameters from the user perspective):
- The value to be (de)serialised is a plain data type or
BaseModel
- this is already possible without any further implementation (tests exist). - The value to be (de)serialised is a custom class that the user may not control, e.g. a
pandas.DataFrame
from the original issue motivation. The (de)serialisation function should be passed in the annotation to the Parameter/Artifact'sloader
/dumper
.
Then the third "case" worth pointing out is if the script function is using the Pydantic IO feature (so a class inheriting from Input
or Output
from hera.workflows.io
), any attributes needing custom dumpers/loaders should have their (de)serialisation function passed in the annotation of the attributes within the class, not in the annotation of the function. The BaseModel will then need to add allow_arbitrary_types = True
in the model config.
The examples below should demonstrate the intention from a user perspective.
Preamble code (imports and custom classes)
import json
from pydantic import BaseModel
from hera.shared import global_config
from hera.workflows import Parameter, script
from hera.workflows.io.v1 import Input, Output
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
global_config.experimental_features["script_pydantic_io"] = True
# Classes used in examples:
class MyBaseModel(BaseModel):
a: str = "a"
b: str = "b"
class NonUserNonBaseModelClass:
"""Represents a non-user-defined class (e.g. pandas DataFrame) that does not inherit from BaseModel."""
def __init__(self, a: str, b: str):
self.a = a
self.b = b
@classmethod
def from_json(cls, json_str) -> "NonUserNonBaseModelClass":
return cls(**json.loads(json_str))
def to_json(self) -> str:
self_dict = {
"a": self.a,
"b": self.b,
}
return json.dumps(self_dict)
class MyInput(Input):
non_user_defined_class: Annotated[
NonUserNonBaseModelClass, Parameter(name="my-parameter", loader=NonUserNonBaseModelClass.from_json)
]
class MyOutput(Output):
non_user_defined_class: Annotated[
NonUserNonBaseModelClass, Parameter(name="my-output", dumper=NonUserNonBaseModelClass.to_json)
]
Example code:
# Deserialisation case 1: Basic types and BaseModel
# These already work with the current implementation (see the
# annotated-param-no-name test in test_runner_annotated_parameter_inputs)
@script(constructor="runner")
def base_model_auto_load(
a_parameter: Annotated[MyBaseModel, Parameter(name="my-parameter")],
another_param: Annotated[int, Parameter(name="int-parameter")],
) -> str:
print(another_param)
return a_parameter.a + a_parameter.b
# Deserialisation case 2: A class that does not inherit from BaseModel (representing a non-user-defined class)
# The user needs to be able to specify their own serialisation and deserialisation methods. This could be
# from specifying the existing function from the non-user defined class, or by passing a lambda function.
@script(constructor="runner")
def non_base_model_with_class_loader(
a_parameter: Annotated[
NonUserNonBaseModelClass,
Parameter(name="my-parameter", loader=NonUserNonBaseModelClass.from_json),
],
) -> str:
return a_parameter.a + a_parameter.b
@script(constructor="runner")
def non_base_model_with_lambda_function_loader(
a_parameter: Annotated[
NonUserNonBaseModelClass,
Parameter(name="my-parameter", loader=lambda json_str: NonUserNonBaseModelClass(**json.loads(json_str))),
],
) -> str:
return a_parameter.a + a_parameter.b
# Deserialisation case 3: A subclass of the special Pydantic IO classes; custom
# deserialisation method provided in the attribute annotation
@script(constructor="runner")
def pydantic_input_with_loader_on_attribute(
my_input: MyInput,
) -> str:
return my_input.non_user_defined_class.a + my_input.non_user_defined_class.b
# Serialisation case 1: Basic types and BaseModel
# These already work with the current implementation (see the
# return_base_model test in test_script_annotations_outputs)
@script(constructor="runner")
def base_model_auto_save(
a: str,
b: str,
) -> tuple[
Annotated[MyBaseModel, Parameter(name="my-output")],
Annotated[int, Parameter(name="int-output")],
]:
return (MyBaseModel(a=a, b=b), 42)
# Serialisation case 2: A class that does not inherit from BaseModel (representing a non-user-defined class)
@script(constructor="runner")
def non_base_model_with_class_serialiser(
a: str,
b: str,
) -> Annotated[
NonUserNonBaseModelClass,
Parameter(name="my-output", dumper=NonUserNonBaseModelClass.to_json),
]:
return NonUserNonBaseModelClass(a=a, b=b)
# Serialisation case 3: A subclass of the special Pydantic IO classes; custom
# serialisation method provided in the attribute annotation
@script(constructor="runner")
def pydantic_output_with_dumper_on_attribute(
a: str,
b: str,
) -> MyOutput:
return MyOutput(non_user_defined_class=NonUserNonBaseModelClass(a=a, b=b))
Original issue details (motivating example)
Is your feature request related to a problem? Please describe.
Tried to use pandas.DataFrame
for outputs, got error:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 222, in _runner
output = _save_annotated_return_outputs(function(**kwargs), output_annotations)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 250, in _save_annotated_return_outputs
_write_to_path(path, value)
File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 326, in _write_to_path
output_string = serialize(output_value)
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/hera/shared/serialization.py", line 51, in serialize
if value == MISSING:
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/pandas/core/generic.py", line 1577, in __nonzero__
raise ValueError(
ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Python:
global_config.experimental_features["script_pydantic_io"] = True
class Datasets(Output):
X_train: pd.DataFrame
X_test: pd.DataFrame
y_train: pd.Series
y_test: pd.Series
class Config:
arbitrary_types_allowed=True
# Load dataset
@script(constructor="runner")
def load_and_split_dataset(
dataset_path: Annotated[
Path,
S3Artifact(...),
],
) -> Datasets:
data = pd.read_csv(dataset_path)
# Split into features and target
X = data.drop("Outcome", axis=1)
y = data["Outcome"]
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
return Datasets(
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)
Pandas DataFrames have a to_json
method which would make things easier, but I have no way to tell the serialize
function in hera.shared.serialization
what to do with DataFrames. I also can't change the class code, hence "non-user" type (I could subclass it though?).
Describe the solution you'd like
A clear and concise description of what you want to happen.
An easy way to plug in the "how" for serializing custom types in the runner, e.g. as part of the type annotation, or a global setter such as global_config.serializer = my_serializer
, or maybe in the RunnerScriptConstructor
? (Needs some more thought)
Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
- Just use
strs
and use theDataFrame.to_json
method - Create a subclass and associated runtime files of
RunnerScriptConstructor
so I can use my own serialize function
Additional context
Add any other context or screenshots about the feature request here.