diff --git a/corehq/apps/accounting/models.py b/corehq/apps/accounting/models.py index c93739521c22..b3b22e20fda6 100644 --- a/corehq/apps/accounting/models.py +++ b/corehq/apps/accounting/models.py @@ -529,7 +529,7 @@ def autopay_card(self): def get_domains(self): return list(Subscription.visible_objects.filter(account_id=self.id, is_active=True).values_list( - 'subscriber__domain', flat=True)) + 'subscriber__domain', flat=True).order_by('subscriber__domain')) def has_enterprise_admin(self, email): lower_emails = [e.lower() for e in self.enterprise_admin_emails] diff --git a/corehq/apps/api/keyset_paginator.py b/corehq/apps/api/keyset_paginator.py new file mode 100644 index 000000000000..be3310686cf7 --- /dev/null +++ b/corehq/apps/api/keyset_paginator.py @@ -0,0 +1,110 @@ +from itertools import islice +from django.http.request import QueryDict +from urllib.parse import urlencode +from tastypie.paginator import Paginator + + +class KeysetPaginator(Paginator): + ''' + An alternate paginator meant to support paginating by keyset rather than by index/offset. + `objects` is expected to represent a query object that exposes an `.execute(limit)` + method that returns an iterable, and a `get_query_params(object)` method to retrieve the parameters + for the next query + Because keyset pagination does not efficiently handle slicing or offset operations, + these methods have been intentionally disabled + ''' + def __init__(self, request_data, objects, + resource_uri=None, limit=None, max_limit=1000, collection_name='objects'): + super().__init__( + request_data, + objects, + resource_uri=resource_uri, + limit=limit, + max_limit=max_limit, + collection_name=collection_name + ) + + def get_offset(self): + raise NotImplementedError() + + def get_slice(self, limit, offset): + raise NotImplementedError() + + def get_count(self): + raise NotImplementedError() + + def get_previous(self, limit, offset): + raise NotImplementedError() + + def get_next(self, limit, **next_params): + return self._generate_uri(limit, **next_params) + + def _generate_uri(self, limit, **next_params): + if self.resource_uri is None: + return None + + if isinstance(self.request_data, QueryDict): + # Because QueryDict allows multiple values for the same key, we need to remove existing values + # prior to updating + request_params = self.request_data.copy() + if 'limit' in request_params: + del request_params['limit'] + for key in next_params: + if key in request_params: + del request_params[key] + + request_params.update({'limit': str(limit), **next_params}) + encoded_params = request_params.urlencode() + else: + request_params = {} + for k, v in self.request_data.items(): + if isinstance(v, str): + request_params[k] = v.encode('utf-8') + else: + request_params[k] = v + + request_params.update({'limit': limit, **next_params}) + encoded_params = urlencode(request_params) + + return '%s?%s' % ( + self.resource_uri, + encoded_params + ) + + def page(self): + """ + Generates all pertinent data about the requested page. + """ + limit = self.get_limit() + padded_limit = limit + 1 if limit else limit + # Fetch 1 more record than requested to allow us to determine if further queries will be needed + it = iter(self.objects.execute(limit=padded_limit)) + objects = list(islice(it, limit)) + + try: + next(it) + has_more = True + except StopIteration: + has_more = False + + meta = { + 'limit': limit, + } + + if limit and has_more: + last_fetched = objects[-1] + next_page_params = self.objects.get_query_params(last_fetched) + meta['next'] = self.get_next(limit, **next_page_params) + + return { + self.collection_name: objects, + 'meta': meta, + } + + +class PageableQueryInterface: + def execute(limit=None): + ''' + Should return an iterable that exposes a `.get_query_params()` method + ''' + raise NotImplementedError() diff --git a/corehq/apps/api/tests/keyset_paginator_tests.py b/corehq/apps/api/tests/keyset_paginator_tests.py new file mode 100644 index 000000000000..22b35a6ca155 --- /dev/null +++ b/corehq/apps/api/tests/keyset_paginator_tests.py @@ -0,0 +1,77 @@ +from django.test import SimpleTestCase +from django.http import QueryDict +from corehq.apps.api.keyset_paginator import KeysetPaginator + + +class SequenceQuery: + def __init__(self, seq): + self.seq = seq + + def execute(self, limit=None): + return self.seq + + @classmethod + def get_query_params(cls, form): + return {'next': form} + + +class KeysetPaginatorTests(SimpleTestCase): + def test_page_fetches_all_results_below_limit(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects, limit=10) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2, 3, 4]) + self.assertEqual(page['meta'], {'limit': 10}) + + def test_page_includes_next_information_when_more_results_are_available(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects, resource_uri='http://test.com/', limit=3) + page = paginator.page() + self.assertEqual(page['objects'], [0, 1, 2]) + self.assertEqual(page['meta'], {'limit': 3, 'next': 'http://test.com/?limit=3&next=2'}) + + def test_does_not_include_duplicate_limits(self): + request_data = QueryDict(mutable=True) + request_data['limit'] = 3 + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&next=2') + + def test_supports_dict_request_data(self): + request_data = { + 'limit': 3, + 'some_param': 'yes' + } + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(request_data, objects, resource_uri='http://test.com/') + page = paginator.page() + self.assertEqual(page['meta']['next'], 'http://test.com/?limit=3&some_param=yes&next=2') + + def test_get_offset_not_implemented(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_offset() + + def test_get_slice_not_implemented(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_slice(limit=10, offset=20) + + def test_get_count_not_implemented(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_count() + + def test_get_previous_not_implemented(self): + objects = SequenceQuery(range(5)) + paginator = KeysetPaginator(QueryDict(), objects) + + with self.assertRaises(NotImplementedError): + paginator.get_previous(limit=10, offset=20) diff --git a/corehq/apps/enterprise/api/resources.py b/corehq/apps/enterprise/api/resources.py index 09d5b85d120d..8adfebbf5bff 100644 --- a/corehq/apps/enterprise/api/resources.py +++ b/corehq/apps/enterprise/api/resources.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from urllib.parse import urljoin from django.http import HttpResponse, HttpResponseForbidden, HttpResponseNotFound @@ -19,9 +19,9 @@ from corehq.apps.api.resources import HqBaseResource from corehq.apps.api.resources.auth import ODataAuthentication from corehq.apps.api.resources.meta import get_hq_throttle -from corehq.apps.enterprise.enterprise import ( - EnterpriseReport, -) +from corehq.apps.api.keyset_paginator import KeysetPaginator +from corehq.apps.enterprise.enterprise import EnterpriseReport +from corehq.apps.enterprise.iterators import IterableEnterpriseFormQuery, EnterpriseFormReportConverter from corehq.apps.enterprise.tasks import generate_enterprise_report, ReportTaskProgress @@ -60,7 +60,8 @@ def alter_list_data_to_serialize(self, request, data): result['@odata.context'] = request.build_absolute_uri(path) meta = result['meta'] - result['@odata.count'] = meta['total_count'] + if 'total_count' in meta: + result['@odata.count'] = meta['total_count'] if 'next' in meta and meta['next']: result['@odata.nextLink'] = request.build_absolute_uri(meta['next']) @@ -139,7 +140,10 @@ def convert_datetime(cls, datetime_string): if not datetime_string: return None - time = datetime.strptime(datetime_string, EnterpriseReport.DATE_ROW_FORMAT) + if isinstance(datetime_string, str): + time = datetime.strptime(datetime_string, EnterpriseReport.DATE_ROW_FORMAT) + else: + time = datetime_string time = time.astimezone(tz.gettz('UTC')) return time.isoformat() @@ -351,6 +355,7 @@ def get_primary_keys(self): class FormSubmissionResource(ODataEnterpriseReportResource): class Meta(ODataEnterpriseReportResource.Meta): + paginator_class = KeysetPaginator limit = 10000 max_limit = 20000 @@ -358,31 +363,33 @@ class Meta(ODataEnterpriseReportResource.Meta): form_name = fields.CharField() submitted = fields.DateTimeField() app_name = fields.CharField() - mobile_user = fields.CharField() + username = fields.CharField() domain = fields.CharField() REPORT_SLUG = EnterpriseReport.FORM_SUBMISSIONS - def get_report_task(self, request): - enddate = datetime.strptime(request.GET['enddate'], '%Y-%m-%d') if 'enddate' in request.GET else None - startdate = datetime.strptime(request.GET['startdate'], '%Y-%m-%d') if 'startdate' in request.GET else None + def get_object_list(self, request): + start_date = request.GET.get('startdate', None) + if start_date: + start_date = datetime.fromisoformat(start_date).astimezone(timezone.utc) + + end_date = request.GET.get('enddate', None) + if end_date: + end_date = datetime.fromisoformat(end_date).astimezone(timezone.utc) + account = BillingAccount.get_account_by_domain(request.domain) - return generate_enterprise_report.s( - self.REPORT_SLUG, - account.id, - request.couch_user.username, - start_date=startdate, - end_date=enddate, - include_form_id=True, - ) + + converter = EnterpriseFormReportConverter() + query_kwargs = converter.get_kwargs_from_map(request.GET) + return IterableEnterpriseFormQuery(account, converter, start_date, end_date, **query_kwargs) def dehydrate(self, bundle): - bundle.data['form_id'] = bundle.obj[0] - bundle.data['form_name'] = bundle.obj[1] - bundle.data['submitted'] = self.convert_datetime(bundle.obj[2]) - bundle.data['app_name'] = bundle.obj[3] - bundle.data['mobile_user'] = bundle.obj[4] - bundle.data['domain'] = bundle.obj[6] + bundle.data['form_id'] = bundle.obj['form_id'] + bundle.data['form_name'] = bundle.obj['form_name'] + bundle.data['submitted'] = self.convert_datetime(bundle.obj['submitted']) + bundle.data['app_name'] = bundle.obj['app_name'] + bundle.data['username'] = bundle.obj['username'] + bundle.data['domain'] = bundle.obj['domain'] return bundle diff --git a/corehq/apps/enterprise/iterators.py b/corehq/apps/enterprise/iterators.py index 1fa96c54c132..0706ac55d2a3 100644 --- a/corehq/apps/enterprise/iterators.py +++ b/corehq/apps/enterprise/iterators.py @@ -1,3 +1,13 @@ +from itertools import dropwhile, chain, islice +from datetime import datetime, timedelta, timezone +from django.utils.translation import gettext as _ +from corehq.apps.es import filters +from corehq.apps.es.utils import es_format_datetime +from corehq.apps.es.forms import FormES +from corehq.apps.enterprise.exceptions import TooMuchRequestedDataError +from corehq.apps.app_manager.dbaccessors import get_brief_apps_in_domain + + def raise_after_max_elements(it, max_elements, exception=None): for total_yielded, ele in enumerate(it): if total_yielded >= max_elements: @@ -5,3 +15,228 @@ def raise_after_max_elements(it, max_elements, exception=None): raise exception yield ele + + +class IterableEnterpriseFormQuery: + ''' + A class representing a query that returns its results as an iterator + The intended use case is to support queries that cross pagination boundaries + - form_converter: an instance of a class that knows how to translate the iteration results and interpret + the previous progress arguments + - start_date: a date to start the date range. Can be None + - end_date: the inclusive date to finish the date range. Can be None + last_domain, last_time, and last_id are intended to represent the last result from a previous query + that should now be resumed. + - last_domain: the domain to resume the query on + - last_time: the timestamp from the last result of the previous query + - last_id: the id from the last result of the previous query + ''' + def __init__(self, account, form_converter, start_date, end_date, last_domain, last_time, last_id): + MAX_DATE_RANGE_DAYS = 100 + (self.start_date, self.end_date) = resolve_start_and_end_date(start_date, end_date, MAX_DATE_RANGE_DAYS) + self.account = account + self.form_converter = form_converter + self.last_domain = last_domain + self.last_time = last_time + self.last_id = last_id + + def execute(self, limit=None): + domains = self.account.get_domains() + + it = run_query_over_domains( + MobileFormSubmissionsQueryFactory(), + domains, + limit=limit, + last_domain=self.last_domain, + start_date=self.start_date, + end_date=self.end_date, + last_time=self.last_time, + last_id=self.last_id, + ) + + return (self.form_converter.convert(form) for form in it) + + def get_query_params(self, fetched_object): + return self.form_converter.get_query_params(fetched_object) + + +def resolve_start_and_end_date(start_date, end_date, maximum_date_range): + ''' + Provide start and end date values if not supplied. + ''' + if not end_date: + end_date = datetime.now(timezone.utc) + + if not start_date: + start_date = end_date - timedelta(days=30) + + if end_date - start_date > timedelta(days=maximum_date_range): + raise TooMuchRequestedDataError( + _('Date ranges with more than {} days are not supported').format(maximum_date_range) + ) + + return start_date, end_date + + +class EnterpriseFormReportConverter: + def __init__(self): + self.app_lookup = AppIdToNameResolver() + + def convert(self, form): + domain = form['domain'] + submitted_date = datetime.strptime(form['received_on'][:19], '%Y-%m-%dT%H:%M:%S') + inserted_at = datetime.strptime(form['inserted_at'][:19], '%Y-%m-%dT%H:%M:%S') + + return { + 'form_id': form['form']['meta']['instanceID'], + 'form_name': form['form']['@name'] or _('Unnamed'), + 'submitted': submitted_date, + 'inserted_at': inserted_at, + 'app_name': self.app_lookup.resolve_app_id_to_name(domain, form['app_id']) or _('App not found'), + 'username': form['form']['meta']['username'], + 'domain': domain + } + + @classmethod + def get_query_paraams(cls, fetched_object): + ''' + Takes a fetched, converted object and returns the values from this object that will be necessary + to continue where this query left off. + ''' + return { + 'domain': fetched_object['domain'], + 'inserted_at': fetched_object['inserted_at'], + 'id': fetched_object['form_id'] + } + + @classmethod + def get_kwargs_from_map(cls, map): + ''' + Takes a map-like object from a continuation request (generally GET/POST) and extracts + the parameters necessary for initializing an IterableEnterpriseFormQuery. + ''' + last_domain = map.get('domain', None) + last_time = map.get('inserted_at', None) + if last_time: + last_time = datetime.fromisoformat(last_time).astimezone(timezone.utc) + last_id = map.get('id', None) + return { + 'last_domain': last_domain, + 'last_time': last_time, + 'last_id': last_id + } + + +class AppIdToNameResolver: + def __init__(self): + self.domain_lookup_tables = {} + + def resolve_app_id_to_name(self, domain, app_id): + if 'domain' not in self.domain_lookup_tables: + domain_apps = get_brief_apps_in_domain(domain) + self.domain_lookup_tables[domain] = {a.id: a.name for a in domain_apps} + + return self.domain_lookup_tables[domain].get(app_id, None) + + +def run_query_over_domains(query_factory, domains, limit=None, last_domain=None, **kwargs): + iterators = [] + if last_domain: + remaining_domains = dropwhile(lambda d: d != last_domain, domains) + in_progress_domain = next(remaining_domains) + iterators.append(run_query_over_domain(query_factory, in_progress_domain, limit=limit, **kwargs)) + else: + remaining_domains = domains + + fresh_domain_query_args = query_factory.get_next_query_args(kwargs, last_hit=None) + for domain in remaining_domains: + iterators.append(run_query_over_domain(query_factory, domain, limit=limit, **fresh_domain_query_args)) + + yield from islice(chain.from_iterable(iterators), limit) + + +def run_query_over_domain(query_factory, domain, limit=None, **kwargs): + remaining = limit + + next_query_args = kwargs + + while True: + query = query_factory.get_query(domain, limit=limit, **next_query_args) + results = query.run() + for hit in results.hits: + last_hit = hit + yield last_hit + + num_fetched = len(results.hits) + + if num_fetched >= results.total or (remaining and num_fetched >= remaining): + break + else: + if remaining: + remaining -= num_fetched + + next_query_args = query_factory.get_next_query_args(next_query_args, last_hit) + + +class ReportQueryFactoryInterface: + ''' + A generic interface for any report queries. + ''' + def get_query(self, **kwargs): + ''' + Returns an ElasticSearch query, configured by **kwargs + ''' + raise NotImplementedError() + + def get_next_query_args(self, previous_args, last_hit): + ''' + Modifies the `previous_args` dictionary with information from `last_hit` to create + a new set of kwargs suitable to pass back to `get_query` to retrieve results beyond `last_hit` + ''' + raise NotImplementedError() + + +class MobileFormSubmissionsQueryFactory(ReportQueryFactoryInterface): + def get_query(self, domain, start_date, end_date, last_time=None, last_id=None, limit=None): + query = ( + FormES() + .domain(domain) + .user_type('mobile') + .submitted(gte=start_date, lte=end_date) + ) + + if limit: + query = query.size(limit) + + query.es_query['sort'] = [ + {'inserted_at': {'order': 'desc'}}, + {'doc_id': 'asc'} + ] + + if last_time and last_id: + # The results are first sorted by 'inserted_at', so if the previous query wasn't + # the only form submitted for its timestamp, return the others with a greater doc_id + submitted_same_time_as_previous_result = filters.AND( + filters.term('inserted_at', es_format_datetime(last_time)), + filters.range_filter('doc_id', gt=last_id) + ), + submitted_prior_to_previous_result = filters.date_range('inserted_at', lt=last_time) + + query = query.filter(filters.OR( + submitted_same_time_as_previous_result, + submitted_prior_to_previous_result + )) + + return query + + def get_next_query_args(self, previous_args, last_hit): + if last_hit: + return previous_args | { + 'last_time': last_hit['inserted_at'], + 'last_id': last_hit['doc_id'] + } + else: + new_args = previous_args.copy() + new_args.pop('last_time', None) + new_args.pop('last_id', None) + return new_args diff --git a/corehq/apps/enterprise/tests/api/__init__.py b/corehq/apps/enterprise/tests/api/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/corehq/apps/enterprise/tests/api/test_resources.py b/corehq/apps/enterprise/tests/api/test_resources.py new file mode 100644 index 000000000000..999c9352f366 --- /dev/null +++ b/corehq/apps/enterprise/tests/api/test_resources.py @@ -0,0 +1,102 @@ +from datetime import datetime, timezone +from unittest.mock import patch + +from django.http import Http404 +from django.test import RequestFactory, TestCase + +from tastypie.exceptions import ImmediateHttpResponse + +from corehq.apps.accounting.models import ( + DefaultProductPlan, + SoftwarePlanEdition, +) +from corehq.apps.accounting.tests.utils import generator +from corehq.apps.domain.models import Domain +from corehq.apps.enterprise.api.resources import ( + EnterpriseODataAuthentication, + ODataAuthentication, +) +from corehq.apps.users.models import WebUser + + +class EnterpriseODataAuthenticationTests(TestCase): + def setUp(self): + super().setUp() + patcher = patch.object(ODataAuthentication, 'is_authenticated', return_value=True) + self.mock_is_authentication = patcher.start() + self.addCleanup(patcher.stop) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.user = cls._create_user('admin@testing-domain.com') + cls.account = cls._create_enterprise_account_covering_domains(['testing-domain']) + cls.account.enterprise_admin_emails = [cls.user.username] + cls.account.save() + + def test_successful_authentication(self): + request = self._create_request(self.user, 'testing-domain') + + auth = EnterpriseODataAuthentication() + self.assertTrue(auth.is_authenticated(request)) + + def test_parent_failure_returns_parent_results(self): + self.mock_is_authentication.return_value = False + + request = self._create_request(self.user, 'testing-domain') + + auth = EnterpriseODataAuthentication() + self.assertFalse(auth.is_authenticated(request)) + + def test_raises_exception_when_billing_account_does_not_exist(self): + request = self._create_request(self.user, 'not-testing-domain') + + auth = EnterpriseODataAuthentication() + with self.assertRaises(Http404): + auth.is_authenticated(request) + + def test_raises_exception_when_not_an_enterprise_admin(self): + self.account.enterprise_admin_emails = ['not-this-user@testing-domain.com'] + self.account.save() + + request = self._create_request(self.user, 'testing-domain') + + auth = EnterpriseODataAuthentication() + with self.assertRaises(ImmediateHttpResponse): + auth.is_authenticated(request) + + @classmethod + def _create_enterprise_account_covering_domains(cls, domains): + billing_account = generator.billing_account( + 'test-admin@dimagi.com', + 'test-admin@dimagi.com', + is_customer_account=True + ) + + enterprise_plan = DefaultProductPlan.get_default_plan_version(SoftwarePlanEdition.ENTERPRISE) + + for domain in domains: + domain_obj = Domain(name=domain, is_active=True) + domain_obj.save() + cls.addClassCleanup(domain_obj.delete) + + generator.generate_domain_subscription( + billing_account, + domain_obj, + datetime.now(timezone.utc), + None, + plan_version=enterprise_plan, + is_active=True + ) + + return billing_account + + @classmethod + def _create_user(cls, username): + return WebUser(username=username) + + def _create_request(self, user, domain): + request = RequestFactory().get('/') + request.couch_user = user + request.domain = domain + return request diff --git a/corehq/apps/enterprise/tests/test_apis.py b/corehq/apps/enterprise/tests/test_apis.py new file mode 100644 index 000000000000..969f98ef0287 --- /dev/null +++ b/corehq/apps/enterprise/tests/test_apis.py @@ -0,0 +1,139 @@ +import base64 +import json +from datetime import datetime, timezone + +from django.test import RequestFactory, TestCase + +from corehq.apps.accounting.models import ( + DefaultProductPlan, + SoftwarePlanEdition, +) +from corehq.apps.accounting.tests import generator +from corehq.apps.domain.models import Domain +from corehq.apps.enterprise.api.resources import FormSubmissionResource +from corehq.apps.es.forms import form_adapter +from corehq.apps.es.tests.utils import es_test +from corehq.apps.users.models import ( + CommCareUser, + HQApiKey, + HqPermissions, + WebUser, +) +from corehq.apps.users.models_role import UserRole +from corehq.form_processor.tests.utils import create_form_for_test + + +@es_test(requires=[form_adapter]) +class FormSubmissionResourceTests(TestCase): + def test_resource_is_accessible(self): + enterprise_account = self._create_enterprise_account_covering_domains(['test-domain-1', 'test-domain-2']) + + enterprise_admin = self._create_enterprise_admin( + 'test-admin@somedomain.com', + 'test-domain-1', + enterprise_account + ) + django_user = enterprise_admin.get_django_user() + api_key = HQApiKey.objects.create(user=django_user, key='1234', name='TestKey') + + mobile_user_1 = self._create_mobile_user('mobile1@test.com', 'test-domain-1') + mobile_user_2 = self._create_mobile_user('mobile2@test.com', 'test-domain-2') + + form1 = self._create_form(mobile_user_1, form_id='1234', received_on=datetime(year=2004, month=10, day=11)) + form2 = self._create_form(mobile_user_2, form_id='2345', received_on=datetime(year=2004, month=10, day=12)) + form_adapter.bulk_index([form1, form2], refresh=True) + + request = self._create_api_request( + enterprise_admin, + 'test-domain-1', + api_key, + {'startdate': '2004-10-10', 'enddate': '2004-11-10'} + ) + + resource = FormSubmissionResource() + response = resource.dispatch_list(request) + result = json.loads(response.content) + resulting_forms = result['value'] + self.assertEqual([form['form_id'] for form in resulting_forms], ['1234', '2345']) + + def _create_form(self, user, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + '@name': 'TestForm', + 'meta': { + 'userID': user._id, + 'username': user.name, + 'instanceID': form_id, + }, + } + return create_form_for_test( + user.domain, + user_id=user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on, + ) + + def _create_enterprise_account_covering_domains(self, domains): + billing_account = generator.billing_account( + 'test-admin@dimagi.com', + 'test-admin@dimagi.com', + is_customer_account=True + ) + + # Enterprise is needed to grant API and OData permissions + enterprise_plan = DefaultProductPlan.get_default_plan_version(SoftwarePlanEdition.ENTERPRISE) + + for domain in domains: + domain_obj = Domain(name=domain, is_active=True) + domain_obj.save() + self.addCleanup(domain_obj.delete) + + generator.generate_domain_subscription( + billing_account, + domain_obj, + datetime.now(timezone.utc), + None, + plan_version=enterprise_plan, + is_active=True + ) + + return billing_account + + def _create_enterprise_admin(self, email, domain, enterprise_account): + user = WebUser.create( + domain, email, 'test123', None, None, email) + + # Users need to have permission to view OData reports on the specified domain to access these APIs + permissions = HqPermissions(view_reports=True) + role = UserRole.create(domain, email, permissions=permissions) + membership = user.get_domain_membership(domain) + membership.role_id = role.couch_id + user.save() + + enterprise_account.enterprise_admin_emails.append(email) + enterprise_account.save() + + self.addCleanup(user.delete, None, deleted_by=None) + + return user + + def _create_mobile_user(self, username, domain): + user = CommCareUser.create(domain, username, 'test123', None, None) + self.addCleanup(user.delete, None, deleted_by=None) + return user + + def _create_api_request(self, user, domain, api_key, data=None): + auth_string = f'{user.username}:{api_key.key}' + factory = RequestFactory() + encoded_auth = base64.b64encode(auth_string.encode()).decode() + request = factory.get( + '/', + data, + HTTP_AUTHORIZATION=f'basic {encoded_auth}' + ) + request.couch_user = user + request.user = user.get_django_user() + request.domain = domain + + return request diff --git a/corehq/apps/enterprise/tests/test_iterators.py b/corehq/apps/enterprise/tests/test_iterators.py index 3ffc80c12a7c..fdfe3a4a5f00 100644 --- a/corehq/apps/enterprise/tests/test_iterators.py +++ b/corehq/apps/enterprise/tests/test_iterators.py @@ -1,6 +1,17 @@ -from django.test import SimpleTestCase +from unittest.mock import patch +from django.test import SimpleTestCase, TestCase +from datetime import datetime -from corehq.apps.enterprise.iterators import raise_after_max_elements +from corehq.apps.es.forms import form_adapter, ElasticForm +from corehq.apps.es.tests.utils import es_test +from corehq.apps.users.models import CommCareUser +from corehq.form_processor.tests.utils import create_form_for_test +from corehq.apps.enterprise.iterators import ( + raise_after_max_elements, + run_query_over_domains, + run_query_over_domain, + MobileFormSubmissionsQueryFactory +) class TestRaiseAfterMaxElements(SimpleTestCase): @@ -17,3 +28,198 @@ def test_iterating_beyond_max_items_will_raise_provided_exception(self): def test_can_iterate_through_all_elements_with_no_exception(self): it = raise_after_max_elements([1, 2, 3], 3) self.assertEqual(list(it), [1, 2, 3]) + + +@es_test(requires=[form_adapter]) +class TestLoopOverDomains(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + inserted_at_mapping_patcher = map_received_on_to_inserted_at() + inserted_at_mapping_patcher.start() + self.addCleanup(inserted_at_mapping_patcher.stop) + + def test_iterates_through_multiple_domains(self): + forms = [ + self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)), + self._create_form('domain2', form_id='2', received_on=datetime(year=2024, month=7, day=2)), + self._create_form('domain3', form_id='3', received_on=datetime(year=2024, month=7, day=3)), + ] + form_adapter.bulk_index(forms, refresh=True) + + it = run_query_over_domains( + MobileFormSubmissionsQueryFactory(), + ['domain1', 'domain2', 'domain3'], + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2', '3']) + + def test_respects_limit_across_multiple_domains(self): + forms = [ + self._create_form('domain1', form_id='1', received_on=datetime(year=2024, month=7, day=1)), + self._create_form('domain1', form_id='2', received_on=datetime(year=2024, month=7, day=2)), + self._create_form('domain2', form_id='3', received_on=datetime(year=2024, month=7, day=3)), + self._create_form('domain2', form_id='4', received_on=datetime(year=2024, month=7, day=4)), + ] + form_adapter.bulk_index(forms, refresh=True) + + it = run_query_over_domains( + MobileFormSubmissionsQueryFactory(), + ['domain1', 'domain2'], + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + limit=3 + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['2', '1', '4']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on + ) + + +@es_test(requires=[form_adapter]) +class TestLoopOverDomain(TestCase): + def setUp(self): + self.user = CommCareUser.create('test-domain', 'test-user', 'password', None, None) + self.addCleanup(self.user.delete, None, None) + + inserted_at_mapping_patcher = map_received_on_to_inserted_at() + inserted_at_mapping_patcher.start() + self.addCleanup(inserted_at_mapping_patcher.stop) + + def test_iterates_through_all_forms_in_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=3)) + form3 = self._create_form('test-domain', form_id='3', received_on=datetime(year=2024, month=7, day=4)) + form_adapter.bulk_index([form1, form2, form3], refresh=True) + + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['3', '2', '1']) + + def test_handles_empty_domain(self): + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'empty-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + self.assertEqual(list(it), []) + + def test_includes_inclusive_boundaries(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2) + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['2', '1']) + + def test_ignores_form_in_another_domain(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=2)) + form2 = self._create_form('not-test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=2)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=15), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1']) + + def test_sorts_by_date_then_id(self): + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form_adapter.bulk_index([form2, form1], refresh=True) + + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2), + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1', '2']) + + def test_does_not_return_forms_beyond_limit(self): + form1 = self._create_form('test-domain', form_id='1', received_on=datetime(year=2024, month=7, day=1)) + form2 = self._create_form('test-domain', form_id='2', received_on=datetime(year=2024, month=7, day=1)) + form_adapter.bulk_index([form1, form2], refresh=True) + + it = run_query_over_domain( + MobileFormSubmissionsQueryFactory(), + 'test-domain', + start_date=datetime(year=2024, month=7, day=1), + end_date=datetime(year=2024, month=7, day=2), + limit=1 + ) + + form_ids = [form['_id'] for form in list(it)] + self.assertEqual(form_ids, ['1']) + + def _create_form(self, domain, form_id=None, received_on=None): + form_data = { + '#type': 'fake-type', + 'meta': { + 'userID': self.user._id, + 'instanceID': form_id, + }, + } + return create_form_for_test( + domain, + user_id=self.user._id, + form_data=form_data, + form_id=form_id, + received_on=received_on, + ) + + +def map_received_on_to_inserted_at(): + ''' + A patcher to use the date value found in a form's 'received_on' field for the 'inserted_at' value. + Without this patch, 'inserted_at' will be `utcnow()`, which would require knowing the order and number of times + that `utcnow()` would be called to manipulate dates + ''' + original = ElasticForm._from_dict + + def from_dict(cls, xform_dict): + id, result = original(cls, xform_dict) + result['inserted_at'] = result['received_on'] + return (id, result) + + return patch.object(ElasticForm, '_from_dict', new=from_dict) diff --git a/corehq/apps/es/forms.py b/corehq/apps/es/forms.py index 60aed8ec325e..6354a439b0dc 100644 --- a/corehq/apps/es/forms.py +++ b/corehq/apps/es/forms.py @@ -43,6 +43,7 @@ def builtin_filters(self): form_ids, xmlns, app, + inserted, submitted, completed, user_id, @@ -186,6 +187,10 @@ def submitted(gt=None, gte=None, lt=None, lte=None): return filters.date_range('received_on', gt, gte, lt, lte) +def inserted(gt=None, gte=None, lt=None, lte=None): + return filters.date_range('inserted_at', gt, gte, lt, lte) + + def completed(gt=None, gte=None, lt=None, lte=None): return filters.date_range('form.meta.timeEnd', gt, gte, lt, lte)