diff --git a/evap/rewards/forms.py b/evap/rewards/forms.py index 71993b2ac..1a06d0ca1 100644 --- a/evap/rewards/forms.py +++ b/evap/rewards/forms.py @@ -1,8 +1,10 @@ +from contextlib import contextmanager from datetime import date from django import forms from django.core.exceptions import ValidationError from django.core.validators import MaxValueValidator, StepValueValidator +from django.db import transaction from django.utils.translation import gettext as _ from evap.rewards.models import RewardPointRedemption, RewardPointRedemptionEvent @@ -48,9 +50,40 @@ def clean_event(self): class BaseRewardPointRedemptionFormSet(forms.BaseFormSet): def __init__(self, *args, **kwargs): self.user = kwargs.pop("user") + self.total_points_available = reward_points_of_user(self.user) super().__init__(*args, **kwargs) + self.locked = False + + def get_form_kwargs(self, index): + """ + Return additional keyword arguments for each individual formset form. + + index will be None if the form being constructed is a new empty + form. + """ + kwargs = self.form_kwargs.copy() + if not self.initial: + return kwargs + kwargs["initial"] = self.initial[index] + kwargs["initial"]["total_points_available"] = self.total_points_available + return kwargs + + @contextmanager + def lock(self): + with transaction.atomic(): + # lock these rows to prevent race conditions + list(self.user.reward_point_grantings.select_for_update()) + list(self.user.reward_point_redemptions.select_for_update()) + + self.locked = True + + yield + + self.locked = False def clean(self): + assert self.locked + if any(self.errors): return @@ -64,9 +97,7 @@ def clean(self): raise ValidationError(_("You don't have enough reward points.")) def save(self) -> list[RewardPointRedemption]: - # lock these rows to prevent race conditions - list(self.user.reward_point_grantings.select_for_update()) - list(self.user.reward_point_redemptions.select_for_update()) + assert self.locked created = [] for form in self.forms: diff --git a/evap/rewards/views.py b/evap/rewards/views.py index 17589d6bb..68e8c4e22 100644 --- a/evap/rewards/views.py +++ b/evap/rewards/views.py @@ -4,7 +4,6 @@ from django.contrib import messages from django.contrib.messages.views import SuccessMessageMixin from django.core.exceptions import BadRequest, SuspiciousOperation -from django.db import transaction from django.db.models import Sum from django.http import HttpResponse from django.shortcuts import get_object_or_404, redirect, render @@ -33,16 +32,15 @@ def index(request): status = 200 - with transaction.atomic(): - total_points_available = reward_points_of_user(request.user) - events = RewardPointRedemptionEvent.objects.filter(redeem_end_date__gte=datetime.now().date()).order_by("date") + events = RewardPointRedemptionEvent.objects.filter(redeem_end_date__gte=datetime.now().date()).order_by("date") - # pylint: disable=unexpected-keyword-arg - formset = RewardPointRedemptionFormSet( - request.POST or None, - initial=[{"event": e, "points": 0, "total_points_available": total_points_available} for e in events], - user=request.user, - ) + # pylint: disable=unexpected-keyword-arg + formset = RewardPointRedemptionFormSet( + request.POST or None, + initial=[{"event": e, "points": 0} for e in events], + user=request.user, + ) + with formset.lock(): if request.method == "POST": try: previous_redeemed_points = int(request.POST["previous_redeemed_points"])