Skip to content

Commit

Permalink
Merge pull request #35265 from dimagi/es/module-badges
Browse files Browse the repository at this point in the history
Module Badges Fixture
  • Loading branch information
esoergel authored Nov 6, 2024
2 parents 1dec87e + cf08c83 commit c730bf3
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 9 deletions.
13 changes: 7 additions & 6 deletions corehq/apps/app_manager/suite_xml/post_process/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,13 @@ def __call__(self, fn):


INSTANCE_KWARGS_BY_ID = {
'groups': dict(id='groups', src='jr://fixture/user-groups'),
'reports': dict(id='reports', src='jr://fixture/commcare:reports'),
'ledgerdb': dict(id='ledgerdb', src='jr://instance/ledgerdb'),
'casedb': dict(id='casedb', src='jr://instance/casedb'),
'commcaresession': dict(id='commcaresession', src='jr://instance/session'),
'registry': dict(id='registry', src='jr://instance/remote/registry'),
'groups': {'id': 'groups', 'src': 'jr://fixture/user-groups'},
'reports': {'id': 'reports', 'src': 'jr://fixture/commcare:reports'},
'ledgerdb': {'id': 'ledgerdb', 'src': 'jr://instance/ledgerdb'},
'casedb': {'id': 'casedb', 'src': 'jr://instance/casedb'},
'commcaresession': {'id': 'commcaresession', 'src': 'jr://instance/session'},
'registry': {'id': 'registry', 'src': 'jr://instance/remote/registry'},
'case-search-fixture': {'id': 'case-search-fixture', 'src': 'jr://fixture/case-search-fixture'},
}


Expand Down
4 changes: 4 additions & 0 deletions corehq/apps/case_search/filter_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def build_filter_from_xpath(xpath, *, domain=None, context=None):
raise CaseFilterError(error_message.format(bad_part, ", ".join(ALL_OPERATORS)), bad_part)
raise CaseFilterError(_("Malformed search query"), None)
except RuntimeError as e:
# eulxml doesn't appear to clean up after this type of failure
# properly, so throw in an extra 'parse' to reset it
parse_xpath("thisisdumb")

# eulxml passes us string errors from YACC
lex_token_error = re.search(r"LexToken\((\w+),\w?'(.+)'", str(e))
if lex_token_error:
Expand Down
76 changes: 76 additions & 0 deletions corehq/apps/case_search/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from lxml.builder import E

from casexml.apps.phone.fixtures import FixtureProvider

from corehq import extensions
from corehq.apps.case_search.exceptions import CaseFilterError
from corehq.apps.case_search.filter_dsl import build_filter_from_xpath
from corehq.apps.es.case_search import CaseSearchES
from corehq.messaging.templating import (
MessagingTemplateRenderer,
NestedDictTemplateParam,
)
from corehq.toggles import MODULE_BADGES

from .models import CSQLFixtureExpression


def _get_user_template_info(restore_user):
return {
"username": restore_user.username,
"uuid": restore_user.user_id,
"user_data": restore_user.user_session_data,
"location_ids": restore_user.get_location_ids(restore_user.domain),
}


def _get_template_renderer(restore_user):
renderer = MessagingTemplateRenderer()
renderer.set_context_param('user', NestedDictTemplateParam(_get_user_template_info(restore_user)))
for name, param in custom_csql_fixture_context(restore_user.domain, restore_user):
renderer.set_context_param(name, param)
return renderer


@extensions.extension_point
def custom_csql_fixture_context(domain, restore_user):
'''Register custom template params to be available in CSQL templates'''


def _run_query(domain, csql):
try:
filter_ = build_filter_from_xpath(csql, domain=domain)
except CaseFilterError:
return "ERROR"
return str(CaseSearchES()
.domain(domain)
.filter(filter_)
.count())


def _get_indicator_nodes(restore_state, indicators):
with restore_state.timing_context('_get_template_renderer'):
renderer = _get_template_renderer(restore_state.restore_user)
for name, csql_template in indicators:
with restore_state.timing_context(name):
value = _run_query(restore_state.domain, renderer.render(csql_template))
yield E.value(value, name=name)


class CaseSearchFixtureProvider(FixtureProvider):
id = 'case-search-fixture'

def __call__(self, restore_state):
if not MODULE_BADGES.enabled(restore_state.domain):
return
indicators = _get_indicators(restore_state.domain)
if indicators:
nodes = _get_indicator_nodes(restore_state, indicators)
yield E.fixture(E.values(*nodes), id=self.id)


def _get_indicators(domain):
return list(CSQLFixtureExpression.by_domain(domain).values_list('name', 'csql'))


case_search_fixture_generator = CaseSearchFixtureProvider()
113 changes: 113 additions & 0 deletions corehq/apps/case_search/tests/test_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import uuid
from unittest.mock import patch

