diff --git a/src/prefect/blocks/core.py b/src/prefect/blocks/core.py index 620d87e80b70..59097417e46a 100644 --- a/src/prefect/blocks/core.py +++ b/src/prefect/blocks/core.py @@ -2,6 +2,7 @@ import html import inspect import sys +import uuid import warnings from abc import ABC from functools import partial @@ -790,6 +791,33 @@ async def _get_block_document( return block_document, block_document_name + @classmethod + @sync_compatible + @inject_client + async def _get_block_document_by_id( + cls, + block_document_id: Union[str, uuid.UUID], + client: Optional["PrefectClient"] = None, + ): + if isinstance(block_document_id, str): + try: + block_document_id = UUID(block_document_id) + except ValueError: + raise ValueError( + f"Block document ID {block_document_id!r} is not a valid UUID" + ) + + try: + block_document = await client.read_block_document( + block_document_id=block_document_id + ) + except prefect.exceptions.ObjectNotFound: + raise ValueError( + f"Unable to find block document with ID {block_document_id!r}" + ) + + return block_document, block_document.name + @classmethod @sync_compatible @inject_client @@ -876,6 +904,104 @@ class Custom(Block): """ block_document, block_document_name = await cls._get_block_document(name) + return cls._load_from_block_document(block_document, validate=validate) + + @classmethod + @sync_compatible + @inject_client + async def load_from_ref( + cls, + ref: Union[str, UUID, Dict[str, Any]], + validate: bool = True, + client: Optional["PrefectClient"] = None, + ) -> "Self": + """ + Retrieves data from the block document by given reference for the block type + that corresponds with the current class and returns an instantiated version of + the current class with the data stored in the block document. + + Provided reference can be a block document ID, or a reference data in dictionary format. + Supported dictionary reference formats are: + - {"block_document_id": } + - {"block_document_slug": } + + If a block document for a given block type is saved with a different schema + than the current class calling `load`, a warning will be raised. + + If the current class schema is a subset of the block document schema, the block + can be loaded as normal using the default `validate = True`. + + If the current class schema is a superset of the block document schema, `load` + must be called with `validate` set to False to prevent a validation error. In + this case, the block attributes will default to `None` and must be set manually + and saved to a new block document before the block can be used as expected. + + Args: + ref: The reference to the block document. This can be a block document ID, + or one of supported dictionary reference formats. + validate: If False, the block document will be loaded without Pydantic + validating the block schema. This is useful if the block schema has + changed client-side since the block document referred to by `name` was saved. + client: The client to use to load the block document. If not provided, the + default client will be injected. + + Raises: + ValueError: If invalid reference format is provided. + ValueError: If the requested block document is not found. + + Returns: + An instance of the current class hydrated with the data stored in the + block document with the specified name. + + """ + block_document = None + if isinstance(ref, (str, UUID)): + block_document, _ = await cls._get_block_document_by_id(ref) + elif isinstance(ref, dict): + if block_document_id := ref.get("block_document_id"): + block_document, _ = await cls._get_block_document_by_id( + block_document_id + ) + elif block_document_slug := ref.get("block_document_slug"): + block_document, _ = await cls._get_block_document(block_document_slug) + + if not block_document: + raise ValueError(f"Invalid reference format {ref!r}.") + + return cls._load_from_block_document(block_document, validate=validate) + + @classmethod + def _load_from_block_document( + cls, block_document: BlockDocument, validate: bool = True + ) -> "Self": + """ + Loads a block from a given block document. + + If a block document for a given block type is saved with a different schema + than the current class calling `load`, a warning will be raised. + + If the current class schema is a subset of the block document schema, the block + can be loaded as normal using the default `validate = True`. + + If the current class schema is a superset of the block document schema, `load` + must be called with `validate` set to False to prevent a validation error. In + this case, the block attributes will default to `None` and must be set manually + and saved to a new block document before the block can be used as expected. + + Args: + block_document: The block document used to instantiate a block. + validate: If False, the block document will be loaded without Pydantic + validating the block schema. This is useful if the block schema has + changed client-side since the block document referred to by `name` was saved. + + Raises: + ValueError: If the requested block document is not found. + + Returns: + An instance of the current class hydrated with the data stored in the + block document with the specified name. + + """ try: return cls._from_block_document(block_document) except ValidationError as e: @@ -883,18 +1009,18 @@ class Custom(Block): missing_fields = tuple(err["loc"][0] for err in e.errors()) missing_block_data = {field: None for field in missing_fields} warnings.warn( - f"Could not fully load {block_document_name!r} of block type" + f"Could not fully load {block_document.name!r} of block type" f" {cls.get_block_type_slug()!r} - this is likely because one or more" " required fields were added to the schema for" f" {cls.__name__!r} that did not exist on the class when this block" " was last saved. Please specify values for new field(s):" f" {listrepr(missing_fields)}, then run" - f' `{cls.__name__}.save("{block_document_name}", overwrite=True)`,' + f' `{cls.__name__}.save("{block_document.name}", overwrite=True)`,' " and load this block again before attempting to use it." ) return cls.model_construct(**block_document.data, **missing_block_data) raise RuntimeError( - f"Unable to load {block_document_name!r} of block type" + f"Unable to load {block_document.name!r} of block type" f" {cls.get_block_type_slug()!r} due to failed validation. To load without" " validation, try loading again with `validate=False`." ) from e diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 042378df0a71..857e2e493fb6 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -95,7 +95,7 @@ parameters_to_args_kwargs, raise_for_reserved_arguments, ) -from prefect.utilities.collections import listrepr +from prefect.utilities.collections import listrepr, visit_collection from prefect.utilities.filesystem import relative_path_to_current_platform from prefect.utilities.hashing import file_hash from prefect.utilities.importtools import import_object, safe_load_namespace @@ -535,6 +535,21 @@ def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]: Raises: ParameterTypeError: if the provided parameters are not valid """ + + def resolve_block_reference(data: Any) -> Any: + if isinstance(data, dict) and "$ref" in data: + return Block.load_from_ref(data["$ref"]) + return data + + try: + parameters = visit_collection( + parameters, resolve_block_reference, return_data=True + ) + except (ValueError, RuntimeError) as exc: + raise ParameterTypeError( + "Failed to resolve block references in parameters." + ) from exc + args, kwargs = parameters_to_args_kwargs(self.fn, parameters) with warnings.catch_warnings(): diff --git a/src/prefect/utilities/schema_tools/validation.py b/src/prefect/utilities/schema_tools/validation.py index 6e67d7e1065f..e94204331125 100644 --- a/src/prefect/utilities/schema_tools/validation.py +++ b/src/prefect/utilities/schema_tools/validation.py @@ -253,5 +253,35 @@ def preprocess_schema( process_properties( definition["properties"], required_fields, allow_none_with_default ) + # Allow block types to be referenced by their id + if "block_type_slug" in definition: + schema["definitions"][definition["title"]] = { + "oneOf": [ + definition, + { + "type": "object", + "properties": { + "$ref": { + "oneOf": [ + { + "type": "string", + "format": "uuid", + }, + { + "type": "object", + "additionalProperties": { + "type": "string", + }, + "minProperties": 1, + }, + ] + } + }, + "required": [ + "$ref", + ], + }, + ] + } return schema diff --git a/tests/blocks/test_block_reference.py b/tests/blocks/test_block_reference.py new file mode 100644 index 000000000000..06de8143685c --- /dev/null +++ b/tests/blocks/test_block_reference.py @@ -0,0 +1,191 @@ +import warnings +from typing import Type +from uuid import UUID, uuid4 + +import pydantic +import pytest + +from prefect.blocks.core import Block +from prefect.exceptions import ParameterTypeError +from prefect.flows import flow + + +class TestBlockReference: + class ReferencedBlock(Block): + a: int + b: str + + class SimilarReferencedBlock(Block): + a: int + b: str + + class OtherReferencedBlock(Block): + c: int + d: str + + @pytest.fixture + def block_document_id(self, prefect_client) -> UUID: + block = self.ReferencedBlock(a=1, b="foo") + block.save("block-reference", client=prefect_client) + return block._block_document_id + + def test_block_load_from_reference( + self, + block_document_id: UUID, + ): + block = self.ReferencedBlock.load_from_ref(block_document_id) + assert block.a == 1 + assert block.b == "foo" + + def test_base_block_load_from_reference( + self, + block_document_id: UUID, + ): + block = Block.load_from_ref(block_document_id) + assert isinstance(block, self.ReferencedBlock) + assert block.a == 1 + assert block.b == "foo" + + def test_block_load_from_reference_string( + self, + block_document_id: UUID, + ): + block = self.ReferencedBlock.load_from_ref(str(block_document_id)) + assert block.a == 1 + assert block.b == "foo" + + def test_block_load_from_bad_reference(self): + with pytest.raises(ValueError, match="is not a valid UUID"): + self.ReferencedBlock.load_from_ref("non-valid-uuid") + + with pytest.raises(ValueError, match="Unable to find block document with ID"): + self.ReferencedBlock.load_from_ref(uuid4()) + + def test_block_load_from_similar_block_reference_type(self): + block = self.SimilarReferencedBlock(a=1, b="foo") + block.save("other-block") + + block = self.ReferencedBlock.load_from_ref(block._block_document_id) + assert block.a == 1 + assert block.b == "foo" + + def test_block_load_from_invalid_block_reference_type(self): + block = self.OtherReferencedBlock(c=1, d="foo") + block.save("other-block") + + with pytest.raises(RuntimeError): + self.ReferencedBlock.load_from_ref(block._block_document_id) + + def test_block_load_from_nested_block_reference(self): + ReferencedBlock = self.ReferencedBlock + + class NestedReferencedBlock(Block): + inner_block: ReferencedBlock + + nested_block = NestedReferencedBlock(inner_block=ReferencedBlock(a=1, b="foo")) + nested_block.save("nested-block") + + loaded_block = NestedReferencedBlock.load_from_ref( + nested_block._block_document_id + ) + assert getattr(loaded_block, "inner_block", None) is not None + assert loaded_block.inner_block.a == 1 + assert loaded_block.inner_block.b == "foo" + + +class TestFlowWithBlockParam: + @pytest.fixture + def ParamBlock(self) -> Type[Block]: + # Ignore warning caused by matching key in registry due to block fixture + warnings.filterwarnings("ignore", category=UserWarning) + + class ParamBlock(Block): + a: int + b: str + + return ParamBlock + + @pytest.fixture + def OtherParamBlock(self) -> Type[Block]: + # Ignore warning caused by matching key in registry due to block fixture + warnings.filterwarnings("ignore", category=UserWarning) + + class OtherParamBlock(Block): + a: int + b: str + + return OtherParamBlock + + def test_flow_with_block_params(self, ParamBlock): + ref_block = ParamBlock(a=10, b="foo") + ref_block.save("param-block") + + @flow + def flow_with_block_param(block: ParamBlock) -> int: + return block.a + + assert ( + flow_with_block_param({"$ref": str(ref_block._block_document_id)}) + == ref_block.a + ) + assert ( + flow_with_block_param( + {"$ref": {"block_document_id": str(ref_block._block_document_id)}} + ) + == ref_block.a + ) + + def test_flow_with_invalid_block_param_type(self, ParamBlock, OtherParamBlock): + ref_block = OtherParamBlock(a=10, b="foo") + ref_block.save("other-param-block") + + @flow + def flow_with_block_param(block: ParamBlock) -> int: + return block.a + + with pytest.raises( + ParameterTypeError, match="Flow run received invalid parameters" + ): + flow_with_block_param({"$ref": str(ref_block._block_document_id)}) + + def test_flow_with_nested_block_params(self, ParamBlock): + class NestedParamBlock(Block): + inner_block: ParamBlock + + nested_block = NestedParamBlock(inner_block=ParamBlock(a=12, b="foo")) + nested_block.save("nested-block") + + @flow + def flow_with_nested_block_param(block: NestedParamBlock): + return block.inner_block.a + + assert ( + flow_with_nested_block_param( + {"$ref": {"block_document_id": str(nested_block._block_document_id)}} + ) + == nested_block.inner_block.a + ) + + def test_flow_with_block_param_in_basemodel(self, ParamBlock): + class ParamModel(pydantic.BaseModel): + block: ParamBlock + + param_block = ParamBlock(a=12, b="foo") + param_block.save("param-block") + + @flow + def flow_with_block_param_in_basemodel(param: ParamModel): + return param.block.a + + assert ( + flow_with_block_param_in_basemodel( + { + "block": { + "$ref": { + "block_document_id": str(param_block._block_document_id) + } + } + } + ) + == param_block.a + )