Skip to content

Commit

Permalink
multi types schema format unmarshal fix
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Feb 11, 2024
1 parent e666357 commit 97775cf
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 52 deletions.
18 changes: 0 additions & 18 deletions openapi_core/unmarshalling/schemas/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,3 @@ class FormatterNotFoundError(UnmarshallerError):

def __str__(self) -> str:
return f"Formatter not found for {self.type_format} format"


@dataclass
class FormatUnmarshalError(UnmarshallerError):
"""Unable to unmarshal value for format"""

value: str
type: str
original_exception: Exception

def __str__(self) -> str:
return (
"Unable to unmarshal value {value} for format {type}: {exception}"
).format(
value=self.value,
type=self.type,
exception=self.original_exception,
)
58 changes: 27 additions & 31 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from openapi_core.unmarshalling.schemas.datatypes import (
FormatUnmarshallersDict,
)
from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError
from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError
from openapi_core.validation.schemas.validators import SchemaValidator

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,34 +136,15 @@ def _unmarshal_properties(

class MultiTypeUnmarshaller(PrimitiveUnmarshaller):
def __call__(self, value: Any) -> Any:
unmarshaller = self._get_best_unmarshaller(value)
primitive_type = self.schema_validator.get_primitive_type(value)
unmarshaller = self.schema_unmarshaller.get_type_unmarshaller(
primitive_type
)
return unmarshaller(value)

@property
def type(self) -> List[str]:
types = self.schema.getkey("type", ["any"])
assert isinstance(types, list)
return types

def _get_best_unmarshaller(self, value: Any) -> "PrimitiveUnmarshaller":
for schema_type in self.type:
result = self.schema_validator.type_validator(
value, type_override=schema_type
)
if not result:
continue
result = self.schema_validator.format_validator(value)
if not result:
continue
return self.schema_unmarshaller.get_type_unmarshaller(schema_type)

raise UnmarshallerError("Unmarshaller not found for type(s)")


class AnyUnmarshaller(MultiTypeUnmarshaller):
@property
def type(self) -> List[str]:
return self.schema_unmarshaller.types_unmarshaller.get_types()
pass


class TypesUnmarshaller:
Expand All @@ -185,7 +164,7 @@ def __init__(
def get_types(self) -> List[str]:
return list(self.unmarshallers.keys())

def get_unmarshaller(
def get_unmarshaller_cls(
self,
schema_type: Optional[Union[Iterable[str], str]],
) -> Type["PrimitiveUnmarshaller"]:
Expand Down Expand Up @@ -220,8 +199,8 @@ def unmarshal(self, schema_format: str, value: Any) -> Any:
return value
try:
return format_unmarshaller(value)
except (ValueError, TypeError) as exc:
raise FormatUnmarshalError(value, schema_format, exc)
except (AttributeError, ValueError, TypeError):
return value

def get_unmarshaller(
self, schema_format: str
Expand Down Expand Up @@ -279,19 +258,32 @@ def unmarshal(self, value: Any) -> Any:
(isinstance(value, bytes) and schema_format in ["binary", "byte"])
):
return typed
return self.formats_unmarshaller.unmarshal(schema_format, typed)

format_unmarshaller = self.get_format_unmarshaller(schema_format)
if format_unmarshaller is None:
return typed
try:
return format_unmarshaller(typed)
except (AttributeError, ValueError, TypeError):
return typed

def get_type_unmarshaller(
self,
schema_type: Optional[Union[Iterable[str], str]],
) -> PrimitiveUnmarshaller:
klass = self.types_unmarshaller.get_unmarshaller(schema_type)
klass = self.types_unmarshaller.get_unmarshaller_cls(schema_type)
return klass(
self.schema,
self.schema_validator,
self,
)

def get_format_unmarshaller(
self,
schema_format: str,
) -> Optional[FormatUnmarshaller]:
return self.formats_unmarshaller.get_unmarshaller(schema_format)

def evolve(self, schema: SchemaPath) -> "SchemaUnmarshaller":
cls = self.__class__

Expand All @@ -304,6 +296,10 @@ def evolve(self, schema: SchemaPath) -> "SchemaUnmarshaller":

def find_format(self, value: Any) -> Optional[str]:
for schema in self.schema_validator.iter_valid_schemas(value):
schema_validator = self.schema_validator.evolve(schema)
primitive_type = schema_validator.get_primitive_type(value)
if primitive_type != "string":
continue
if "format" in schema:
return str(schema.getkey("format"))
return None
18 changes: 18 additions & 0 deletions openapi_core/validation/schemas/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ def format_validator_callable(self) -> FormatValidator:

return lambda x: True

def get_primitive_type(self, value: Any) -> Optional[str]:
schema_types = self.schema.getkey("type")
if isinstance(schema_types, str):
return schema_types
if schema_types is None:
schema_types = sorted(self.validator.TYPE_CHECKER._type_checkers)
assert isinstance(schema_types, list)
for schema_type in schema_types:
result = self.type_validator(value, type_override=schema_type)
if not result:
continue
result = self.format_validator(value)
if not result:
continue
assert isinstance(schema_type, (str, type(None)))
return schema_type
return None

def iter_valid_schemas(self, value: Any) -> Iterator[SchemaPath]:
yield self.schema

Expand Down
21 changes: 21 additions & 0 deletions tests/integration/unmarshalling/test_unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,27 @@ def test_nultiple_types_invalid(self, unmarshallers_factory, types, value):
assert len(exc_info.value.schema_errors) == 1
assert "is not of type" in exc_info.value.schema_errors[0].message

@pytest.mark.parametrize(
"types,format,value,expected",
[
(["string", "null"], "date", None, None),
(["string", "null"], "date", "2018-12-13", date(2018, 12, 13)),
],
)
def test_multiple_types_format_valid_or_ignored(
self, unmarshallers_factory, types, format, value, expected
):
schema = {
"type": types,
"format": format,
}
spec = SchemaPath.from_dict(schema)
unmarshaller = unmarshallers_factory.create(spec)

result = unmarshaller.unmarshal(value)

assert result == expected

def test_any_null(self, unmarshallers_factory):
schema = {}
spec = SchemaPath.from_dict(schema)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/unmarshalling/test_schema_unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from openapi_core.unmarshalling.schemas.exceptions import (
FormatterNotFoundError,
)
from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError
from openapi_core.unmarshalling.schemas.factories import (
SchemaUnmarshallersFactory,
)
Expand Down Expand Up @@ -102,8 +101,9 @@ def custom_format_unmarshaller(value):
extra_format_unmarshallers=extra_format_unmarshallers,
)

with pytest.raises(FormatUnmarshalError):
unmarshaller.unmarshal(value)
result = unmarshaller.unmarshal(value)

assert result == value

def test_schema_extra_format_unmarshaller_format_custom(
self, schema_unmarshaller_factory
Expand Down

0 comments on commit 97775cf

Please sign in to comment.