diff --git a/djongo/models/fields.py b/djongo/models/fields.py index d34d8aa2..98314e50 100644 --- a/djongo/models/fields.py +++ b/djongo/models/fields.py @@ -18,6 +18,7 @@ import functools import json import typing +from typing import Any from bson import ObjectId from django import forms @@ -211,7 +212,6 @@ def value_from_object(self, obj): value = super().value_from_object(obj) if value is None: return None - container_obj = self.model_container(**value) processed_value = self._obj_thru_fields('value_from_object', container_obj) return processed_value @@ -233,7 +233,7 @@ def get_db_prep_save(self, value, connection): value, connection) return processed_value - + def get_prep_value(self, value): if (value is None or not isinstance(value, self.base_type)): @@ -256,7 +256,25 @@ def to_python(self, value): if isinstance(value, str): value = json.loads(value) - + + if isinstance(value, self.model_container): + value = {field.attname: getattr(value, field.attname) + for field in value._meta.fields} + + new_value = [] + + if isinstance(value, list): + for val in value: + if isinstance(val, self.model_container): + new_value.append({field.attname: getattr(val, field.attname) + for field in val._meta.fields}) + if isinstance(val, dict): + new_value.append(val) + value = new_value + + if type(value) == dict and self.base_type == list: + value = [value] + if not isinstance(value, self.base_type): raise ValidationError( f'Value: {value} must be an instance of {self.base_type}') @@ -296,7 +314,6 @@ def formfield(self, **kwargs): 'model_form_kw': self.model_form_kwargs, 'name': self.attname } - defaults.update(kwargs) return super().formfield(**defaults) @@ -373,6 +390,16 @@ def validate(self, value, model_instance, validate_parent=True): for _dict in value: super().validate(_dict, model_instance, validate_parent=False) + def formfield(self, **kwargs): + defaults = { + 'form_class': ArrayFormField, + 'model_container': self.model_container, + 'model_form_class': self.model_form_class, + 'model_form_kw': self.model_form_kwargs, + 'name': self.attname + } + return super().formfield(**defaults) + def _get_model_form_class(model_form_class, model_container, admin, request): if not model_form_class: @@ -399,14 +426,14 @@ def add_fields(self, form, index): class ArrayFormField(forms.Field): - def __init__(self, name, model_form_class, model_container, mdl_form_kw_l, + def __init__(self, name, model_form_class, model_container, model_form_kw, widget=None, admin=None, request=None, *args, **kwargs): self.name = name self.model_container = model_container self.model_form_class = _get_model_form_class( model_form_class, model_container, admin, request) - self.mdl_form_kw_l = mdl_form_kw_l + self.mdl_form_kw_l = model_form_kw self.admin = admin self.request = request @@ -416,7 +443,6 @@ def __init__(self, name, model_form_class, model_container, mdl_form_kw_l, error_messages = { 'incomplete': 'Enter all required fields.', } - self.ArrayFormSet = forms.formset_factory( self.model_form_class, formset=NestedFormSet, can_delete=True) super().__init__(error_messages=error_messages, @@ -442,6 +468,10 @@ def clean(self, value): def has_changed(self, initial, data): form_set_initial = [] for init in initial or []: + empty_model = self.model_container + for key, val in init.items(): + setattr(empty_model, key, val) + init = empty_model form_set_initial.append( forms.model_to_dict( init, @@ -459,9 +489,8 @@ def get_bound_field(self, form, field_name): class ArrayFormBoundField(forms.BoundField): def __init__(self, form, field, name): super().__init__(form, field, name) - data = self.data if form.is_bound else None - initial = [] + initial = self.value() if self.initial is not None: for ini in self.initial: if isinstance(ini, Model): @@ -471,7 +500,6 @@ def __init__(self, form, field, name): fields=field.model_form_class._meta.fields, exclude=field.model_form_class._meta.exclude )) - self.form_set = field.ArrayFormSet(data, initial=initial, prefix=self.html_name) def __getitem__(self, idx): @@ -591,8 +619,21 @@ class EmbeddedFormBoundField(forms.BoundField): def __str__(self): instance = self.value() + empty_model = self.field.model_form._meta.model + if instance: + # The model_form_class expects a Model object. + if type(instance) == dict: + for key, val in instance.items(): + setattr(empty_model, key, val) + instance = empty_model + if(type(instance)== list): + if instance == []: + instance = None + else: + for key, val in instance[0].items(): + setattr(empty_model, key, val) + instance = empty_model model_form = self.field.model_form_class(instance=instance, **self.field.model_form_kwargs) - return mark_safe(f'\n{ model_form.as_table() }\n
') @@ -608,6 +649,9 @@ def decompress(self, value): return value elif isinstance(value, Model): return [getattr(value, f_n) for f_n in self.field_names] + elif isinstance(value, dict): + # On update, a dict was input into here. + return value else: raise forms.ValidationError('Expected model-form')