from django.test import TestCase

from lxml import etree

from casexml.apps.case.mock import CaseBlock
from casexml.apps.phone.tests.utils import call_fixture_generator

from corehq.apps.domain.shortcuts import create_domain
from corehq.apps.es.case_search import case_search_adapter
from corehq.apps.es.tests.utils import case_search_es_setup, es_test
from corehq.apps.users.models import WebUser
from corehq.tests.util.xml import assert_xml_equal, assert_xml_partial_equal
from corehq.util.test_utils import flag_enabled

from ..fixtures import _get_template_renderer, case_search_fixture_generator


@flag_enabled('MODULE_BADGES')
@es_test(requires=[case_search_adapter], setup_class=True)
class TestCaseSearchFixtures(TestCase):
domain_name = 'test-case-search-fixtures'

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.domain_obj = create_domain(cls.domain_name)
cls.user = WebUser.create(cls.domain_name, '[email protected]', 'secret', None, None)
cls.restore_user = cls.user.to_ota_restore_user(cls.domain_name)
case_search_es_setup(cls.domain_name, cls._get_case_blocks(cls.user.user_id))
cls.addClassCleanup(cls.domain_obj.delete)

@staticmethod
def _get_case_blocks(owner_id):
def case_block(case_type, name, owner_id):
return CaseBlock(
case_id=str(uuid.uuid4()),
case_type=case_type,
case_name=name,
owner_id=owner_id,
create=True,
)

return [
case_block('client', 'Kleo', owner_id),
case_block('client', 'Sven', owner_id),
case_block('client', 'Thilo', '---'),
case_block('place', 'Berlin', owner_id),
case_block('place', 'Sirius B', '---'),
]

def render(self, template_string):
return _get_template_renderer(self.restore_user).render(template_string)

def generate_fixture(self):
res = call_fixture_generator(case_search_fixture_generator, self.restore_user, self.domain_obj)
return etree.tostring(next(res), encoding='utf-8')

def test_no_interpolation(self):
res = self.render("dob < '2020-01-01'")
self.assertEqual(res, "dob < '2020-01-01'")

def test_user_id(self):
res = self.render("@owner_id = '{user.uuid}'")
self.assertEqual(res, f"@owner_id = '{self.user.user_id}'")

@patch('custom.bha.commcare_extensions.get_user_clinic_ids')
def test_bha_custom_csql_fixture_context(self, get_user_clinic_ids):
self.restore_user.domain = 'co-carecoordination'

def reset_domain():
self.restore_user.domain = self.domain_name
self.addCleanup(reset_domain)

get_user_clinic_ids.return_value = "abc123 def456"
res = self.render("selected(@owner_id, '{bha.user_clinic_ids}')")
self.assertEqual(res, "selected(@owner_id, 'abc123 def456')")

@patch('corehq.apps.case_search.fixtures._get_indicators')
@patch('corehq.apps.case_search.fixtures._run_query')
def test_fixture_generator(self, run_query, get_indicators):
run_query.return_value = "42"
get_indicators.return_value = [
('pre_pandemic_births', "dob < '2020-01-01'"),
('owned_by_user', "@owner_id = '{user.uuid}'"),
]

expected = """
<fixture id="case-search-fixture">
<values>
<value name="pre_pandemic_births">42</value>
<value name="owned_by_user">42</value>
</values>
</fixture>"""
assert_xml_equal(expected, self.generate_fixture())

@patch('corehq.apps.case_search.fixtures._get_indicators')
def test_full_query(self, get_indicators):
indicators = [
# (name, csql_template, expected_count)
('owned_by_user', "@owner_id = '{user.uuid}'", 3),
('total_clients', "@case_type = 'client'", 3),
('own_clients', "@case_type = 'client' and @owner_id = '{user.uuid}'", 2),
('bad_query', "this is not a valid query", "ERROR"),
]
get_indicators.return_value = [(name, csql_template) for name, csql_template, _ in indicators]

