Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/structs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ annotations:
- ``__init__``
- ``__repr__``
- ``__copy__``
- ``__replace__``
- ``__eq__`` & ``__ne__``
- ``__match_args__`` (for Python 3.10+'s `pattern matching`_)
- ``__rich_repr__`` (for pretty printing support with rich_)
Expand Down
140 changes: 77 additions & 63 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -7640,6 +7640,80 @@ Struct_copy(PyObject *self, PyObject *args)
return NULL;
}

static PyObject *
Struct_replace(
PyObject *self,
PyObject *const *args,
Py_ssize_t nargs,
PyObject *kwnames
) {
Py_ssize_t nkwargs = (kwnames == NULL) ? 0 : PyTuple_GET_SIZE(kwnames);

if (!check_positional_nargs(nargs, 0, 0)) return NULL;

StructMetaObject *struct_type = (StructMetaObject *)Py_TYPE(self);
PyObject *fields = struct_type->struct_fields;
Py_ssize_t nfields = PyTuple_GET_SIZE(fields);
bool is_gc = MS_TYPE_IS_GC(struct_type);
bool should_untrack = is_gc;

PyObject *out = Struct_alloc((PyTypeObject *)struct_type);
if (out == NULL) return NULL;

for (Py_ssize_t i = 0; i < nkwargs; i++) {
PyObject *val;
Py_ssize_t field_index;
PyObject *kwname = PyTuple_GET_ITEM(kwnames, i);

/* Since keyword names are interned, first loop with pointer
* comparisons only. */
for (field_index = 0; field_index < nfields; field_index++) {
PyObject *field = PyTuple_GET_ITEM(fields, field_index);
if (MS_LIKELY(kwname == field)) goto kw_found;
}
for (field_index = 0; field_index < nfields; field_index++) {
PyObject *field = PyTuple_GET_ITEM(fields, field_index);
if (MS_UNICODE_EQ(kwname, field)) goto kw_found;
}

/* Unknown keyword */
PyErr_Format(
PyExc_TypeError, "`%.200s` has no field '%U'",
((PyTypeObject *)struct_type)->tp_name, kwname
);
goto error;

kw_found:
val = args[i];
Py_INCREF(val);
Struct_set_index(out, field_index, val);
if (should_untrack) {
should_untrack = !MS_MAYBE_TRACKED(val);
}
}

for (Py_ssize_t i = 0; i < nfields; i++) {
if (Struct_get_index_noerror(out, i) == NULL) {
PyObject *val = Struct_get_index(self, i);
if (val == NULL) goto error;
if (should_untrack) {
should_untrack = !MS_MAYBE_TRACKED(val);
}
Py_INCREF(val);
Struct_set_index(out, i, val);
}
}

if (is_gc && !should_untrack) {
PyObject_GC_Track(out);
}
return out;

error:
Py_DECREF(out);
return NULL;
}

