From 2ac69105ceef39d411f3b7d57ab1dc224975ced0 Mon Sep 17 00:00:00 2001 From: Andy Babic Date: Sun, 27 Oct 2024 08:28:05 +0000 Subject: [PATCH] Update ClusterableModel.serializable_data() to support the exclude_fields argument and support '__' in exclude_fields to control the representation of child items --- modelcluster/models.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/modelcluster/models.py b/modelcluster/models.py index 0692f95..6af41b6 100644 --- a/modelcluster/models.py +++ b/modelcluster/models.py @@ -232,23 +232,30 @@ def save(self, **kwargs): for field in m2m_fields_to_commit: getattr(self, field).commit() - def serializable_data(self): - obj = get_serializable_data_for_fields(self) + def serializable_data(self, exclude_fields=None): + obj = get_serializable_data_for_fields(self, exclude_fields=exclude_fields) + + # normalize exclude_fields to a set + exclude = set(exclude_fields or ()) for rel in get_all_child_relations(self): rel_name = rel.get_accessor_name() - children = getattr(self, rel_name).all() + if rel_name in exclude: + continue + # define a subset of exclude_fields for this relationship + rel_exclude = {f[len(rel_name) + 2:] for f in exclude if f.startswith(rel_name + '__')} + + # serialize children to a list, using only the fields we need + children = getattr(self, rel_name).all().defer(*rel_exclude).iterator() if hasattr(rel.related_model, 'serializable_data'): - obj[rel_name] = [child.serializable_data() for child in children] + obj[rel_name] = [child.serializable_data(exclude_fields=rel_exclude) for child in children] else: - obj[rel_name] = [get_serializable_data_for_fields(child) for child in children] + obj[rel_name] = [get_serializable_data_for_fields(child, exclude_fields=rel_exclude) for child in children] for field in get_all_child_m2m_relations(self): - if field.serialize: - children = getattr(self, field.name).all() - obj[field.name] = [child.pk for child in children] - + if field.serialize and field.name not in exclude: + obj[field.name] = list(getattr(self, field.name).all().values_list('pk', flat=True)) return obj def to_json(self):