diff --git a/corehq/apps/accounting/models.py b/corehq/apps/accounting/models.py index c93739521c22..b3b22e20fda6 100644 --- a/corehq/apps/accounting/models.py +++ b/corehq/apps/accounting/models.py @@ -529,7 +529,7 @@ def autopay_card(self): def get_domains(self): return list(Subscription.visible_objects.filter(account_id=self.id, is_active=True).values_list( - 'subscriber__domain', flat=True)) + 'subscriber__domain', flat=True).order_by('subscriber__domain')) def has_enterprise_admin(self, email): lower_emails = [e.lower() for e in self.enterprise_admin_emails] diff --git a/corehq/apps/enterprise/api/keyset_paginator.py b/corehq/apps/enterprise/api/keyset_paginator.py new file mode 100644 index 000000000000..d6f3f847280f --- /dev/null +++ b/corehq/apps/enterprise/api/keyset_paginator.py @@ -0,0 +1,86 @@ +from itertools import islice +from django.http.request import QueryDict +from urllib.parse import urlencode +from tastypie.paginator import Paginator + + +class KeysetPaginator(Paginator): + def __init__(self, request_data, objects, + resource_uri=None, limit=None, max_limit=1000, collection_name='objects'): + 'objects is expected to be an iterator' + super().__init__( + request_data, + objects, + resource_uri=resource_uri, + limit=limit, + max_limit=max_limit, + collection_name=collection_name + ) + + def get_offset(self): + raise NotImplementedError() + + def get_slice(self, limit, offset): + raise NotImplementedError() + + def get_count(self): + raise NotImplementedError() + + def get_previous(self, limit, offset): + raise NotImplementedError() + + def get_next(self, limit, **next_params): + return self._generate_uri(limit, **next_params) + + def _generate_uri(self, limit, **next_params): + if self.resource_uri is None: + return None + + if isinstance(self.request_data, QueryDict): + # Because QueryDict allows multiple values for the same key, we need to remove existing values + # prior to updating + request_params = self.request_data.copy() + if 'limit' in request_params: + del request_params['limit'] + for key in next_params: + if key in request_params: + del request_params[key] + + request_params.update({'limit': str(limit), **next_params}) + encoded_params = request_params.urlencode() + else: + request_params = {} + for k, v in self.request_data.items(): + if isinstance(v, str): + request_params[k] = v.encode('utf-8') + else: + request_params[k] = v + + request_params.update({'limit': limit, **next_params}) + encoded_params = urlencode(request_params) + + return '%s?%s' % ( + self.resource_uri, + encoded_params + ) + + def page(self): + """ + Generates all pertinent data about the requested page. + """ + limit = self.get_limit() + + objects = list(islice(self.objects, limit if limit else None)) + meta = { + 'limit': limit, + } + + if limit: + next_params = self.objects.get_next_query_params() + if next_params: + meta['next'] = self.get_next(limit, **next_params) + + return { + self.collection_name: objects, + 'meta': meta, + } diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index 09d5b85d120d..a23b58674c86 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -19,9 +19,9 @@ from corehq.apps.api.resources import HqBaseResource from corehq.apps.api.resources.auth import ODataAuthentication from corehq.apps.api.resources.meta import get_hq_throttle -from corehq.apps.enterprise.enterprise import ( - EnterpriseReport, -) +from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator +from corehq.apps.enterprise.enterprise import EnterpriseReport +from corehq.apps.enterprise.iterators import get_enterprise_form_iterator from corehq.apps.enterprise.tasks import generate_enterprise_report, ReportTaskProgress @@ -351,6 +351,7 @@ def get_primary_keys(self): class FormSubmissionResource(ODataEnterpriseReportResource): class Meta(ODataEnterpriseReportResource.Meta): + paginator_class = KeysetPaginator limit = 10000 max_limit = 20000 @@ -363,26 +364,25 @@ class Meta(ODataEnterpriseReportResource.Meta): REPORT_SLUG = EnterpriseReport.FORM_SUBMISSIONS - def get_report_task(self, request): - enddate = datetime.strptime(request.GET['enddate'], '%Y-%m-%d') if 'enddate' in request.GET else None - startdate = datetime.strptime(request.GET['startdate'], '%Y-%m-%d') if 'startdate' in request.GET else None + def get_object_list(self, request): + # TODO: logic to handle when the start/end date are null or last 30 days + start_date = request.GET.get('start_date', None) + end_date = request.GET.get('end_date', None) + last_domain = request.GET.get('last_domain', None) + last_time = request.GET.get('last_time', None) + last_id = request.GET.get('last_id', None) + account = BillingAccount.get_account_by_domain(request.domain) - return generate_enterprise_report.s( - self.REPORT_SLUG, - account.id, - request.couch_user.username, - start_date=startdate, - end_date=enddate, - include_form_id=True, - ) + + return get_enterprise_form_iterator(account, start_date, end_date, last_domain, last_time, last_id) def dehydrate(self, bundle): - bundle.data['form_id'] = bundle.obj[0] - bundle.data['form_name'] = bundle.obj[1] - bundle.data['submitted'] = self.convert_datetime(bundle.obj[2]) - bundle.data['app_name'] = bundle.obj[3] - bundle.data['mobile_user'] = bundle.obj[4] - bundle.data['domain'] = bundle.obj[6] + bundle.data['form_id'] = bundle.obj['form']['meta']['instanceID'] + bundle.data['form_name'] = bundle.obj['@name'] + bundle.data['submitted'] = self.convert_datetime(bundle.obj['received_on']) + bundle.data['app_name'] = 'App Name' # TODO bundle.obj[3] + bundle.data['mobile_user'] = bundle.obj['form']['meta']['username'] + bundle.data['domain'] = bundle.obj['domain'] return bundle diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 1fa96c54c132..1378234009f0 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,3 +1,9 @@ +from datetime import datetime +from corehq.apps.es import filters +from corehq.apps.es.forms import FormES +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper + + def raise_after_max_elements(it, max_elements, exception=None): for total_yielded, ele in enumerate(it): if total_yielded >= max_elements: @@ -5,3 +11,81 @@ def raise_after_max_elements(it, max_elements, exception=None): raise exception yield ele + + +def get_enterprise_form_iterator(account, start_date, end_date, last_domain=None, last_time=None, last_id=None): + domains = account.get_domains() + + it = multi_domain_form_generator(domains, start_date, end_date, last_domain, last_time, last_id) + return ResumableIteratorWrapper(it, lambda ele: { + 'domain': ele['domain'], + 'received_on': ele['received_on'], + 'id': ele['form']['meta']['instanceID'] + }) + + +def multi_domain_form_generator(domains, start_date, end_date, last_domain=None, last_time=None, last_id=None): + domain_index = domains.index(last_domain) if last_domain else 0 + + def _get_domain_iterator(last_time=None, last_id=None): + if domain_index >= len(domains): + return None + domain = domains[domain_index] + return domain_form_generator(domain, start_date, end_date, last_time, last_id) + + current_iterator = _get_domain_iterator(last_time, last_id) + + while current_iterator: + yield from current_iterator + domain_index += 1 + if domain_index >= len(domains): + break + current_iterator = _get_domain_iterator() + + +def domain_form_generator(domain, start_date, end_date, last_time=None, last_id=None): + if not last_time: + last_time = datetime.now() + + while True: + query = create_domain_query(domain, start_date, end_date, last_time, last_id) + results = query.run() + for form in results.hits: + last_form_fetched = form + yield last_form_fetched + + if len(results.hits) >= results.total: + break + else: + last_time = last_form_fetched['received_on'] + last_id = last_form_fetched['_id'] + + +def create_domain_query(domain, start_date, end_date, last_time, last_id): + CURSOR_SIZE = 5 + + query = ( + FormES() + .domain(domain) + .user_type('mobile') + .submitted(gte=start_date, lte=end_date) + .size(CURSOR_SIZE) + ) + + query.es_query['sort'] = [ + {'received_on': {'order': 'desc'}}, + {'form.meta.instanceID': 'asc'} + ] + + if last_id: + query = query.filter(filters.OR( + filters.AND( + filters.term('received_on', last_time), + filters.range_filter('form.meta.instanceID', gt=last_id) + ), + filters.range_filter('received_on', lt=last_time) + )) + else: + query = query.submitted(lte=last_time) + + return query diff --git a/corehq/apps/enterprise/resumable_iterator_wrapper.py b/corehq/apps/enterprise/resumable_iterator_wrapper.py new file mode 100644 index 000000000000..d9756781ef2a --- /dev/null +++ b/corehq/apps/enterprise/resumable_iterator_wrapper.py @@ -0,0 +1,31 @@ +class ResumableIteratorWrapper: + def __init__(self, sequence, get_element_properties_fn=None): + self.it = iter(sequence) + self.prev_element = None + self.iteration_started = False + self.is_complete = False + + self.get_element_properties_fn = get_element_properties_fn + if not self.get_element_properties_fn: + self.get_element_properties_fn = lambda ele: {'value': ele} + + def __iter__(self): + return self + + def __next__(self): + self.iteration_started = True + try: + self.prev_element = next(self.it) + except StopIteration: + self.is_complete = True + raise + + return self.prev_element + + def get_next_query_params(self): + if self.is_complete: + return None + if not self.iteration_started: + return {} + + return self.get_element_properties_fn(self.prev_element) diff --git a/corehq/apps/enterprise/tests/api/__init__.py b/corehq/apps/enterprise/tests/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py b/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py new file mode 100644 index 000000000000..deceeef6abb9 --- /dev/null +++ b/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py @@ -0,0 +1,66 @@ +from django.test import SimpleTestCase +from django.http import QueryDict +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator + + +class KeysetPaginatorTests(SimpleTestCase): + def test_page_fetches_all_results_below_limit(self): + objects = ResumableIteratorWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects, limit=10) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2, 3, 4]) + self.assertEqual(page['meta'], {'limit': 10}) + + def test_page_includes_next_information_when_more_results_are_available(self): + objects = ResumableIteratorWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(QueryDict(), objects, resource_uri='http://test.com/', limit=3) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2]) + self.assertEqual(page['meta'], {'limit': 3, 'next': 'http://test.com/?limit=3&next=2'}) + + def test_does_not_include_duplicate_limits(self): + request_data = QueryDict(mutable=True) + request_data['limit'] = 3 + objects = ResumableIteratorWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&next=2') + + def test_supports_dict_request_data(self): + request_data = { + 'limit': 3, + 'some_param': 'yes' + } + objects = ResumableIteratorWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&some_param=yes&next=2') + + def test_get_offset_not_implemented(self): + objects = ResumableIteratorWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_offset() + + def test_get_slice_not_implemented(self): + objects = ResumableIteratorWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_slice(limit=10, offset=20) + + def test_get_count_not_implemented(self): + objects = ResumableIteratorWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_count() + + def test_get_previous_not_implemented(self): + objects = ResumableIteratorWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_previous(limit=10, offset=20) diff --git a/corehq/apps/enterprise/tests/test_iterators.py b/corehq/apps/enterprise/tests/test_iterators.py index 3ffc80c12a7c..7119be628eb4 100644 --- a/corehq/apps/enterprise/tests/test_iterators.py +++ b/corehq/apps/enterprise/tests/test_iterators.py @@ -1,6 +1,15 @@ -from django.test import SimpleTestCase +from django.test import SimpleTestCase, TestCase +from datetime import datetime -from corehq.apps.enterprise.iterators import raise_after_max_elements +from corehq.apps.es.forms import form_adapter +from corehq.apps.es.tests.utils import es_test +from corehq.apps.users.models import CommCareUser +from corehq.form_processor.tests.utils import create_form_for_test +from corehq.apps.enterprise.iterators import ( + raise_after_max_elements, + domain_form_generator, + multi_domain_form_generator, +) class TestRaiseAfterMaxElements(SimpleTestCase): @@ -17,3 +26,121 @@ def test_iterating_beyond_max_items_will_raise_provided_exception(self): def test_can_iterate_through_all_elements_with_no_exception(self): it = raise_after_max_elements([1, 2, 3], 3) self.assertEqual(list(it), [1, 2, 3]) + + +@es_test(requires=[form_adapter]) +class TestMultiDomainFormGenerator(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + def test_iterates_through_multiple_domains(self): + form1 = self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('domain2', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form3 = self._create_form('domain3', form_id='3', received_on=datetime(year=2024, month=7, day=3)) + form_adapter.bulk_index([form1, form2, form3], refresh=True) + + it = multi_domain_form_generator( + ['domain1', 'domain2', 'domain3'], + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2', '3']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on + ) + + +@es_test(requires=[form_adapter]) +class TestDomainFormGenerator(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + def test_iterates_through_all_forms_in_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=3)) + form3 = self._create_form('test-domain', form_id='3', received_on=datetime(year=2024, month=7, day=4)) + form_adapter.bulk_index([form1, form2, form3], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['3', '2', '1']) + + def test_includes_inclusive_boundaries(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['2', '1']) + + def test_ignores_form_in_another_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('not-test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1']) + + def test_sorts_by_date_then_id(self): + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form_adapter.bulk_index([form2, form1], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on + ) diff --git a/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py b/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py new file mode 100644 index 000000000000..d8c6afc60f8d --- /dev/null +++ b/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py @@ -0,0 +1,49 @@ +from django.test import SimpleTestCase +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper + + +class ResumableIteratorWrapperTests(SimpleTestCase): + def test_can_iterate_through_a_wrapped_iterator(self): + initial_it = iter(range(5)) + it = ResumableIteratorWrapper(initial_it) + self.assertEqual(list(it), [0, 1, 2, 3, 4]) + + def test_can_iterate_through_a_sequence(self): + sequence = [0, 1, 2, 3, 4] + it = ResumableIteratorWrapper(sequence) + self.assertEqual(list(it), [0, 1, 2, 3, 4]) + + def test_get_next_query_params_returns_empty_object_prior_to_iteration(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + it = ResumableIteratorWrapper(seq, ) + self.assertEqual(it.get_next_query_params(), {}) + + def test_default_get_next_query_params_returns_identity_object(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + it = ResumableIteratorWrapper(seq, ) + next(it) + self.assertEqual(it.get_next_query_params(), {'value': {'key': 'one', 'val': 'val1'}}) + + def test_custom_get_next_query_params_fn(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + + def custom_element_properties_fn(ele): + return (ele['key'], ele['val']) + + it = ResumableIteratorWrapper(seq, custom_element_properties_fn) + next(it) + self.assertEqual(it.get_next_query_params(), ('one', 'val1')) + + def test_get_next_query_params_returns_none_when_fully_iterated(self): + it = ResumableIteratorWrapper(range(5)) + list(it) + self.assertIsNone(it.get_next_query_params())