static AssocList *
AssocList_FromStruct(PyObject *obj) {
if (Py_EnterRecursiveCall(" while serializing an object")) return NULL;
Expand Down Expand Up @@ -7713,81 +7787,20 @@ PyDoc_STRVAR(struct_replace__doc__,
"\n"
"See Also\n"
"--------\n"
"copy.replace\n"
"dataclasses.replace"
);
static PyObject*
struct_replace(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
{
Py_ssize_t nkwargs = (kwnames == NULL) ? 0 : PyTuple_GET_SIZE(kwnames);

if (!check_positional_nargs(nargs, 1, 1)) return NULL;
PyObject *obj = args[0];
if (Py_TYPE(Py_TYPE(obj)) != &StructMetaType) {
PyErr_SetString(PyExc_TypeError, "`struct` must be a `msgspec.Struct`");
return NULL;
}

StructMetaObject *struct_type = (StructMetaObject *)Py_TYPE(obj);
PyObject *fields = struct_type->struct_fields;
Py_ssize_t nfields = PyTuple_GET_SIZE(fields);
bool is_gc = MS_TYPE_IS_GC(struct_type);
bool should_untrack = is_gc;

PyObject *out = Struct_alloc((PyTypeObject *)struct_type);
if (out == NULL) return NULL;

for (Py_ssize_t i = 0; i < nkwargs; i++) {
PyObject *val;
Py_ssize_t field_index;
PyObject *kwname = PyTuple_GET_ITEM(kwnames, i);

/* Since keyword names are interned, first loop with pointer
* comparisons only. */
for (field_index = 0; field_index < nfields; field_index++) {
PyObject *field = PyTuple_GET_ITEM(fields, field_index);
if (MS_LIKELY(kwname == field)) goto kw_found;
}
for (field_index = 0; field_index < nfields; field_index++) {
PyObject *field = PyTuple_GET_ITEM(fields, field_index);
if (MS_UNICODE_EQ(kwname, field)) goto kw_found;
}

/* Unknown keyword */
PyErr_Format(
PyExc_TypeError, "`%.200s` has no field '%U'",
((PyTypeObject *)struct_type)->tp_name, kwname
);
goto error;

kw_found:
val = args[i + 1];
Py_INCREF(val);
Struct_set_index(out, field_index, val);
if (should_untrack) {
should_untrack = !MS_MAYBE_TRACKED(val);
}
}

for (Py_ssize_t i = 0; i < nfields; i++) {
if (Struct_get_index_noerror(out, i) == NULL) {
PyObject *val = Struct_get_index(obj, i);
if (val == NULL) goto error;
if (should_untrack) {
should_untrack = !MS_MAYBE_TRACKED(val);
}
Py_INCREF(val);
Struct_set_index(out, i, val);
}
}

if (is_gc && !should_untrack) {
PyObject_GC_Track(out);
}
return out;

error:
Py_DECREF(out);
return NULL;
return Struct_replace(obj, args + 1, 0, kwnames);
}

PyDoc_STRVAR(struct_asdict__doc__,
Expand Down Expand Up @@ -8047,6 +8060,7 @@ StructMixin_config(StructMetaObject *self, void *closure) {

static PyMethodDef Struct_methods[] = {
{"__copy__", Struct_copy, METH_NOARGS, "copy a struct"},
{"__replace__", (PyCFunction) Struct_replace, METH_FASTCALL | METH_KEYWORDS, "create a new struct with replacements" },
{"__reduce__", Struct_reduce, METH_NOARGS, "reduce a struct"},
{"__rich_repr__", Struct_rich_repr, METH_NOARGS, "rich repr"},
{NULL, NULL},
Expand Down
44 changes: 30 additions & 14 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@

import msgspec
from msgspec import NODEFAULT, UNSET, Struct, defstruct, field
from msgspec.structs import StructConfig, replace
from msgspec.structs import StructConfig

if hasattr(copy, "replace"):
# Added in Python 3.13
copy_replace = copy.replace
else:

def copy_replace(s, **changes):
return s.__replace__(**changes)


@contextmanager
Expand Down Expand Up @@ -2265,27 +2273,35 @@ def test_defstruct_rename(self):
assert Test.__struct_encode_fields__ == ("myField",)


@pytest.fixture(params=["structs.replace", "copy.replace"])
def replace(request):
if request.param == "structs.replace":
return msgspec.structs.replace
else:
return copy_replace


class TestReplace:
def test_replace_no_kwargs(self):
def test_replace_not_a_struct(self):
with pytest.raises(TypeError, match="`struct` must be a `msgspec.Struct`"):
msgspec.structs.replace(1, x=2)

def test_replace_no_kwargs(self, replace):
p = Point(1, 2)
assert replace(p) == p

def test_replace_kwargs(self):
def test_replace_kwargs(self, replace):
p = Point(1, 2)
assert replace(p, x=3) == Point(3, 2)
assert replace(p, y=4) == Point(1, 4)
assert replace(p, x=3, y=4) == Point(3, 4)

def test_replace_unknown_field(self):
def test_replace_unknown_field(self, replace):
p = Point(1, 2)
with pytest.raises(TypeError, match="`Point` has no field 'oops'"):
replace(p, oops=3)

def test_replace_not_a_struct(self):
with pytest.raises(TypeError, match="`struct` must be a `msgspec.Struct`"):
replace(1, x=2)

def test_replace_errors_unset_fields(self):
def test_replace_errors_unset_fields(self, replace):
p = Point(1, 2)
del p.x

Expand All @@ -2297,14 +2313,14 @@ def test_replace_errors_unset_fields(self):

assert replace(p, x=3) == Point(3, 2)

def test_replace_frozen(self):
def test_replace_frozen(self, replace):
class Test(msgspec.Struct, frozen=True):
x: int
y: int

assert replace(Test(1, 2), x=3) == Test(3, 2)

def test_replace_gc_delayed_tracking(self):
def test_replace_gc_delayed_tracking(self, replace):
class Test(msgspec.Struct):
x: int
y: Optional[List[int]]
Expand All @@ -2320,7 +2336,7 @@ class Test(msgspec.Struct):
assert gc.is_tracked(replace(obj, x=1))
assert not gc.is_tracked(replace(obj, y=None))

def test_replace_gc_false(self):
def test_replace_gc_false(self, replace):
class Test(msgspec.Struct, gc=False):
x: int
y: List[int]
Expand All @@ -2329,7 +2345,7 @@ class Test(msgspec.Struct, gc=False):
assert res == Test(3, [1, 2])
assert not gc.is_tracked(res)

def test_replace_reference_counts(self):
def test_replace_reference_counts(self, replace):
class Test(msgspec.Struct):
x: Any
y: int
Expand Down Expand Up @@ -2629,6 +2645,6 @@ def __post_init__(self):

x1 = Ex()
assert count == 1
x2 = replace(x1)
x2 = msgspec.structs.replace(x1)
assert x1 == x2
assert count == 1
Loading