1
+ from contextlib import contextmanager
1
2
from datetime import date
2
3
3
4
from django import forms
4
5
from django .core .exceptions import ValidationError
5
6
from django .core .validators import MaxValueValidator , StepValueValidator
7
+ from django .db import transaction
6
8
from django .utils .translation import gettext as _
7
9
8
10
from evap .rewards .models import RewardPointRedemption , RewardPointRedemptionEvent
@@ -48,9 +50,40 @@ def clean_event(self):
48
50
class BaseRewardPointRedemptionFormSet (forms .BaseFormSet ):
49
51
def __init__ (self , * args , ** kwargs ):
50
52
self .user = kwargs .pop ("user" )
53
+ self .total_points_available = reward_points_of_user (self .user )
51
54
super ().__init__ (* args , ** kwargs )
55
+ self .locked = False
56
+
57
+ def get_form_kwargs (self , index ):
58
+ """
59
+ Return additional keyword arguments for each individual formset form.
60
+
61
+ index will be None if the form being constructed is a new empty
62
+ form.
63
+ """
64
+ kwargs = self .form_kwargs .copy ()
65
+ if not self .initial :
66
+ return kwargs
67
+ kwargs ["initial" ] = self .initial [index ]
68
+ kwargs ["initial" ]["total_points_available" ] = self .total_points_available
69
+ return kwargs
70
+
71
+ @contextmanager
72
+ def lock (self ):
73
+ with transaction .atomic ():
74
+ # lock these rows to prevent race conditions
75
+ list (self .user .reward_point_grantings .select_for_update ())
76
+ list (self .user .reward_point_redemptions .select_for_update ())
77
+
78
+ self .locked = True
79
+
80
+ yield
81
+
82
+ self .locked = False
52
83
53
84
def clean (self ):
85
+ assert self .locked
86
+
54
87
if any (self .errors ):
55
88
return
56
89
@@ -64,9 +97,7 @@ def clean(self):
64
97
raise ValidationError (_ ("You don't have enough reward points." ))
65
98
66
99
def save (self ) -> list [RewardPointRedemption ]:
67
- # lock these rows to prevent race conditions
68
- list (self .user .reward_point_grantings .select_for_update ())
69
- list (self .user .reward_point_redemptions .select_for_update ())
100
+ assert self .locked
70
101
71
102
created = []
72
103
for form in self .forms :
0 commit comments