From b9d3b6dd6b0df13d3ebd3b8e75fbd952e0c59d9a Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 23 Oct 2024 11:00:01 -0400 Subject: [PATCH 01/23] Apply sorting to enterprise domain list --- corehq/apps/accounting/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/corehq/apps/accounting/models.py b/corehq/apps/accounting/models.py index c93739521c22..b3b22e20fda6 100644 --- a/corehq/apps/accounting/models.py +++ b/corehq/apps/accounting/models.py @@ -529,7 +529,7 @@ def autopay_card(self): def get_domains(self): return list(Subscription.visible_objects.filter(account_id=self.id, is_active=True).values_list( - 'subscriber__domain', flat=True)) + 'subscriber__domain', flat=True).order_by('subscriber__domain')) def has_enterprise_admin(self, email): lower_emails = [e.lower() for e in self.enterprise_admin_emails] From 463d97b8624da7a01c932177baeefa951af97c12 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 23 Oct 2024 11:06:15 -0400 Subject: [PATCH 02/23] Add resumable iterator wrapper --- .../enterprise/resumable_iterator_wrapper.py | 48 ++++++++++++++ .../tests/test_resumable_iterator_wrapper.py | 66 +++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 corehq/apps/enterprise/resumable_iterator_wrapper.py create mode 100644 corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py diff --git a/corehq/apps/enterprise/resumable_iterator_wrapper.py b/corehq/apps/enterprise/resumable_iterator_wrapper.py new file mode 100644 index 000000000000..04ef51d28c08 --- /dev/null +++ b/corehq/apps/enterprise/resumable_iterator_wrapper.py @@ -0,0 +1,48 @@ +from itertools import islice + + +class ResumableIteratorWrapper: + def __init__(self, sequence_factory_fn, get_element_properties_fn=None, limit=None): + self.limit = limit + + # if a limit exists, increase it by 1 to allow us to check whether additional items remain at the end + padded_limit = limit + 1 if limit else None + self.original_it = iter(sequence_factory_fn(padded_limit)) + self.it = islice(self.original_it, self.limit) + self.prev_element = None + self.iteration_started = False + self.is_complete = False + + self.get_element_properties_fn = get_element_properties_fn + if not self.get_element_properties_fn: + self.get_element_properties_fn = lambda ele: {'value': ele} + + def __iter__(self): + return self + + def __next__(self): + self.iteration_started = True + + try: + self.prev_element = next(self.it) + except StopIteration: + if self.limit and not self.is_complete: + # the end of the limited sequence was reached, check if items beyond the limit remain + try: + next(self.original_it) + except StopIteration: + # the iteration is fully complete -- no additional items can be fetched + self.is_complete = True + else: + self.is_complete = True + raise + + return self.prev_element + + def get_next_query_params(self): + if self.is_complete: + return None + if not self.iteration_started: + return {} + + return self.get_element_properties_fn(self.prev_element) diff --git a/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py b/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py new file mode 100644 index 000000000000..d7b0c3a000fd --- /dev/null +++ b/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py @@ -0,0 +1,66 @@ +from django.test import SimpleTestCase +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper + + +class ResumableIteratorWrapperTests(SimpleTestCase): + def test_can_iterate_through_a_wrapped_iterator(self): + initial_it = iter(range(5)) + it = ResumableIteratorWrapper(lambda _: initial_it) + self.assertEqual(list(it), [0, 1, 2, 3, 4]) + + def test_can_iterate_through_a_sequence(self): + sequence = [0, 1, 2, 3, 4] + it = ResumableIteratorWrapper(lambda _: sequence) + self.assertEqual(list(it), [0, 1, 2, 3, 4]) + + def test_can_limit_a_sequence(self): + sequence = [0, 1, 2, 3, 4] + it = ResumableIteratorWrapper(lambda _: sequence, limit=4) + self.assertEqual(list(it), [0, 1, 2, 3]) + + def test_when_limit_is_less_than_sequence_length_is_incomplete(self): + sequence = [0, 1, 2, 3, 4] + it = ResumableIteratorWrapper(lambda _: sequence, limit=4) + list(it) + self.assertFalse(it.is_complete) + + def test_when_limit_matches_sequence_size_iterator_is_complete(self): + sequence = [0, 1, 2, 3, 4] + it = ResumableIteratorWrapper(lambda _: sequence, limit=5) + list(it) + self.assertTrue(it.is_complete) + + def test_get_next_query_params_returns_empty_object_prior_to_iteration(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + it = ResumableIteratorWrapper(lambda _: seq) + self.assertEqual(it.get_next_query_params(), {}) + + def test_default_get_next_query_params_returns_identity_object(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + it = ResumableIteratorWrapper(lambda _: seq, ) + next(it) + self.assertEqual(it.get_next_query_params(), {'value': {'key': 'one', 'val': 'val1'}}) + + def test_custom_get_next_query_params_fn(self): + seq = [ + {'key': 'one', 'val': 'val1'}, + {'key': 'two', 'val': 'val2'}, + ] + + def custom_element_properties_fn(ele): + return (ele['key'], ele['val']) + + it = ResumableIteratorWrapper(lambda _: seq, custom_element_properties_fn) + next(it) + self.assertEqual(it.get_next_query_params(), ('one', 'val1')) + + def test_get_next_query_params_returns_none_when_fully_iterated(self): + it = ResumableIteratorWrapper(lambda _: range(5)) + list(it) + self.assertIsNone(it.get_next_query_params()) From 3eccde7f45da97b517944e4cd385501c4f77c402 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 23 Oct 2024 11:10:56 -0400 Subject: [PATCH 03/23] Add KeysetPaginator --- .../apps/enterprise/api/keyset_paginator.py | 94 +++++++++++++++++++ corehq/apps/enterprise/tests/api/__init__.py | 0 .../tests/api/keyset_paginator_tests.py | 75 +++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 corehq/apps/enterprise/api/keyset_paginator.py create mode 100644 corehq/apps/enterprise/tests/api/__init__.py create mode 100644 corehq/apps/enterprise/tests/api/keyset_paginator_tests.py diff --git a/corehq/apps/enterprise/api/keyset_paginator.py b/corehq/apps/enterprise/api/keyset_paginator.py new file mode 100644 index 000000000000..17896e3c3a4f --- /dev/null +++ b/corehq/apps/enterprise/api/keyset_paginator.py @@ -0,0 +1,94 @@ +from django.http.request import QueryDict +from urllib.parse import urlencode +from tastypie.paginator import Paginator + + +class KeysetPaginator(Paginator): + ''' + An alternate paginator meant to support paginating by keyset rather than by index/offset. + `objects` is expected to represent a query object that exposes an `.execute(limit)` + method that returns an iterable. + The above returned iterable must expose a `.get_next_query_params()` method that will return + parameters to allow the user to fetch the next page of data. + Because keyset pagination does not efficiently handle slicing or offset operations, + these methods have been intentionally disabled + ''' + def __init__(self, request_data, objects, + resource_uri=None, limit=None, max_limit=1000, collection_name='objects'): + super().__init__( + request_data, + objects, + resource_uri=resource_uri, + limit=limit, + max_limit=max_limit, + collection_name=collection_name + ) + + def get_offset(self): + raise NotImplementedError() + + def get_slice(self, limit, offset): + raise NotImplementedError() + + def get_count(self): + raise NotImplementedError() + + def get_previous(self, limit, offset): + raise NotImplementedError() + + def get_next(self, limit, **next_params): + return self._generate_uri(limit, **next_params) + + def _generate_uri(self, limit, **next_params): + if self.resource_uri is None: + return None + + if isinstance(self.request_data, QueryDict): + # Because QueryDict allows multiple values for the same key, we need to remove existing values + # prior to updating + request_params = self.request_data.copy() + if 'limit' in request_params: + del request_params['limit'] + for key in next_params: + if key in request_params: + del request_params[key] + + request_params.update({'limit': str(limit), **next_params}) + encoded_params = request_params.urlencode() + else: + request_params = {} + for k, v in self.request_data.items(): + if isinstance(v, str): + request_params[k] = v.encode('utf-8') + else: + request_params[k] = v + + request_params.update({'limit': limit, **next_params}) + encoded_params = urlencode(request_params) + + return '%s?%s' % ( + self.resource_uri, + encoded_params + ) + + def page(self): + """ + Generates all pertinent data about the requested page. + """ + limit = self.get_limit() + it = self.objects.execute(limit=limit) + objects = list(it) + + meta = { + 'limit': limit, + } + + if limit: + next_params = it.get_next_query_params() + if next_params: + meta['next'] = self.get_next(limit, **next_params) + + return { + self.collection_name: objects, + 'meta': meta, + } diff --git a/corehq/apps/enterprise/tests/api/__init__.py b/corehq/apps/enterprise/tests/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py b/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py new file mode 100644 index 000000000000..56e69b6f0415 --- /dev/null +++ b/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py @@ -0,0 +1,75 @@ +from django.test import SimpleTestCase +from django.http import QueryDict +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator + + +class SequenceWrapper: + def __init__(self, seq, get_next_fn=None): + self.seq = seq + self.get_next_fn = get_next_fn + + def execute(self, limit=None): + return ResumableIteratorWrapper(lambda _: self.seq, self.get_next_fn, limit=limit) + + +class KeysetPaginatorTests(SimpleTestCase): + def test_page_fetches_all_results_below_limit(self): + objects = SequenceWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects, limit=10) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2, 3, 4]) + self.assertEqual(page['meta'], {'limit': 10}) + + def test_page_includes_next_information_when_more_results_are_available(self): + objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(QueryDict(), objects, resource_uri='http://test.com/', limit=3) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2]) + self.assertEqual(page['meta'], {'limit': 3, 'next': 'http://test.com/?limit=3&next=2'}) + + def test_does_not_include_duplicate_limits(self): + request_data = QueryDict(mutable=True) + request_data['limit'] = 3 + objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&next=2') + + def test_supports_dict_request_data(self): + request_data = { + 'limit': 3, + 'some_param': 'yes' + } + objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&some_param=yes&next=2') + + def test_get_offset_not_implemented(self): + objects = SequenceWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_offset() + + def test_get_slice_not_implemented(self): + objects = SequenceWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_slice(limit=10, offset=20) + + def test_get_count_not_implemented(self): + objects = SequenceWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_count() + + def test_get_previous_not_implemented(self): + objects = SequenceWrapper(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_previous(limit=10, offset=20) From 51543600455eb3b89c57e84682b14091f54902b9 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 23 Oct 2024 11:18:47 -0400 Subject: [PATCH 04/23] Added enterprise form iterators --- corehq/apps/enterprise/iterators.py | 175 +++++++++++++++++ .../apps/enterprise/tests/test_iterators.py | 176 +++++++++++++++++- 2 files changed, 349 insertions(+), 2 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 1fa96c54c132..667d087f2d08 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,3 +1,12 @@ +from datetime import datetime, timedelta +from django.utils.translation import gettext as _ +from corehq.apps.es import filters +from corehq.apps.es.forms import FormES +from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.enterprise.exceptions import TooMuchRequestedDataError +from corehq.apps.app_manager.dbaccessors import get_brief_apps_in_domain + + def raise_after_max_elements(it, max_elements, exception=None): for total_yielded, ele in enumerate(it): if total_yielded >= max_elements: @@ -5,3 +14,169 @@ def raise_after_max_elements(it, max_elements, exception=None): raise exception yield ele + + +class IterableEnterpriseFormQuery: + ''' + A class representing a query that returns its results as an iterator + The intended use case is to support queries that cross pagination boundaries + ''' + def __init__(self, account, start_date, end_date, last_domain, last_time, last_id): + MAX_DATE_RANGE_DAYS = 100 + (self.start_date, self.end_date) = resolve_start_and_end_date(start_date, end_date, MAX_DATE_RANGE_DAYS) + self.account = account + self.last_domain = last_domain + self.last_time = last_time + self.last_id = last_id + + def execute(self, limit=None): + domains = self.account.get_domains() + + def create_multi_domain_form_generator(limit): + it = multi_domain_form_generator( + domains, + self.start_date, + self.end_date, + self.last_domain, + self.last_time, + self.last_id, + limit=limit + ) + xform_converter = RawFormConverter() + return (xform_converter.convert(form) for form in it) + + return ResumableIteratorWrapper(create_multi_domain_form_generator, lambda form: { + 'domain': form['domain'], + 'received_on': form['submitted'], + 'id': form['form_id'] + }, limit=limit) + + +def resolve_start_and_end_date(start_date, end_date, maximum_date_range): + ''' + Provide start and end date values if not supplied. + ''' + if not end_date: + end_date = datetime.utcnow() + + if not start_date: + start_date = end_date - timedelta(days=30) + + if end_date - start_date > timedelta(days=maximum_date_range): + raise TooMuchRequestedDataError( + _('Date ranges with more than {} days are not supported').format(maximum_date_range) + ) + + return start_date, end_date + + +class RawFormConverter: + def __init__(self): + self.app_lookup = AppIdToNameResolver() + + def convert(self, form): + domain = form['domain'] + submitted_date = datetime.strptime(form['received_on'][:19], '%Y-%m-%dT%H:%M:%S') + + return { + 'form_id': form['form']['meta']['instanceID'], + 'form_name': form['form']['@name'] or _('Unnamed'), + 'submitted': submitted_date, + 'app_name': self.app_lookup.resolve_app_id_to_name(domain, form['app_id']) or _('App not found'), + 'username': form['form']['meta']['username'], + 'domain': domain + } + + +class AppIdToNameResolver: + def __init__(self): + self.domain_lookup_tables = {} + + def resolve_app_id_to_name(self, domain, app_id): + if 'domain' not in self.domain_lookup_tables: + domain_apps = get_brief_apps_in_domain(domain) + self.domain_lookup_tables[domain] = {a.id: a.name for a in domain_apps} + + return self.domain_lookup_tables[domain].get(app_id, None) + + +def multi_domain_form_generator( + domains, start_date, end_date, last_domain=None, last_time=None, last_id=None, limit=None): + domain_index = domains.index(last_domain) if last_domain else 0 + + remaining = limit + + def _get_domain_iterator(last_time=None, last_id=None): + if domain_index >= len(domains): + return None + domain = domains[domain_index] + return domain_form_generator(domain, start_date, end_date, last_time, last_id, limit=remaining) + + current_iterator = _get_domain_iterator(last_time, last_id) + + while current_iterator: + for form in current_iterator: + yield form + if remaining: + remaining -= 1 + if remaining == 0: + return + domain_index += 1 + if domain_index >= len(domains): + return + current_iterator = _get_domain_iterator() + + +def domain_form_generator(domain, start_date, end_date, last_time=None, last_id=None, limit=None): + if not last_time: + last_time = datetime.now() + + remaining = limit + + while True: + query = create_domain_query(domain, start_date, end_date, last_time, last_id, limit=remaining) + results = query.run() + for form in results.hits: + last_form_fetched = form + yield last_form_fetched + + num_fetched = len(results.hits) + + if num_fetched >= results.total or (remaining and num_fetched >= remaining): + break + else: + if remaining: + remaining -= num_fetched + assert remaining > 0 + last_time = last_form_fetched['received_on'] + last_id = last_form_fetched['_id'] + + +def create_domain_query(domain, start_date, end_date, last_time, last_id, limit=None): + query = ( + FormES() + .domain(domain) + .user_type('mobile') + .submitted(gte=start_date, lte=end_date) + ) + + if limit: + query = query.size(limit) + + query.es_query['sort'] = [ + {'received_on': {'order': 'desc'}}, + {'form.meta.instanceID': 'asc'} + ] + + if last_id: + query = query.filter(filters.OR( + filters.AND( + filters.term('received_on', last_time), + filters.range_filter('form.meta.instanceID', gt=last_id) + ), + filters.range_filter('received_on', lt=last_time) + )) + else: + query = query.submitted(lte=last_time) + + return query diff --git a/corehq/apps/enterprise/tests/test_iterators.py b/corehq/apps/enterprise/tests/test_iterators.py index 3ffc80c12a7c..3343bc14089f 100644 --- a/corehq/apps/enterprise/tests/test_iterators.py +++ b/corehq/apps/enterprise/tests/test_iterators.py @@ -1,6 +1,15 @@ -from django.test import SimpleTestCase +from django.test import SimpleTestCase, TestCase +from datetime import datetime -from corehq.apps.enterprise.iterators import raise_after_max_elements +from corehq.apps.es.forms import form_adapter +from corehq.apps.es.tests.utils import es_test +from corehq.apps.users.models import CommCareUser +from corehq.form_processor.tests.utils import create_form_for_test +from corehq.apps.enterprise.iterators import ( + raise_after_max_elements, + domain_form_generator, + multi_domain_form_generator, +) class TestRaiseAfterMaxElements(SimpleTestCase): @@ -17,3 +26,166 @@ def test_iterating_beyond_max_items_will_raise_provided_exception(self): def test_can_iterate_through_all_elements_with_no_exception(self): it = raise_after_max_elements([1, 2, 3], 3) self.assertEqual(list(it), [1, 2, 3]) + + +@es_test(requires=[form_adapter]) +class TestMultiDomainFormGenerator(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + def test_iterates_through_multiple_domains(self): + forms = [ + self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)), + self._create_form('domain2', form_id='2', received_on=datetime(year=2024, month=7, day=2)), + self._create_form('domain3', form_id='3', received_on=datetime(year=2024, month=7, day=3)), + ] + form_adapter.bulk_index(forms, refresh=True) + + it = multi_domain_form_generator( + ['domain1', 'domain2', 'domain3'], + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2', '3']) + + def test_respects_limit_across_multiple_domains(self): + forms = [ + self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)), + self._create_form('domain1', form_id='2', received_on=datetime(year=2024, month=7, day=2)), + self._create_form('domain2', form_id='3', received_on=datetime(year=2024, month=7, day=3)), + self._create_form('domain2', form_id='4', received_on=datetime(year=2024, month=7, day=4)), + ] + form_adapter.bulk_index(forms, refresh=True) + + it = multi_domain_form_generator( + ['domain1', 'domain2'], + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + limit=3 + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['2', '1', '4']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on + ) + + +@es_test(requires=[form_adapter]) +class TestDomainFormGenerator(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + def test_iterates_through_all_forms_in_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=3)) + form3 = self._create_form('test-domain', form_id='3', received_on=datetime(year=2024, month=7, day=4)) + form_adapter.bulk_index([form1, form2, form3], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['3', '2', '1']) + + def test_handles_empty_domain(self): + it = domain_form_generator( + 'empty-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + self.assertEqual(list(it), []) + + def test_includes_inclusive_boundaries(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['2', '1']) + + def test_ignores_form_in_another_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('not-test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1']) + + def test_sorts_by_date_then_id(self): + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form_adapter.bulk_index([form2, form1], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2']) + + def test_does_not_return_forms_beyond_limit(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = domain_form_generator( + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2), + limit=1 + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on + ) From 4112456c31c1f6e0c2b245e10907d23577a72910 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 23 Oct 2024 11:20:27 -0400 Subject: [PATCH 05/23] Rewire FormSubmissionResource to use iterators --- corehq/apps/enterprise/api/resources.py | 56 +++++++++++++++---------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index 09d5b85d120d..a67f16b3574a 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -19,9 +19,9 @@ from corehq.apps.api.resources import HqBaseResource from corehq.apps.api.resources.auth import ODataAuthentication from corehq.apps.api.resources.meta import get_hq_throttle -from corehq.apps.enterprise.enterprise import ( - EnterpriseReport, -) +from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator +from corehq.apps.enterprise.enterprise import EnterpriseReport +from corehq.apps.enterprise.iterators import IterableEnterpriseFormQuery from corehq.apps.enterprise.tasks import generate_enterprise_report, ReportTaskProgress @@ -60,7 +60,8 @@ def alter_list_data_to_serialize(self, request, data): result['@odata.context'] = request.build_absolute_uri(path) meta = result['meta'] - result['@odata.count'] = meta['total_count'] + if 'total_count' in meta: + result['@odata.count'] = meta['total_count'] if 'next' in meta and meta['next']: result['@odata.nextLink'] = request.build_absolute_uri(meta['next']) @@ -139,7 +140,10 @@ def convert_datetime(cls, datetime_string): if not datetime_string: return None - time = datetime.strptime(datetime_string, EnterpriseReport.DATE_ROW_FORMAT) + if isinstance(datetime_string, str): + time = datetime.strptime(datetime_string, EnterpriseReport.DATE_ROW_FORMAT) + else: + time = datetime_string time = time.astimezone(tz.gettz('UTC')) return time.isoformat() @@ -351,6 +355,7 @@ def get_primary_keys(self): class FormSubmissionResource(ODataEnterpriseReportResource): class Meta(ODataEnterpriseReportResource.Meta): + paginator_class = KeysetPaginator limit = 10000 max_limit = 20000 @@ -363,26 +368,33 @@ class Meta(ODataEnterpriseReportResource.Meta): REPORT_SLUG = EnterpriseReport.FORM_SUBMISSIONS - def get_report_task(self, request): - enddate = datetime.strptime(request.GET['enddate'], '%Y-%m-%d') if 'enddate' in request.GET else None - startdate = datetime.strptime(request.GET['startdate'], '%Y-%m-%d') if 'startdate' in request.GET else None + def get_object_list(self, request): + start_date = request.GET.get('startdate', None) + if start_date: + start_date = datetime.fromisoformat(start_date) + + end_date = request.GET.get('enddate', None) + if end_date: + end_date = datetime.fromisoformat(end_date) + + last_time = request.GET.get('received_on', None) + if last_time: + last_time = datetime.fromisoformat(last_time) + + last_domain = request.GET.get('domain', None) + last_id = request.GET.get('id', None) + account = BillingAccount.get_account_by_domain(request.domain) - return generate_enterprise_report.s( - self.REPORT_SLUG, - account.id, - request.couch_user.username, - start_date=startdate, - end_date=enddate, - include_form_id=True, - ) + + return IterableEnterpriseFormQuery(account, start_date, end_date, last_domain, last_time, last_id) def dehydrate(self, bundle): - bundle.data['form_id'] = bundle.obj[0] - bundle.data['form_name'] = bundle.obj[1] - bundle.data['submitted'] = self.convert_datetime(bundle.obj[2]) - bundle.data['app_name'] = bundle.obj[3] - bundle.data['mobile_user'] = bundle.obj[4] - bundle.data['domain'] = bundle.obj[6] + bundle.data['form_id'] = bundle.obj['form_id'] + bundle.data['form_name'] = bundle.obj['form_name'] + bundle.data['submitted'] = self.convert_datetime(bundle.obj['submitted']) + bundle.data['app_name'] = bundle.obj['app_name'] + bundle.data['mobile_user'] = bundle.obj['username'] + bundle.data['domain'] = bundle.obj['domain'] return bundle From 399b0136ad1379abe1b92629032c3b7717109064 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Tue, 29 Oct 2024 12:14:41 -0400 Subject: [PATCH 06/23] Moved generic API classes into the API application --- corehq/apps/{enterprise => }/api/keyset_paginator.py | 8 ++++++++ .../{enterprise => api}/resumable_iterator_wrapper.py | 0 .../tests/api => api/tests}/keyset_paginator_tests.py | 4 ++-- .../tests/test_resumable_iterator_wrapper.py | 2 +- corehq/apps/enterprise/api/resources.py | 2 +- corehq/apps/enterprise/iterators.py | 2 +- corehq/apps/enterprise/tests/api/__init__.py | 0 7 files changed, 13 insertions(+), 5 deletions(-) rename corehq/apps/{enterprise => }/api/keyset_paginator.py (93%) rename corehq/apps/{enterprise => api}/resumable_iterator_wrapper.py (100%) rename corehq/apps/{enterprise/tests/api => api/tests}/keyset_paginator_tests.py (94%) rename corehq/apps/{enterprise => api}/tests/test_resumable_iterator_wrapper.py (96%) delete mode 100644 corehq/apps/enterprise/tests/api/__init__.py diff --git a/corehq/apps/enterprise/api/keyset_paginator.py b/corehq/apps/api/keyset_paginator.py similarity index 93% rename from corehq/apps/enterprise/api/keyset_paginator.py rename to corehq/apps/api/keyset_paginator.py index 17896e3c3a4f..b1f644d073ab 100644 --- a/corehq/apps/enterprise/api/keyset_paginator.py +++ b/corehq/apps/api/keyset_paginator.py @@ -92,3 +92,11 @@ def page(self): self.collection_name: objects, 'meta': meta, } + + +class PageableQueryInterface: + def execute(limit=None): + ''' + Should return an iterable that exposes a `.get_next_query_params()` method + ''' + raise NotImplementedError() diff --git a/corehq/apps/enterprise/resumable_iterator_wrapper.py b/corehq/apps/api/resumable_iterator_wrapper.py similarity index 100% rename from corehq/apps/enterprise/resumable_iterator_wrapper.py rename to corehq/apps/api/resumable_iterator_wrapper.py diff --git a/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py b/corehq/apps/api/tests/keyset_paginator_tests.py similarity index 94% rename from corehq/apps/enterprise/tests/api/keyset_paginator_tests.py rename to corehq/apps/api/tests/keyset_paginator_tests.py index 56e69b6f0415..025f39b13a00 100644 --- a/corehq/apps/enterprise/tests/api/keyset_paginator_tests.py +++ b/corehq/apps/api/tests/keyset_paginator_tests.py @@ -1,7 +1,7 @@ from django.test import SimpleTestCase from django.http import QueryDict -from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper -from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator +from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.api.keyset_paginator import KeysetPaginator class SequenceWrapper: diff --git a/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py b/corehq/apps/api/tests/test_resumable_iterator_wrapper.py similarity index 96% rename from corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py rename to corehq/apps/api/tests/test_resumable_iterator_wrapper.py index d7b0c3a000fd..93fd8eedae98 100644 --- a/corehq/apps/enterprise/tests/test_resumable_iterator_wrapper.py +++ b/corehq/apps/api/tests/test_resumable_iterator_wrapper.py @@ -1,5 +1,5 @@ from django.test import SimpleTestCase -from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper class ResumableIteratorWrapperTests(SimpleTestCase): diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index a67f16b3574a..cf4478a9cb63 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -19,7 +19,7 @@ from corehq.apps.api.resources import HqBaseResource from corehq.apps.api.resources.auth import ODataAuthentication from corehq.apps.api.resources.meta import get_hq_throttle -from corehq.apps.enterprise.api.keyset_paginator import KeysetPaginator +from corehq.apps.api.keyset_paginator import KeysetPaginator from corehq.apps.enterprise.enterprise import EnterpriseReport from corehq.apps.enterprise.iterators import IterableEnterpriseFormQuery diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 667d087f2d08..5a55c365cd85 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -2,7 +2,7 @@ from django.utils.translation import gettext as _ from corehq.apps.es import filters from corehq.apps.es.forms import FormES -from corehq.apps.enterprise.resumable_iterator_wrapper import ResumableIteratorWrapper +from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper from corehq.apps.enterprise.exceptions import TooMuchRequestedDataError from corehq.apps.app_manager.dbaccessors import get_brief_apps_in_domain diff --git a/corehq/apps/enterprise/tests/api/__init__.py b/corehq/apps/enterprise/tests/api/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 From 185a143bb715ffff37583aae46e9f86463948199 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 30 Oct 2024 10:57:53 -0400 Subject: [PATCH 07/23] Removed ResumableIteratorWrapper --- corehq/apps/api/keyset_paginator.py | 28 +++++--- corehq/apps/api/resumable_iterator_wrapper.py | 48 -------------- .../apps/api/tests/keyset_paginator_tests.py | 28 ++++---- .../tests/test_resumable_iterator_wrapper.py | 66 ------------------- corehq/apps/enterprise/api/resources.py | 10 +-- corehq/apps/enterprise/iterators.py | 52 +++++++++------ 6 files changed, 68 insertions(+), 164 deletions(-) delete mode 100644 corehq/apps/api/resumable_iterator_wrapper.py delete mode 100644 corehq/apps/api/tests/test_resumable_iterator_wrapper.py diff --git a/corehq/apps/api/keyset_paginator.py b/corehq/apps/api/keyset_paginator.py index b1f644d073ab..be3310686cf7 100644 --- a/corehq/apps/api/keyset_paginator.py +++ b/corehq/apps/api/keyset_paginator.py @@ -1,3 +1,4 @@ +from itertools import islice from django.http.request import QueryDict from urllib.parse import urlencode from tastypie.paginator import Paginator @@ -7,9 +8,8 @@ class KeysetPaginator(Paginator): ''' An alternate paginator meant to support paginating by keyset rather than by index/offset. `objects` is expected to represent a query object that exposes an `.execute(limit)` - method that returns an iterable. - The above returned iterable must expose a `.get_next_query_params()` method that will return - parameters to allow the user to fetch the next page of data. + method that returns an iterable, and a `get_query_params(object)` method to retrieve the parameters + for the next query Because keyset pagination does not efficiently handle slicing or offset operations, these methods have been intentionally disabled ''' @@ -76,17 +76,25 @@ def page(self): Generates all pertinent data about the requested page. """ limit = self.get_limit() - it = self.objects.execute(limit=limit) - objects = list(it) + padded_limit = limit + 1 if limit else limit + # Fetch 1 more record than requested to allow us to determine if further queries will be needed + it = iter(self.objects.execute(limit=padded_limit)) + objects = list(islice(it, limit)) + + try: + next(it) + has_more = True + except StopIteration: + has_more = False meta = { 'limit': limit, } - if limit: - next_params = it.get_next_query_params() - if next_params: - meta['next'] = self.get_next(limit, **next_params) + if limit and has_more: + last_fetched = objects[-1] + next_page_params = self.objects.get_query_params(last_fetched) + meta['next'] = self.get_next(limit, **next_page_params) return { self.collection_name: objects, @@ -97,6 +105,6 @@ def page(self): class PageableQueryInterface: def execute(limit=None): ''' - Should return an iterable that exposes a `.get_next_query_params()` method + Should return an iterable that exposes a `.get_query_params()` method ''' raise NotImplementedError() diff --git a/corehq/apps/api/resumable_iterator_wrapper.py b/corehq/apps/api/resumable_iterator_wrapper.py deleted file mode 100644 index 04ef51d28c08..000000000000 --- a/corehq/apps/api/resumable_iterator_wrapper.py +++ /dev/null @@ -1,48 +0,0 @@ -from itertools import islice - - -class ResumableIteratorWrapper: - def __init__(self, sequence_factory_fn, get_element_properties_fn=None, limit=None): - self.limit = limit - - # if a limit exists, increase it by 1 to allow us to check whether additional items remain at the end - padded_limit = limit + 1 if limit else None - self.original_it = iter(sequence_factory_fn(padded_limit)) - self.it = islice(self.original_it, self.limit) - self.prev_element = None - self.iteration_started = False - self.is_complete = False - - self.get_element_properties_fn = get_element_properties_fn - if not self.get_element_properties_fn: - self.get_element_properties_fn = lambda ele: {'value': ele} - - def __iter__(self): - return self - - def __next__(self): - self.iteration_started = True - - try: - self.prev_element = next(self.it) - except StopIteration: - if self.limit and not self.is_complete: - # the end of the limited sequence was reached, check if items beyond the limit remain - try: - next(self.original_it) - except StopIteration: - # the iteration is fully complete -- no additional items can be fetched - self.is_complete = True - else: - self.is_complete = True - raise - - return self.prev_element - - def get_next_query_params(self): - if self.is_complete: - return None - if not self.iteration_started: - return {} - - return self.get_element_properties_fn(self.prev_element) diff --git a/corehq/apps/api/tests/keyset_paginator_tests.py b/corehq/apps/api/tests/keyset_paginator_tests.py index 025f39b13a00..22b35a6ca155 100644 --- a/corehq/apps/api/tests/keyset_paginator_tests.py +++ b/corehq/apps/api/tests/keyset_paginator_tests.py @@ -1,28 +1,30 @@ from django.test import SimpleTestCase from django.http import QueryDict -from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper from corehq.apps.api.keyset_paginator import KeysetPaginator -class SequenceWrapper: - def __init__(self, seq, get_next_fn=None): +class SequenceQuery: + def __init__(self, seq): self.seq = seq - self.get_next_fn = get_next_fn def execute(self, limit=None): - return ResumableIteratorWrapper(lambda _: self.seq, self.get_next_fn, limit=limit) + return self.seq + + @classmethod + def get_query_params(cls, form): + return {'next': form} class KeysetPaginatorTests(SimpleTestCase): def test_page_fetches_all_results_below_limit(self): - objects = SequenceWrapper(range(5)) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects, limit=10) page = paginator.page() self.assertEqual(page['objects'], [0, 1, 2, 3, 4]) self.assertEqual(page['meta'], {'limit': 10}) def test_page_includes_next_information_when_more_results_are_available(self): - objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects, resource_uri='http://test.com/', limit=3) page = paginator.page() self.assertEqual(page['objects'], [0, 1, 2]) @@ -31,7 +33,7 @@ def test_page_includes_next_information_when_more_results_are_available(self): def test_does_not_include_duplicate_limits(self): request_data = QueryDict(mutable=True) request_data['limit'] = 3 - objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') page = paginator.page() self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&next=2') @@ -41,34 +43,34 @@ def test_supports_dict_request_data(self): 'limit': 3, 'some_param': 'yes' } - objects = SequenceWrapper(range(5), lambda ele: {'next': ele}) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') page = paginator.page() self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&some_param=yes&next=2') def test_get_offset_not_implemented(self): - objects = SequenceWrapper(range(5)) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects) with self.assertRaises(NotImplementedError): paginator.get_offset() def test_get_slice_not_implemented(self): - objects = SequenceWrapper(range(5)) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects) with self.assertRaises(NotImplementedError): paginator.get_slice(limit=10, offset=20) def test_get_count_not_implemented(self): - objects = SequenceWrapper(range(5)) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects) with self.assertRaises(NotImplementedError): paginator.get_count() def test_get_previous_not_implemented(self): - objects = SequenceWrapper(range(5)) + objects = SequenceQuery(range(5)) paginator = KeysetPaginator(QueryDict(), objects) with self.assertRaises(NotImplementedError): diff --git a/corehq/apps/api/tests/test_resumable_iterator_wrapper.py b/corehq/apps/api/tests/test_resumable_iterator_wrapper.py deleted file mode 100644 index 93fd8eedae98..000000000000 --- a/corehq/apps/api/tests/test_resumable_iterator_wrapper.py +++ /dev/null @@ -1,66 +0,0 @@ -from django.test import SimpleTestCase -from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper - - -class ResumableIteratorWrapperTests(SimpleTestCase): - def test_can_iterate_through_a_wrapped_iterator(self): - initial_it = iter(range(5)) - it = ResumableIteratorWrapper(lambda _: initial_it) - self.assertEqual(list(it), [0, 1, 2, 3, 4]) - - def test_can_iterate_through_a_sequence(self): - sequence = [0, 1, 2, 3, 4] - it = ResumableIteratorWrapper(lambda _: sequence) - self.assertEqual(list(it), [0, 1, 2, 3, 4]) - - def test_can_limit_a_sequence(self): - sequence = [0, 1, 2, 3, 4] - it = ResumableIteratorWrapper(lambda _: sequence, limit=4) - self.assertEqual(list(it), [0, 1, 2, 3]) - - def test_when_limit_is_less_than_sequence_length_is_incomplete(self): - sequence = [0, 1, 2, 3, 4] - it = ResumableIteratorWrapper(lambda _: sequence, limit=4) - list(it) - self.assertFalse(it.is_complete) - - def test_when_limit_matches_sequence_size_iterator_is_complete(self): - sequence = [0, 1, 2, 3, 4] - it = ResumableIteratorWrapper(lambda _: sequence, limit=5) - list(it) - self.assertTrue(it.is_complete) - - def test_get_next_query_params_returns_empty_object_prior_to_iteration(self): - seq = [ - {'key': 'one', 'val': 'val1'}, - {'key': 'two', 'val': 'val2'}, - ] - it = ResumableIteratorWrapper(lambda _: seq) - self.assertEqual(it.get_next_query_params(), {}) - - def test_default_get_next_query_params_returns_identity_object(self): - seq = [ - {'key': 'one', 'val': 'val1'}, - {'key': 'two', 'val': 'val2'}, - ] - it = ResumableIteratorWrapper(lambda _: seq, ) - next(it) - self.assertEqual(it.get_next_query_params(), {'value': {'key': 'one', 'val': 'val1'}}) - - def test_custom_get_next_query_params_fn(self): - seq = [ - {'key': 'one', 'val': 'val1'}, - {'key': 'two', 'val': 'val2'}, - ] - - def custom_element_properties_fn(ele): - return (ele['key'], ele['val']) - - it = ResumableIteratorWrapper(lambda _: seq, custom_element_properties_fn) - next(it) - self.assertEqual(it.get_next_query_params(), ('one', 'val1')) - - def test_get_next_query_params_returns_none_when_fully_iterated(self): - it = ResumableIteratorWrapper(lambda _: range(5)) - list(it) - self.assertIsNone(it.get_next_query_params()) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index cf4478a9cb63..0ed9a42b28f1 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -377,16 +377,10 @@ def get_object_list(self, request): if end_date: end_date = datetime.fromisoformat(end_date) - last_time = request.GET.get('received_on', None) - if last_time: - last_time = datetime.fromisoformat(last_time) - - last_domain = request.GET.get('domain', None) - last_id = request.GET.get('id', None) - account = BillingAccount.get_account_by_domain(request.domain) - return IterableEnterpriseFormQuery(account, start_date, end_date, last_domain, last_time, last_id) + query_kwargs = IterableEnterpriseFormQuery.get_kwargs_from_map(request.GET) + return IterableEnterpriseFormQuery(account, start_date, end_date, **query_kwargs) def dehydrate(self, bundle): bundle.data['form_id'] = bundle.obj['form_id'] diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 5a55c365cd85..a6705720cc83 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -2,7 +2,6 @@ from django.utils.translation import gettext as _ from corehq.apps.es import filters from corehq.apps.es.forms import FormES -from corehq.apps.api.resumable_iterator_wrapper import ResumableIteratorWrapper from corehq.apps.enterprise.exceptions import TooMuchRequestedDataError from corehq.apps.app_manager.dbaccessors import get_brief_apps_in_domain @@ -32,24 +31,39 @@ def __init__(self, account, start_date, end_date, last_domain, last_time, last_i def execute(self, limit=None): domains = self.account.get_domains() - def create_multi_domain_form_generator(limit): - it = multi_domain_form_generator( - domains, - self.start_date, - self.end_date, - self.last_domain, - self.last_time, - self.last_id, - limit=limit - ) - xform_converter = RawFormConverter() - return (xform_converter.convert(form) for form in it) - - return ResumableIteratorWrapper(create_multi_domain_form_generator, lambda form: { - 'domain': form['domain'], - 'received_on': form['submitted'], - 'id': form['form_id'] - }, limit=limit) + it = multi_domain_form_generator( + domains, + self.start_date, + self.end_date, + self.last_domain, + self.last_time, + self.last_id, + limit=limit + ) + + xform_converter = RawFormConverter() + return (xform_converter.convert(form) for form in it) + + @classmethod + def get_kwargs_from_map(cls, map): + last_domain = map.get('domain', None) + last_time = map.get('received_on', None) + if last_time: + last_time = datetime.fromisoformat(last_time) + last_id = map.get('id', None) + return { + 'last_domain': last_domain, + 'last_time': last_time, + 'last_id': last_id + } + + @classmethod + def get_query_params(cls, fetched_object): + return { + 'domain': fetched_object['domain'], + 'received_on': fetched_object['submitted'], + 'id': fetched_object['form_id'] + } def resolve_start_and_end_date(start_date, end_date, maximum_date_range): From 05eaa9a3a79bf4a0240a02bd92bb799dc7f6104c Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 30 Oct 2024 12:01:27 -0400 Subject: [PATCH 08/23] Switched received filter to inserted --- corehq/apps/enterprise/iterators.py | 13 +++++++------ corehq/apps/es/forms.py | 5 +++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index a6705720cc83..c941b9fc0bc3 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from django.utils.translation import gettext as _ from corehq.apps.es import filters +from corehq.apps.es.utils import es_format_datetime from corehq.apps.es.forms import FormES from corehq.apps.enterprise.exceptions import TooMuchRequestedDataError from corehq.apps.app_manager.dbaccessors import get_brief_apps_in_domain @@ -178,19 +179,19 @@ def create_domain_query(domain, start_date, end_date, last_time, last_id, limit= query = query.size(limit) query.es_query['sort'] = [ - {'received_on': {'order': 'desc'}}, - {'form.meta.instanceID': 'asc'} + {'inserted_at': {'order': 'desc'}}, + {'doc_id': 'asc'} ] if last_id: query = query.filter(filters.OR( filters.AND( - filters.term('received_on', last_time), - filters.range_filter('form.meta.instanceID', gt=last_id) + filters.term('inserted_at', es_format_datetime(last_time)), + filters.range_filter('doc_id', gt=last_id) ), - filters.range_filter('received_on', lt=last_time) + filters.date_range('inserted_at', lt=last_time) )) else: - query = query.submitted(lte=last_time) + query = query.inserted(lte=last_time) return query diff --git a/corehq/apps/es/forms.py b/corehq/apps/es/forms.py index 60aed8ec325e..6354a439b0dc 100644 --- a/corehq/apps/es/forms.py +++ b/corehq/apps/es/forms.py @@ -43,6 +43,7 @@ def builtin_filters(self): form_ids, xmlns, app, + inserted, submitted, completed, user_id, @@ -186,6 +187,10 @@ def submitted(gt=None, gte=None, lt=None, lte=None): return filters.date_range('received_on', gt, gte, lt, lte) +def inserted(gt=None, gte=None, lt=None, lte=None): + return filters.date_range('inserted_at', gt, gte, lt, lte) + + def completed(gt=None, gte=None, lt=None, lte=None): return filters.date_range('form.meta.timeEnd', gt, gte, lt, lte) From 2504668727cf9882dd9fbc6e21de29f918e460a0 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 30 Oct 2024 12:04:06 -0400 Subject: [PATCH 09/23] Rename domain forms generator --- corehq/apps/enterprise/iterators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index c941b9fc0bc3..c3cd5a9885df 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -149,7 +149,7 @@ def domain_form_generator(domain, start_date, end_date, last_time=None, last_id= remaining = limit while True: - query = create_domain_query(domain, start_date, end_date, last_time, last_id, limit=remaining) + query = create_domain_forms_query(domain, start_date, end_date, last_time, last_id, limit=remaining) results = query.run() for form in results.hits: last_form_fetched = form @@ -167,7 +167,7 @@ def domain_form_generator(domain, start_date, end_date, last_time=None, last_id= last_id = last_form_fetched['_id'] -def create_domain_query(domain, start_date, end_date, last_time, last_id, limit=None): +def create_domain_forms_query(domain, start_date, end_date, last_time, last_id, limit=None): query = ( FormES() .domain(domain) From dd334de028120e16ceb4d3f88d0b3812dbe26fed Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 30 Oct 2024 14:52:55 -0400 Subject: [PATCH 10/23] Make enterprise form api timezone aware --- corehq/apps/enterprise/api/resources.py | 6 +++--- corehq/apps/enterprise/iterators.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index 0ed9a42b28f1..e29bbb8c1efe 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from urllib.parse import urljoin from django.http import HttpResponse, HttpResponseForbidden, HttpResponseNotFound @@ -371,11 +371,11 @@ class Meta(ODataEnterpriseReportResource.Meta): def get_object_list(self, request): start_date = request.GET.get('startdate', None) if start_date: - start_date = datetime.fromisoformat(start_date) + start_date = datetime.fromisoformat(start_date).astimezone(timezone.utc) end_date = request.GET.get('enddate', None) if end_date: - end_date = datetime.fromisoformat(end_date) + end_date = datetime.fromisoformat(end_date).astimezone(timezone.utc) account = BillingAccount.get_account_by_domain(request.domain) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index c3cd5a9885df..74ec1a2c005f 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from django.utils.translation import gettext as _ from corehq.apps.es import filters from corehq.apps.es.utils import es_format_datetime @@ -50,7 +50,7 @@ def get_kwargs_from_map(cls, map): last_domain = map.get('domain', None) last_time = map.get('received_on', None) if last_time: - last_time = datetime.fromisoformat(last_time) + last_time = datetime.fromisoformat(last_time).astimezone(timezone.utc) last_id = map.get('id', None) return { 'last_domain': last_domain, @@ -72,7 +72,7 @@ def resolve_start_and_end_date(start_date, end_date, maximum_date_range): Provide start and end date values if not supplied. ''' if not end_date: - end_date = datetime.utcnow() + end_date = datetime.now(timezone.utc) if not start_date: start_date = end_date - timedelta(days=30) @@ -144,7 +144,7 @@ def _get_domain_iterator(last_time=None, last_id=None): def domain_form_generator(domain, start_date, end_date, last_time=None, last_id=None, limit=None): if not last_time: - last_time = datetime.now() + last_time = datetime.now(timezone.utc) remaining = limit From 080d837c646b60abb79a2cf975f26f6291301104 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Thu, 31 Oct 2024 15:21:49 -0400 Subject: [PATCH 11/23] Rename mobile_user field to username --- corehq/apps/enterprise/api/resources.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index e29bbb8c1efe..3b586ba3faab 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -363,7 +363,7 @@ class Meta(ODataEnterpriseReportResource.Meta): form_name = fields.CharField() submitted = fields.DateTimeField() app_name = fields.CharField() - mobile_user = fields.CharField() + username = fields.CharField() domain = fields.CharField() REPORT_SLUG = EnterpriseReport.FORM_SUBMISSIONS @@ -387,7 +387,7 @@ def dehydrate(self, bundle): bundle.data['form_name'] = bundle.obj['form_name'] bundle.data['submitted'] = self.convert_datetime(bundle.obj['submitted']) bundle.data['app_name'] = bundle.obj['app_name'] - bundle.data['mobile_user'] = bundle.obj['username'] + bundle.data['username'] = bundle.obj['username'] bundle.data['domain'] = bundle.obj['domain'] return bundle From 09c104b9bc322fe8ddfa90f2bf91d7f00c07a5bc Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Fri, 1 Nov 2024 16:03:11 -0400 Subject: [PATCH 12/23] Made enterprise form submission report iteration generic Moved from `received_on` to `inserted_at` --- corehq/apps/enterprise/iterators.py | 146 +++++++++++------- .../apps/enterprise/tests/test_iterators.py | 62 ++++++-- 2 files changed, 136 insertions(+), 72 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 74ec1a2c005f..dbb878c68802 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -32,14 +32,15 @@ def __init__(self, account, start_date, end_date, last_domain, last_time, last_i def execute(self, limit=None): domains = self.account.get_domains() - it = multi_domain_form_generator( + it = loop_over_domains( domains, - self.start_date, - self.end_date, - self.last_domain, - self.last_time, - self.last_id, - limit=limit + MobileFormSubmissionsQueryFactory(), + limit=limit, + last_domain=self.last_domain, + start_date=self.start_date, + end_date=self.end_date, + last_time=self.last_time, + last_id=self.last_id, ) xform_converter = RawFormConverter() @@ -48,7 +49,7 @@ def execute(self, limit=None): @classmethod def get_kwargs_from_map(cls, map): last_domain = map.get('domain', None) - last_time = map.get('received_on', None) + last_time = map.get('inserted_at', None) if last_time: last_time = datetime.fromisoformat(last_time).astimezone(timezone.utc) last_id = map.get('id', None) @@ -62,7 +63,7 @@ def get_kwargs_from_map(cls, map): def get_query_params(cls, fetched_object): return { 'domain': fetched_object['domain'], - 'received_on': fetched_object['submitted'], + 'inserted_at': fetched_object['inserted_at'], 'id': fetched_object['form_id'] } @@ -92,11 +93,13 @@ def __init__(self): def convert(self, form): domain = form['domain'] submitted_date = datetime.strptime(form['received_on'][:19], '%Y-%m-%dT%H:%M:%S') + inserted_at = datetime.strptime(form['inserted_at'][:19], '%Y-%m-%dT%H:%M:%S') return { 'form_id': form['form']['meta']['instanceID'], 'form_name': form['form']['@name'] or _('Unnamed'), 'submitted': submitted_date, + 'inserted_at': inserted_at, 'app_name': self.app_lookup.resolve_app_id_to_name(domain, form['app_id']) or _('App not found'), 'username': form['form']['meta']['username'], 'domain': domain @@ -115,23 +118,22 @@ def resolve_app_id_to_name(self, domain, app_id): return self.domain_lookup_tables[domain].get(app_id, None) -def multi_domain_form_generator( - domains, start_date, end_date, last_domain=None, last_time=None, last_id=None, limit=None): +def loop_over_domains(domains, query_factory, limit=None, last_domain=None, **kwargs): domain_index = domains.index(last_domain) if last_domain else 0 remaining = limit - def _get_domain_iterator(last_time=None, last_id=None): + def _get_domain_iterator(**kwargs): if domain_index >= len(domains): return None domain = domains[domain_index] - return domain_form_generator(domain, start_date, end_date, last_time, last_id, limit=remaining) + return loop_over_domain(domain, query_factory, limit=remaining, **kwargs) - current_iterator = _get_domain_iterator(last_time, last_id) + current_iterator = _get_domain_iterator(**kwargs) while current_iterator: - for form in current_iterator: - yield form + for hit in current_iterator: + yield hit if remaining: remaining -= 1 if remaining == 0: @@ -139,21 +141,21 @@ def _get_domain_iterator(last_time=None, last_id=None): domain_index += 1 if domain_index >= len(domains): return - current_iterator = _get_domain_iterator() + next_args = query_factory.get_next_query_args(kwargs, last_hit=None) + current_iterator = _get_domain_iterator(**next_args) -def domain_form_generator(domain, start_date, end_date, last_time=None, last_id=None, limit=None): - if not last_time: - last_time = datetime.now(timezone.utc) - +def loop_over_domain(domain, query_factory, limit=None, **kwargs): remaining = limit + next_query_args = kwargs + while True: - query = create_domain_forms_query(domain, start_date, end_date, last_time, last_id, limit=remaining) + query = query_factory.get_query(domain, limit=limit, **next_query_args) results = query.run() - for form in results.hits: - last_form_fetched = form - yield last_form_fetched + for hit in results.hits: + last_hit = hit + yield last_hit num_fetched = len(results.hits) @@ -162,36 +164,64 @@ def domain_form_generator(domain, start_date, end_date, last_time=None, last_id= else: if remaining: remaining -= num_fetched - assert remaining > 0 - last_time = last_form_fetched['received_on'] - last_id = last_form_fetched['_id'] - - -def create_domain_forms_query(domain, start_date, end_date, last_time, last_id, limit=None): - query = ( - FormES() - .domain(domain) - .user_type('mobile') - .submitted(gte=start_date, lte=end_date) - ) - - if limit: - query = query.size(limit) - - query.es_query['sort'] = [ - {'inserted_at': {'order': 'desc'}}, - {'doc_id': 'asc'} - ] - - if last_id: - query = query.filter(filters.OR( - filters.AND( - filters.term('inserted_at', es_format_datetime(last_time)), - filters.range_filter('doc_id', gt=last_id) - ), - filters.date_range('inserted_at', lt=last_time) - )) - else: - query = query.inserted(lte=last_time) - - return query + + next_query_args = query_factory.get_next_query_args(next_query_args, last_hit) + + +class ReportQueryFactoryInterface: + ''' + A generic interface for any report queries. + ''' + def get_query(self, **kwargs): + ''' + Returns an ElasticSearch query, configured by **kwargs + ''' + raise NotImplementedError() + + def get_next_query_args(self, previous_args, last_hit): + ''' + Modifies the `previous_args` dictionary with information from `last_hit` to create + a new set of kwargs suitable to pass back to `get_query` to retrieve results beyond `last_hit` + ''' + raise NotImplementedError() + + +class MobileFormSubmissionsQueryFactory(ReportQueryFactoryInterface): + def get_query(self, domain, start_date, end_date, last_time=None, last_id=None, limit=None): + query = ( + FormES() + .domain(domain) + .user_type('mobile') + .submitted(gte=start_date, lte=end_date) + ) + + if limit: + query = query.size(limit) + + query.es_query['sort'] = [ + {'inserted_at': {'order': 'desc'}}, + {'doc_id': 'asc'} + ] + + if last_time and last_id: + query = query.filter(filters.OR( + filters.AND( + filters.term('inserted_at', es_format_datetime(last_time)), + filters.range_filter('doc_id', gt=last_id) + ), + filters.date_range('inserted_at', lt=last_time) + )) + + return query + + def get_next_query_args(self, previous_args, last_hit): + if last_hit: + return previous_args | { + 'last_time': last_hit['inserted_at'], + 'last_id': last_hit['doc_id'] + } + else: + new_args = previous_args.copy() + new_args.pop('last_time', None) + new_args.pop('last_id', None) + return new_args diff --git a/corehq/apps/enterprise/tests/test_iterators.py b/corehq/apps/enterprise/tests/test_iterators.py index 3343bc14089f..d781eabfc88d 100644 --- a/corehq/apps/enterprise/tests/test_iterators.py +++ b/corehq/apps/enterprise/tests/test_iterators.py @@ -1,14 +1,16 @@ +from unittest.mock import patch from django.test import SimpleTestCase, TestCase from datetime import datetime -from corehq.apps.es.forms import form_adapter +from corehq.apps.es.forms import form_adapter, ElasticForm from corehq.apps.es.tests.utils import es_test from corehq.apps.users.models import CommCareUser from corehq.form_processor.tests.utils import create_form_for_test from corehq.apps.enterprise.iterators import ( raise_after_max_elements, - domain_form_generator, - multi_domain_form_generator, + loop_over_domains, + loop_over_domain, + MobileFormSubmissionsQueryFactory ) @@ -29,11 +31,15 @@ def test_can_iterate_through_all_elements_with_no_exception(self): @es_test(requires=[form_adapter]) -class TestMultiDomainFormGenerator(TestCase): +class TestLoopOverDomains(TestCase): def setUp(self): self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) self.addCleanup(self.user.delete, None, None) + inserted_at_mapping_patcher = map_received_on_to_inserted_at() + inserted_at_mapping_patcher.start() + self.addCleanup(inserted_at_mapping_patcher.stop) + def test_iterates_through_multiple_domains(self): forms = [ self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)), @@ -42,8 +48,9 @@ def test_iterates_through_multiple_domains(self): ] form_adapter.bulk_index(forms, refresh=True) - it = multi_domain_form_generator( + it = loop_over_domains( ['domain1', 'domain2', 'domain3'], + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15) ) @@ -60,8 +67,9 @@ def test_respects_limit_across_multiple_domains(self): ] form_adapter.bulk_index(forms, refresh=True) - it = multi_domain_form_generator( + it = loop_over_domains( ['domain1', 'domain2'], + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), limit=3 @@ -88,19 +96,24 @@ def _create_form(self, domain, form_id=None, received_on=None): @es_test(requires=[form_adapter]) -class TestDomainFormGenerator(TestCase): +class TestLoopOverDomain(TestCase): def setUp(self): self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) self.addCleanup(self.user.delete, None, None) + inserted_at_mapping_patcher = map_received_on_to_inserted_at() + inserted_at_mapping_patcher.start() + self.addCleanup(inserted_at_mapping_patcher.stop) + def test_iterates_through_all_forms_in_domain(self): form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=3)) form3 = self._create_form('test-domain', form_id='3', received_on=datetime(year=2024, month=7, day=4)) form_adapter.bulk_index([form1, form2, form3], refresh=True) - it = domain_form_generator( + it = loop_over_domain( 'test-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -109,8 +122,9 @@ def test_iterates_through_all_forms_in_domain(self): self.assertEqual(form_ids, ['3', '2', '1']) def test_handles_empty_domain(self): - it = domain_form_generator( + it = loop_over_domain( 'empty-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -122,8 +136,9 @@ def test_includes_inclusive_boundaries(self): form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) form_adapter.bulk_index([form1, form2], refresh=True) - it = domain_form_generator( + it = loop_over_domain( 'test-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2) ) @@ -136,8 +151,9 @@ def test_ignores_form_in_another_domain(self): form2 = self._create_form('not-test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) form_adapter.bulk_index([form1, form2], refresh=True) - it = domain_form_generator( + it = loop_over_domain( 'test-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -150,8 +166,9 @@ def test_sorts_by_date_then_id(self): form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) form_adapter.bulk_index([form2, form1], refresh=True) - it = domain_form_generator( + it = loop_over_domain( 'test-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2), ) @@ -164,8 +181,9 @@ def test_does_not_return_forms_beyond_limit(self): form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) form_adapter.bulk_index([form1, form2], refresh=True) - it = domain_form_generator( + it = loop_over_domain( 'test-domain', + MobileFormSubmissionsQueryFactory(), start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2), limit=1 @@ -187,5 +205,21 @@ def _create_form(self, domain, form_id=None, received_on=None): user_id=self.user._id, form_data=form_data, form_id=form_id, - received_on=received_on + received_on=received_on, ) + + +def map_received_on_to_inserted_at(): + ''' + A patcher to use the date value found in a form's 'received_on' field for the 'inserted_at' value. + Without this patch, 'inserted_at' will be `utcnow()`, which would require knowing the order and number of times + that `utcnow()` would be called to manipulate dates + ''' + original = ElasticForm._from_dict + + def from_dict(cls, xform_dict): + id, result = original(cls, xform_dict) + result['inserted_at'] = result['received_on'] + return (id, result) + + return patch.object(ElasticForm, '_from_dict', new=from_dict) From 409f725102c5a24f85b072ec1513fa3636c21e9b Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 6 Nov 2024 11:10:47 -0500 Subject: [PATCH 13/23] Added happy path test for form resource api --- corehq/apps/enterprise/tests/test_apis.py | 111 ++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 corehq/apps/enterprise/tests/test_apis.py diff --git a/corehq/apps/enterprise/tests/test_apis.py b/corehq/apps/enterprise/tests/test_apis.py new file mode 100644 index 000000000000..2fdb60e93341 --- /dev/null +++ b/corehq/apps/enterprise/tests/test_apis.py @@ -0,0 +1,111 @@ +import json +from django.test import TestCase, RequestFactory +from django_prbac.models import Role, Grant, UserRole +from corehq import privileges +import base64 +from datetime import datetime, timezone +from corehq.apps.users.models import WebUser, CommCareUser, HQApiKey +from corehq.apps.accounting.tests import generator +from corehq.apps.domain.models import Domain +from corehq.apps.es.tests.utils import es_test +from corehq.apps.es.forms import form_adapter +from corehq.apps.enterprise.api.resources import FormSubmissionResource +from corehq.form_processor.tests.utils import create_form_for_test + + +@es_test(requires=[form_adapter]) +class FormSubmissionResourceTests(TestCase): + def test_happy_path(self): + + self._create_enterprise_account_covering_domains(['test-domain-1', 'test-domain-2']) + + enterprise_admin = self._create_enterprise_admin('test-admin@dimagi.com', 'test-domain-1') + django_user = enterprise_admin.get_django_user() + api_key = HQApiKey.objects.create(user=django_user, key='1234', name='TestKey') + + mobile_user_1 = self._create_mobile_user('mobile1@test.com', 'test-domain-1') + mobile_user_2 = self._create_mobile_user('mobile2@test.com', 'test-domain-2') + + form1 = self._create_form(mobile_user_1, form_id='1234', received_on=datetime(year=2004, month=10, day=11)) + form2 = self._create_form(mobile_user_2, form_id='2345', received_on=datetime(year=2004, month=10, day=12)) + form_adapter.bulk_index([form1, form2], refresh=True) + + request = self._create_api_request(enterprise_admin, 'test-domain-1', api_key) + + resource = FormSubmissionResource() + response = resource.dispatch_list(request) + result = json.loads(response.content) + resulting_forms = result['value'] + self.assertEqual([form['form_id'] for form in resulting_forms], ['1234', '2345']) + + def _create_form(self, user, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + '@name': 'TestForm', + 'meta': { + 'userID': user._id, + 'username': user.name, + 'instanceID': form_id, + }, + } + return create_form_for_test( + user.domain, + user_id=user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on, + ) + + def _create_enterprise_account_covering_domains(self, domains): + billing_account = generator.billing_account('test-admin@dimagi.com', 'test-admin@dimagi.com') + billing_account.enterprise_admin_emails = ['test-admin@dimagi.com'] + billing_account.save() + + for domain in domains: + domain_obj = Domain(name=domain, is_active=True) + domain_obj.save() + self.addCleanup(domain_obj.delete) + + generator.generate_domain_subscription( + billing_account, domain_obj, datetime.now(timezone.utc), None, is_active=True + ) + + return billing_account + + def _create_enterprise_admin(self, email, domain): + user = WebUser.create( + domain, email, 'test123', None, None, email) + user.is_superuser = True + user.save() + + self.addCleanup(user.delete, None, deleted_by=None) + + role = Role.objects.create(slug="test_role") + UserRole.objects.create(user=user.get_django_user(), role=role) + accounting_admin_role = Role.objects.get_or_create( + name="Accounting Admin", + slug=privileges.ACCOUNTING_ADMIN, + )[0] + Grant.objects.create(from_role=role, to_role=accounting_admin_role) + + return user + + def _create_mobile_user(self, username, domain): + user = CommCareUser.create(domain, username, 'test123', None, None) + self.addCleanup(user.delete, None, deleted_by=None) + return user + + def _create_api_request(self, user, domain, api_key): + auth_string = f'{user.username}:{api_key.key}' + factory = RequestFactory() + encoded_auth = base64.b64encode(auth_string.encode()).decode() + request = factory.get( + '/', + {'startdate': '2004-10-10', 'enddate': '2004-11-10'}, + HTTP_AUTHORIZATION=f'basic {encoded_auth}' + ) + request.couch_user = user + request.user = user.get_django_user() + request.domain = domain + + return request From bff5facbf3137ff0f7d757ab19538a65d1ce7b37 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 6 Nov 2024 16:17:18 -0500 Subject: [PATCH 14/23] Remove superuser permissions from Enterprise Forms API test --- corehq/apps/enterprise/tests/test_apis.py | 64 +++++++++++++++-------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/corehq/apps/enterprise/tests/test_apis.py b/corehq/apps/enterprise/tests/test_apis.py index 2fdb60e93341..dd32aa068ac7 100644 --- a/corehq/apps/enterprise/tests/test_apis.py +++ b/corehq/apps/enterprise/tests/test_apis.py @@ -1,25 +1,28 @@ import json from django.test import TestCase, RequestFactory -from django_prbac.models import Role, Grant, UserRole -from corehq import privileges +from corehq.apps.accounting.models import SoftwarePlanEdition, DefaultProductPlan import base64 from datetime import datetime, timezone -from corehq.apps.users.models import WebUser, CommCareUser, HQApiKey +from corehq.apps.users.models import WebUser, CommCareUser, HQApiKey, HqPermissions from corehq.apps.accounting.tests import generator from corehq.apps.domain.models import Domain from corehq.apps.es.tests.utils import es_test from corehq.apps.es.forms import form_adapter from corehq.apps.enterprise.api.resources import FormSubmissionResource from corehq.form_processor.tests.utils import create_form_for_test +from corehq.apps.users.models_role import UserRole @es_test(requires=[form_adapter]) class FormSubmissionResourceTests(TestCase): def test_happy_path(self): + enterprise_account = self._create_enterprise_account_covering_domains(['test-domain-1', 'test-domain-2']) - self._create_enterprise_account_covering_domains(['test-domain-1', 'test-domain-2']) - - enterprise_admin = self._create_enterprise_admin('test-admin@dimagi.com', 'test-domain-1') + enterprise_admin = self._create_enterprise_admin( + 'test-admin@somedomain.com', + 'test-domain-1', + enterprise_account + ) django_user = enterprise_admin.get_django_user() api_key = HQApiKey.objects.create(user=django_user, key='1234', name='TestKey') @@ -30,7 +33,12 @@ def test_happy_path(self): form2 = self._create_form(mobile_user_2, form_id='2345', received_on=datetime(year=2004, month=10, day=12)) form_adapter.bulk_index([form1, form2], refresh=True) - request = self._create_api_request(enterprise_admin, 'test-domain-1', api_key) + request = self._create_api_request( + enterprise_admin, + 'test-domain-1', + api_key, + {'startdate': '2004-10-10', 'enddate': '2004-11-10'} + ) resource = FormSubmissionResource() response = resource.dispatch_list(request) @@ -57,9 +65,14 @@ def _create_form(self, user, form_id=None, received_on=None): ) def _create_enterprise_account_covering_domains(self, domains): - billing_account = generator.billing_account('test-admin@dimagi.com', 'test-admin@dimagi.com') - billing_account.enterprise_admin_emails = ['test-admin@dimagi.com'] - billing_account.save() + billing_account = generator.billing_account( + 'test-admin@dimagi.com', + 'test-admin@dimagi.com', + is_customer_account=True + ) + + # Enterprise is needed to grant API and OData permissions + enterprise_plan = DefaultProductPlan.get_default_plan_version(SoftwarePlanEdition.ENTERPRISE) for domain in domains: domain_obj = Domain(name=domain, is_active=True) @@ -67,26 +80,31 @@ def _create_enterprise_account_covering_domains(self, domains): self.addCleanup(domain_obj.delete) generator.generate_domain_subscription( - billing_account, domain_obj, datetime.now(timezone.utc), None, is_active=True + billing_account, + domain_obj, + datetime.now(timezone.utc), + None, + plan_version=enterprise_plan, + is_active=True ) return billing_account - def _create_enterprise_admin(self, email, domain): + def _create_enterprise_admin(self, email, domain, enterprise_account): user = WebUser.create( domain, email, 'test123', None, None, email) - user.is_superuser = True + + # Users need to have permission to view OData reports on the specified domain to access these APIs + permissions = HqPermissions(view_reports=True) + role = UserRole.create(domain, email, permissions=permissions) + membership = user.get_domain_membership(domain) + membership.role_id = role.couch_id user.save() - self.addCleanup(user.delete, None, deleted_by=None) + enterprise_account.enterprise_admin_emails.append(email) + enterprise_account.save() - role = Role.objects.create(slug="test_role") - UserRole.objects.create(user=user.get_django_user(), role=role) - accounting_admin_role = Role.objects.get_or_create( - name="Accounting Admin", - slug=privileges.ACCOUNTING_ADMIN, - )[0] - Grant.objects.create(from_role=role, to_role=accounting_admin_role) + self.addCleanup(user.delete, None, deleted_by=None) return user @@ -95,13 +113,13 @@ def _create_mobile_user(self, username, domain): self.addCleanup(user.delete, None, deleted_by=None) return user - def _create_api_request(self, user, domain, api_key): + def _create_api_request(self, user, domain, api_key, data=None): auth_string = f'{user.username}:{api_key.key}' factory = RequestFactory() encoded_auth = base64.b64encode(auth_string.encode()).decode() request = factory.get( '/', - {'startdate': '2004-10-10', 'enddate': '2004-11-10'}, + data, HTTP_AUTHORIZATION=f'basic {encoded_auth}' ) request.couch_user = user From c65f0b67e58c338dce2462b5bf4c1bf914e0448f Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 6 Nov 2024 16:20:08 -0500 Subject: [PATCH 15/23] rename api test --- corehq/apps/enterprise/tests/test_apis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/corehq/apps/enterprise/tests/test_apis.py b/corehq/apps/enterprise/tests/test_apis.py index dd32aa068ac7..d67e693c881b 100644 --- a/corehq/apps/enterprise/tests/test_apis.py +++ b/corehq/apps/enterprise/tests/test_apis.py @@ -15,7 +15,7 @@ @es_test(requires=[form_adapter]) class FormSubmissionResourceTests(TestCase): - def test_happy_path(self): + def test_resource_is_accessible(self): enterprise_account = self._create_enterprise_account_covering_domains(['test-domain-1', 'test-domain-2']) enterprise_admin = self._create_enterprise_admin( From 593882ca072d4b45852d9f7eb97f75c0475075bd Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Wed, 6 Nov 2024 16:22:42 -0500 Subject: [PATCH 16/23] isort --- corehq/apps/enterprise/tests/test_apis.py | 24 ++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/corehq/apps/enterprise/tests/test_apis.py b/corehq/apps/enterprise/tests/test_apis.py index d67e693c881b..969f98ef0287 100644 --- a/corehq/apps/enterprise/tests/test_apis.py +++ b/corehq/apps/enterprise/tests/test_apis.py @@ -1,16 +1,26 @@ -import json -from django.test import TestCase, RequestFactory -from corehq.apps.accounting.models import SoftwarePlanEdition, DefaultProductPlan import base64 +import json from datetime import datetime, timezone -from corehq.apps.users.models import WebUser, CommCareUser, HQApiKey, HqPermissions + +from django.test import RequestFactory, TestCase + +from corehq.apps.accounting.models import ( + DefaultProductPlan, + SoftwarePlanEdition, +) from corehq.apps.accounting.tests import generator from corehq.apps.domain.models import Domain -from corehq.apps.es.tests.utils import es_test -from corehq.apps.es.forms import form_adapter from corehq.apps.enterprise.api.resources import FormSubmissionResource -from corehq.form_processor.tests.utils import create_form_for_test +from corehq.apps.es.forms import form_adapter +from corehq.apps.es.tests.utils import es_test +from corehq.apps.users.models import ( + CommCareUser, + HQApiKey, + HqPermissions, + WebUser, +) from corehq.apps.users.models_role import UserRole +from corehq.form_processor.tests.utils import create_form_for_test @es_test(requires=[form_adapter]) From bee7055c7bd34087a011cdc2a87d1d074691660d Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Thu, 7 Nov 2024 14:15:06 -0500 Subject: [PATCH 17/23] Refactor domain iteration logic --- corehq/apps/enterprise/iterators.py | 38 +++++++++++------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index dbb878c68802..334d465d73f2 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,3 +1,4 @@ +from itertools import dropwhile, chain, islice from datetime import datetime, timedelta, timezone from django.utils.translation import gettext as _ from corehq.apps.es import filters @@ -119,30 +120,19 @@ def resolve_app_id_to_name(self, domain, app_id): def loop_over_domains(domains, query_factory, limit=None, last_domain=None, **kwargs): - domain_index = domains.index(last_domain) if last_domain else 0 - - remaining = limit - - def _get_domain_iterator(**kwargs): - if domain_index >= len(domains): - return None - domain = domains[domain_index] - return loop_over_domain(domain, query_factory, limit=remaining, **kwargs) - - current_iterator = _get_domain_iterator(**kwargs) - - while current_iterator: - for hit in current_iterator: - yield hit - if remaining: - remaining -= 1 - if remaining == 0: - return - domain_index += 1 - if domain_index >= len(domains): - return - next_args = query_factory.get_next_query_args(kwargs, last_hit=None) - current_iterator = _get_domain_iterator(**next_args) + iterators = [] + if last_domain: + remaining_domains = dropwhile(lambda d: d != last_domain, domains) + in_progress_domain = next(remaining_domains) + iterators.append(loop_over_domain(in_progress_domain, query_factory, limit=limit, **kwargs)) + else: + remaining_domains = domains + + fresh_domain_query_args = query_factory.get_next_query_args(kwargs, last_hit=None) + for domain in remaining_domains: + iterators.append(loop_over_domain(domain, query_factory, limit=limit, **fresh_domain_query_args)) + + yield from islice(chain.from_iterable(iterators), limit) def loop_over_domain(domain, query_factory, limit=None, **kwargs): From 7082597e5df3c194645e0e02b9212fe11727e778 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Thu, 7 Nov 2024 14:21:32 -0500 Subject: [PATCH 18/23] rename domain looping functions --- corehq/apps/enterprise/iterators.py | 12 +++---- .../apps/enterprise/tests/test_iterators.py | 36 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 334d465d73f2..c7e989aa59d0 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -33,9 +33,9 @@ def __init__(self, account, start_date, end_date, last_domain, last_time, last_i def execute(self, limit=None): domains = self.account.get_domains() - it = loop_over_domains( - domains, + it = run_query_over_domains( MobileFormSubmissionsQueryFactory(), + domains, limit=limit, last_domain=self.last_domain, start_date=self.start_date, @@ -119,23 +119,23 @@ def resolve_app_id_to_name(self, domain, app_id): return self.domain_lookup_tables[domain].get(app_id, None) -def loop_over_domains(domains, query_factory, limit=None, last_domain=None, **kwargs): +def run_query_over_domains(query_factory, domains, limit=None, last_domain=None, **kwargs): iterators = [] if last_domain: remaining_domains = dropwhile(lambda d: d != last_domain, domains) in_progress_domain = next(remaining_domains) - iterators.append(loop_over_domain(in_progress_domain, query_factory, limit=limit, **kwargs)) + iterators.append(run_query_over_domain(query_factory, in_progress_domain, limit=limit, **kwargs)) else: remaining_domains = domains fresh_domain_query_args = query_factory.get_next_query_args(kwargs, last_hit=None) for domain in remaining_domains: - iterators.append(loop_over_domain(domain, query_factory, limit=limit, **fresh_domain_query_args)) + iterators.append(run_query_over_domain(query_factory, domain, limit=limit, **fresh_domain_query_args)) yield from islice(chain.from_iterable(iterators), limit) -def loop_over_domain(domain, query_factory, limit=None, **kwargs): +def run_query_over_domain(query_factory, domain, limit=None, **kwargs): remaining = limit next_query_args = kwargs diff --git a/corehq/apps/enterprise/tests/test_iterators.py b/corehq/apps/enterprise/tests/test_iterators.py index d781eabfc88d..fdfe3a4a5f00 100644 --- a/corehq/apps/enterprise/tests/test_iterators.py +++ b/corehq/apps/enterprise/tests/test_iterators.py @@ -8,8 +8,8 @@ from corehq.form_processor.tests.utils import create_form_for_test from corehq.apps.enterprise.iterators import ( raise_after_max_elements, - loop_over_domains, - loop_over_domain, + run_query_over_domains, + run_query_over_domain, MobileFormSubmissionsQueryFactory ) @@ -48,9 +48,9 @@ def test_iterates_through_multiple_domains(self): ] form_adapter.bulk_index(forms, refresh=True) - it = loop_over_domains( - ['domain1', 'domain2', 'domain3'], + it = run_query_over_domains( MobileFormSubmissionsQueryFactory(), + ['domain1', 'domain2', 'domain3'], start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15) ) @@ -67,9 +67,9 @@ def test_respects_limit_across_multiple_domains(self): ] form_adapter.bulk_index(forms, refresh=True) - it = loop_over_domains( - ['domain1', 'domain2'], + it = run_query_over_domains( MobileFormSubmissionsQueryFactory(), + ['domain1', 'domain2'], start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), limit=3 @@ -111,9 +111,9 @@ def test_iterates_through_all_forms_in_domain(self): form3 = self._create_form('test-domain', form_id='3', received_on=datetime(year=2024, month=7, day=4)) form_adapter.bulk_index([form1, form2, form3], refresh=True) - it = loop_over_domain( - 'test-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'test-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -122,9 +122,9 @@ def test_iterates_through_all_forms_in_domain(self): self.assertEqual(form_ids, ['3', '2', '1']) def test_handles_empty_domain(self): - it = loop_over_domain( - 'empty-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'empty-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -136,9 +136,9 @@ def test_includes_inclusive_boundaries(self): form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) form_adapter.bulk_index([form1, form2], refresh=True) - it = loop_over_domain( - 'test-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'test-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2) ) @@ -151,9 +151,9 @@ def test_ignores_form_in_another_domain(self): form2 = self._create_form('not-test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) form_adapter.bulk_index([form1, form2], refresh=True) - it = loop_over_domain( - 'test-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'test-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=15), ) @@ -166,9 +166,9 @@ def test_sorts_by_date_then_id(self): form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) form_adapter.bulk_index([form2, form1], refresh=True) - it = loop_over_domain( - 'test-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'test-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2), ) @@ -181,9 +181,9 @@ def test_does_not_return_forms_beyond_limit(self): form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) form_adapter.bulk_index([form1, form2], refresh=True) - it = loop_over_domain( - 'test-domain', + it = run_query_over_domain( MobileFormSubmissionsQueryFactory(), + 'test-domain', start_date=datetime(year=2024, month=7, day=1), end_date=datetime(year=2024, month=7, day=2), limit=1 From b85a5f54bdf76162b113f72c60ee6784c1f46be1 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Thu, 7 Nov 2024 15:51:27 -0500 Subject: [PATCH 19/23] Added authentication tests --- corehq/apps/enterprise/tests/api/__init__.py | 0 .../enterprise/tests/api/test_resources.py | 93 +++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 corehq/apps/enterprise/tests/api/__init__.py create mode 100644 corehq/apps/enterprise/tests/api/test_resources.py diff --git a/corehq/apps/enterprise/tests/api/__init__.py b/corehq/apps/enterprise/tests/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/corehq/apps/enterprise/tests/api/test_resources.py b/corehq/apps/enterprise/tests/api/test_resources.py new file mode 100644 index 000000000000..5d3dd7649561 --- /dev/null +++ b/corehq/apps/enterprise/tests/api/test_resources.py @@ -0,0 +1,93 @@ +from django.test import TestCase, RequestFactory +from datetime import datetime, timezone +from unittest.mock import patch +from corehq.apps.domain.models import Domain +from corehq.apps.enterprise.api.resources import EnterpriseODataAuthentication, ODataAuthentication +from corehq.apps.accounting.models import DefaultProductPlan, SoftwarePlanEdition +from corehq.apps.accounting.tests.utils import generator +from corehq.apps.users.models import WebUser +from django.http import Http404 +from tastypie.exceptions import ImmediateHttpResponse + + +class EnterpriseODataAuthenticationTests(TestCase): + def setUp(self): + super().setUp() + patcher = patch.object(ODataAuthentication, 'is_authenticated', return_value=True) + self.mock_is_authentication = patcher.start() + self.addCleanup(patcher.stop) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.user = cls._create_user('admin@test-domain.com') + cls.account = cls._create_enterprise_account_covering_domains(['test-domain']) + cls.account.enterprise_admin_emails = [cls.user.username] + cls.account.save() + + def test_successful_authentication(self): + request = self._create_request(self.user, 'test-domain') + + auth = EnterpriseODataAuthentication() + self.assertTrue(auth.is_authenticated(request)) + + def test_parent_failure_returns_parent_results(self): + self.mock_is_authentication.return_value = False + + request = self._create_request(self.user, 'test-domain') + + auth = EnterpriseODataAuthentication() + self.assertFalse(auth.is_authenticated(request)) + + def test_raises_exception_when_billing_account_does_not_exist(self): + request = self._create_request(self.user, 'not-test-domain') + + auth = EnterpriseODataAuthentication() + with self.assertRaises(Http404): + auth.is_authenticated(request) + + def test_raises_exception_when_not_an_enterprise_admin(self): + self.account.enterprise_admin_emails = ['not-this-user@test-domain.com'] + self.account.save() + + request = self._create_request(self.user, 'test-domain') + + auth = EnterpriseODataAuthentication() + with self.assertRaises(ImmediateHttpResponse): + auth.is_authenticated(request) + + @classmethod + def _create_enterprise_account_covering_domains(cls, domains): + billing_account = generator.billing_account( + 'test-admin@dimagi.com', + 'test-admin@dimagi.com', + is_customer_account=True + ) + + enterprise_plan = DefaultProductPlan.get_default_plan_version(SoftwarePlanEdition.ENTERPRISE) + + for domain in domains: + domain_obj = Domain(name=domain, is_active=True) + domain_obj.save() + cls.addClassCleanup(domain_obj.delete) + + generator.generate_domain_subscription( + billing_account, + domain_obj, + datetime.now(timezone.utc), + None, + plan_version=enterprise_plan, + is_active=True + ) + + return billing_account + + @classmethod + def _create_user(cls, username): + return WebUser(username=username) + + def _create_request(self, user, domain): + request = RequestFactory().get('/') + request.couch_user = user + request.domain = domain + return request From 2d9d74b00d69b94cea2f49602134ad0e9a3a17a9 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Thu, 7 Nov 2024 15:52:05 -0500 Subject: [PATCH 20/23] isort --- .../enterprise/tests/api/test_resources.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/corehq/apps/enterprise/tests/api/test_resources.py b/corehq/apps/enterprise/tests/api/test_resources.py index 5d3dd7649561..4ac50c3094a7 100644 --- a/corehq/apps/enterprise/tests/api/test_resources.py +++ b/corehq/apps/enterprise/tests/api/test_resources.py @@ -1,14 +1,23 @@ -from django.test import TestCase, RequestFactory from datetime import datetime, timezone from unittest.mock import patch -from corehq.apps.domain.models import Domain -from corehq.apps.enterprise.api.resources import EnterpriseODataAuthentication, ODataAuthentication -from corehq.apps.accounting.models import DefaultProductPlan, SoftwarePlanEdition -from corehq.apps.accounting.tests.utils import generator -from corehq.apps.users.models import WebUser + from django.http import Http404 +from django.test import RequestFactory, TestCase + from tastypie.exceptions import ImmediateHttpResponse +from corehq.apps.accounting.models import ( + DefaultProductPlan, + SoftwarePlanEdition, +) +from corehq.apps.accounting.tests.utils import generator +from corehq.apps.domain.models import Domain +from corehq.apps.enterprise.api.resources import ( + EnterpriseODataAuthentication, + ODataAuthentication, +) +from corehq.apps.users.models import WebUser + class EnterpriseODataAuthenticationTests(TestCase): def setUp(self): From dedb429db9ba51cafe061d0de5dba4d68e02fe37 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Tue, 12 Nov 2024 08:55:33 -0500 Subject: [PATCH 21/23] Additional clarifying comments/structures --- corehq/apps/enterprise/iterators.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index c7e989aa59d0..fcc773740707 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -21,6 +21,13 @@ class IterableEnterpriseFormQuery: ''' A class representing a query that returns its results as an iterator The intended use case is to support queries that cross pagination boundaries + - start_date: a date to start the date range. Can be None + - end_date: the inclusive date to finish the date range. Can be None + last_domain, last_time, and last_id are intended to represent the last result from a previous query + that should now be resumed. + - last_domain: the domain to resume the query on + - last_time: the timestamp from the last result of the previous query + - last_id: the id from the last result of the previous query ''' def __init__(self, account, start_date, end_date, last_domain, last_time, last_id): MAX_DATE_RANGE_DAYS = 100 @@ -194,12 +201,17 @@ def get_query(self, domain, start_date, end_date, last_time=None, last_id=None, ] if last_time and last_id: + # The results are first sorted by 'inserted_at', so if the previous query wasn't + # the only form submitted for its timestamp, return the others with a greater doc_id + submitted_same_time_as_previous_result = filters.AND( + filters.term('inserted_at', es_format_datetime(last_time)), + filters.range_filter('doc_id', gt=last_id) + ), + submitted_prior_to_previous_result = filters.date_range('inserted_at', lt=last_time) + query = query.filter(filters.OR( - filters.AND( - filters.term('inserted_at', es_format_datetime(last_time)), - filters.range_filter('doc_id', gt=last_id) - ), - filters.date_range('inserted_at', lt=last_time) + submitted_same_time_as_previous_result, + submitted_prior_to_previous_result )) return query From d489a367e560dc13525987f650f8e21fa7b5e860 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Tue, 12 Nov 2024 09:24:03 -0500 Subject: [PATCH 22/23] Allow the iterable query to use a generic converter --- corehq/apps/enterprise/api/resources.py | 7 +-- corehq/apps/enterprise/iterators.py | 61 +++++++++++++++---------- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index 3b586ba3faab..8adfebbf5bff 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -21,7 +21,7 @@ from corehq.apps.api.resources.meta import get_hq_throttle from corehq.apps.api.keyset_paginator import KeysetPaginator from corehq.apps.enterprise.enterprise import EnterpriseReport -from corehq.apps.enterprise.iterators import IterableEnterpriseFormQuery +from corehq.apps.enterprise.iterators import IterableEnterpriseFormQuery, EnterpriseFormReportConverter from corehq.apps.enterprise.tasks import generate_enterprise_report, ReportTaskProgress @@ -379,8 +379,9 @@ def get_object_list(self, request): account = BillingAccount.get_account_by_domain(request.domain) - query_kwargs = IterableEnterpriseFormQuery.get_kwargs_from_map(request.GET) - return IterableEnterpriseFormQuery(account, start_date, end_date, **query_kwargs) + converter = EnterpriseFormReportConverter() + query_kwargs = converter.get_kwargs_from_map(request.GET) + return IterableEnterpriseFormQuery(account, converter, start_date, end_date, **query_kwargs) def dehydrate(self, bundle): bundle.data['form_id'] = bundle.obj['form_id'] diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index fcc773740707..0706ac55d2a3 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -21,6 +21,8 @@ class IterableEnterpriseFormQuery: ''' A class representing a query that returns its results as an iterator The intended use case is to support queries that cross pagination boundaries + - form_converter: an instance of a class that knows how to translate the iteration results and interpret + the previous progress arguments - start_date: a date to start the date range. Can be None - end_date: the inclusive date to finish the date range. Can be None last_domain, last_time, and last_id are intended to represent the last result from a previous query @@ -29,10 +31,11 @@ class IterableEnterpriseFormQuery: - last_time: the timestamp from the last result of the previous query - last_id: the id from the last result of the previous query ''' - def __init__(self, account, start_date, end_date, last_domain, last_time, last_id): + def __init__(self, account, form_converter, start_date, end_date, last_domain, last_time, last_id): MAX_DATE_RANGE_DAYS = 100 (self.start_date, self.end_date) = resolve_start_and_end_date(start_date, end_date, MAX_DATE_RANGE_DAYS) self.account = account + self.form_converter = form_converter self.last_domain = last_domain self.last_time = last_time self.last_id = last_id @@ -51,29 +54,10 @@ def execute(self, limit=None): last_id=self.last_id, ) - xform_converter = RawFormConverter() - return (xform_converter.convert(form) for form in it) + return (self.form_converter.convert(form) for form in it) - @classmethod - def get_kwargs_from_map(cls, map): - last_domain = map.get('domain', None) - last_time = map.get('inserted_at', None) - if last_time: - last_time = datetime.fromisoformat(last_time).astimezone(timezone.utc) - last_id = map.get('id', None) - return { - 'last_domain': last_domain, - 'last_time': last_time, - 'last_id': last_id - } - - @classmethod - def get_query_params(cls, fetched_object): - return { - 'domain': fetched_object['domain'], - 'inserted_at': fetched_object['inserted_at'], - 'id': fetched_object['form_id'] - } + def get_query_params(self, fetched_object): + return self.form_converter.get_query_params(fetched_object) def resolve_start_and_end_date(start_date, end_date, maximum_date_range): @@ -94,7 +78,7 @@ def resolve_start_and_end_date(start_date, end_date, maximum_date_range): return start_date, end_date -class RawFormConverter: +class EnterpriseFormReportConverter: def __init__(self): self.app_lookup = AppIdToNameResolver() @@ -113,6 +97,35 @@ def convert(self, form): 'domain': domain } + @classmethod + def get_query_paraams(cls, fetched_object): + ''' + Takes a fetched, converted object and returns the values from this object that will be necessary + to continue where this query left off. + ''' + return { + 'domain': fetched_object['domain'], + 'inserted_at': fetched_object['inserted_at'], + 'id': fetched_object['form_id'] + } + + @classmethod + def get_kwargs_from_map(cls, map): + ''' + Takes a map-like object from a continuation request (generally GET/POST) and extracts + the parameters necessary for initializing an IterableEnterpriseFormQuery. + ''' + last_domain = map.get('domain', None) + last_time = map.get('inserted_at', None) + if last_time: + last_time = datetime.fromisoformat(last_time).astimezone(timezone.utc) + last_id = map.get('id', None) + return { + 'last_domain': last_domain, + 'last_time': last_time, + 'last_id': last_id + } + class AppIdToNameResolver: def __init__(self): From c5d96fb8dd0dcd906b825f362c597ced01d39509 Mon Sep 17 00:00:00 2001 From: Matt Riley Date: Tue, 12 Nov 2024 10:37:12 -0500 Subject: [PATCH 23/23] Changed "test-domain" to "testing-domain" to try to isolate a testing failure --- corehq/apps/enterprise/tests/api/test_resources.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/corehq/apps/enterprise/tests/api/test_resources.py b/corehq/apps/enterprise/tests/api/test_resources.py index 4ac50c3094a7..999c9352f366 100644 --- a/corehq/apps/enterprise/tests/api/test_resources.py +++ b/corehq/apps/enterprise/tests/api/test_resources.py @@ -29,13 +29,13 @@ def setUp(self): @classmethod def setUpClass(cls): super().setUpClass() - cls.user = cls._create_user('admin@test-domain.com') - cls.account = cls._create_enterprise_account_covering_domains(['test-domain']) + cls.user = cls._create_user('admin@testing-domain.com') + cls.account = cls._create_enterprise_account_covering_domains(['testing-domain']) cls.account.enterprise_admin_emails = [cls.user.username] cls.account.save() def test_successful_authentication(self): - request = self._create_request(self.user, 'test-domain') + request = self._create_request(self.user, 'testing-domain') auth = EnterpriseODataAuthentication() self.assertTrue(auth.is_authenticated(request)) @@ -43,23 +43,23 @@ def test_successful_authentication(self): def test_parent_failure_returns_parent_results(self): self.mock_is_authentication.return_value = False - request = self._create_request(self.user, 'test-domain') + request = self._create_request(self.user, 'testing-domain') auth = EnterpriseODataAuthentication() self.assertFalse(auth.is_authenticated(request)) def test_raises_exception_when_billing_account_does_not_exist(self): - request = self._create_request(self.user, 'not-test-domain') + request = self._create_request(self.user, 'not-testing-domain') auth = EnterpriseODataAuthentication() with self.assertRaises(Http404): auth.is_authenticated(request) def test_raises_exception_when_not_an_enterprise_admin(self): - self.account.enterprise_admin_emails = ['not-this-user@test-domain.com'] + self.account.enterprise_admin_emails = ['not-this-user@testing-domain.com'] self.account.save() - request = self._create_request(self.user, 'test-domain') + request = self._create_request(self.user, 'testing-domain') auth = EnterpriseODataAuthentication() with self.assertRaises(ImmediateHttpResponse):