res = self.generate_fixture()
for name, _, expected in indicators:
expected_xml = f'<partial><value name="{name}">{expected}</value></partial>'
assert_xml_partial_equal(expected_xml, res, f'./values/value[@name="{name}"]')
9 changes: 6 additions & 3 deletions corehq/ex-submodules/casexml/apps/phone/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def __init__(
is_async: bool = False,
overwrite_cache: bool = False,
auth_type: Optional[str] = None,
timing_context: Optional[TimingContext] = None,
):
if not project or not project.name:
raise Exception('you are not allowed to make a RestoreState without a domain!')
Expand All @@ -380,6 +381,7 @@ def __init__(
self.overwrite_cache = overwrite_cache
self.auth_type = auth_type
self._last_sync_log = Ellipsis
self.timing_context = timing_context or TimingContext()

def validate_state(self):
check_version(self.params.version)
Expand Down Expand Up @@ -532,20 +534,21 @@ def __init__(self, project=None, restore_user=None, params=None,
self.is_async = is_async
self.skip_fixtures = skip_fixtures

self.timing_context = TimingContext('restore-{}-{}'.format(self.domain, self.restore_user.username))

self.restore_state = RestoreState(
self.project,
self.restore_user,
self.params, is_async,
self.cache_settings.overwrite_cache,
auth_type=auth_type
auth_type=auth_type,
timing_context=self.timing_context,
)

self.force_cache = self.cache_settings.force_cache
self.cache_timeout = self.cache_settings.cache_timeout
self.overwrite_cache = self.cache_settings.overwrite_cache

self.timing_context = TimingContext('restore-{}-{}'.format(self.domain, self.restore_user.username))

@property
@memoized
def sync_log(self):
Expand Down
16 changes: 16 additions & 0 deletions corehq/messaging/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ def __getattr__(self, item):
return SimpleMessagingTemplateParam(UNKNOWN_VALUE)


class NestedDictTemplateParam(SimpleDictTemplateParam):

def __init__(self, dict_of_values):
self.__dict_of_values = dict_of_values

def __getattr__(self, item):
"""Works just like SimpleDictTemplateParam but it can contain nested dicts"""
if val := self.__dict_of_values.get(item):
if isinstance(val, dict):
return NestedDictTemplateParam(val)

return SimpleMessagingTemplateParam(val)

return SimpleMessagingTemplateParam(UNKNOWN_VALUE)


class CaseMessagingTemplateParam(SimpleDictTemplateParam):

def __init__(self, case):
Expand Down
Empty file added custom/bha/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions custom/bha/commcare_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from django.conf import settings

from corehq.apps.case_search.fixtures import custom_csql_fixture_context
from corehq.messaging.templating import SimpleDictTemplateParam

from .util import get_user_clinic_ids, get_user_facility_ids

BHA_DOMAINS = settings.CUSTOM_DOMAINS_BY_MODULE['custom.bha']


@custom_csql_fixture_context.extend(domains=BHA_DOMAINS)
def bha_csql_fixture_context(domain, restore_user):
facility_ids = get_user_facility_ids(domain, restore_user)
return ('bha', SimpleDictTemplateParam({
'facility_ids': facility_ids,
'user_clinic_ids': get_user_clinic_ids(domain, restore_user, facility_ids),
}))
Empty file added custom/bha/tests/__init__.py
Empty file.
39 changes: 39 additions & 0 deletions custom/bha/tests/test_csql_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from corehq.apps.locations.tests.util import LocationHierarchyTestCase
from corehq.apps.users.models import WebUser

from ..util import get_user_facility_ids


class UserFacilityTests(LocationHierarchyTestCase):
domain = 'bha-user-facility-tests'
location_type_names = ['state', 'registry', 'organization', 'facility', 'facility_data']
location_structure = [
('state-1', [
('registry-1', [
('organization-1', [
('facility-1a', []),
('facility-1b', [
('facility_data-1b', []),
]),
]),
('organization-2', [
('facility-2a', []),
('facility-2b', []),
]),
]),
])
]

def test_get_user_facility_ids(self):
user = WebUser.create(self.domain, '[email protected]', 'secret', None, None)
user.add_to_assigned_locations(self.domain, self.locations['organization-1'])
user.add_to_assigned_locations(self.domain, self.locations['facility-2a'])
restore_user = user.to_ota_restore_user(self.domain)
res = get_user_facility_ids(self.domain, restore_user)
self.assertItemsEqual(res, [
self.locations[loc_name].location_id for loc_name in [
'facility-1a',
'facility-1b',
'facility-2a',
]
])
24 changes: 24 additions & 0 deletions custom/bha/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from corehq.apps.locations.models import SQLLocation
from corehq.apps.es.case_search import CaseSearchES


def get_user_facility_ids(domain, restore_user):
# Facility locations the user is either directly assigned to, or which are
# children of organizations the user is assigned to
owned_locs = (restore_user.get_sql_locations(domain)
.filter(location_type__code__in=['organization', 'facility']))
return list(SQLLocation.objects
.get_queryset_descendants(owned_locs, include_self=True)
.filter(location_type__code='facility')
.location_ids())


def get_user_clinic_ids(domain, restore_user, facility_ids):
return " ".join(
CaseSearchES()
.domain(domain)
.case_type('clinic')
.is_closed(False)
.owner(facility_ids)
.get_ids()
)
Loading

0 comments on commit c730bf3

Please sign in to comment.