Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -21363,6 +21363,9 @@ convert_object_to_struct(
should_untrack = !MS_MAYBE_TRACKED(val);
}
}

if (Struct_decode_post_init(struct_type, out, path) < 0) goto error;

Py_LeaveRecursiveCall();
if (is_gc && !should_untrack)
PyObject_GC_Track(out);
Expand Down
35 changes: 16 additions & 19 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,18 +2239,18 @@ class Test2(Struct, Generic[T], tag=True, array_like=array_like):


class TestStructPostInit:
@pytest.mark.parametrize("array_like", [False, True])
@pytest.mark.parametrize("union", [False, True])
def test_struct_post_init(self, array_like, union):
count = 0
@mapcls_from_attributes_and_array_like
def test_struct_post_init(self, union, mapcls, from_attributes, array_like):
called = False
singleton = object()

class Ex(Struct, array_like=array_like, tag=union):
x: int

def __post_init__(self):
nonlocal count
count += 1
nonlocal called
called = True
return singleton

if union:
Expand All @@ -2262,25 +2262,23 @@ class Ex2(Struct, array_like=array_like, tag=True):
else:
typ = Ex

msg = Ex(1)
buf = to_builtins(msg)
res = convert(buf, type=typ)
assert res == msg
assert count == 2 # 1 for Ex(), 1 for decode
msg = mapcls(type="Ex", x=1) if union else mapcls(x=1)
res = convert(msg, type=typ, from_attributes=from_attributes)
assert type(res) is Ex
assert called
assert sys.getrefcount(singleton) == 2 # 1 for ref, 1 for call

@pytest.mark.parametrize("array_like", [False, True])
@pytest.mark.parametrize("union", [False, True])
@pytest.mark.parametrize("exc_class", [ValueError, TypeError, OSError])
def test_struct_post_init_errors(self, array_like, union, exc_class):
error = False

@mapcls_from_attributes_and_array_like
def test_struct_post_init_errors(
self, union, exc_class, mapcls, from_attributes, array_like
):
class Ex(Struct, array_like=array_like, tag=union):
x: int

def __post_init__(self):
if error:
raise exc_class("Oh no!")
raise exc_class("Oh no!")

if union:

Expand All @@ -2291,16 +2289,15 @@ class Ex2(Struct, array_like=array_like, tag=True):
else:
typ = Ex

msg = to_builtins([Ex(1)])
error = True
msg = [mapcls(type="Ex", x=1) if union else mapcls(x=1)]

if exc_class in (ValueError, TypeError):
expected = ValidationError
else:
expected = exc_class

with pytest.raises(expected, match="Oh no!") as rec:
convert(msg, type=List[typ])
convert(msg, type=List[typ], from_attributes=from_attributes)

if expected is ValidationError:
assert "- at `$[0]`" in str(rec.value)
Expand Down
Loading