diff --git a/corehq/apps/api/resources/v0_5.py b/corehq/apps/api/resources/v0_5.py index 04596a18f61a..5a3f5a52ccb9 100644 --- a/corehq/apps/api/resources/v0_5.py +++ b/corehq/apps/api/resources/v0_5.py @@ -230,6 +230,8 @@ def detail_uri_kwargs(self, bundle_or_obj): class CommCareUserResource(v0_1.CommCareUserResource): + primary_location = fields.CharField() + locations = fields.ListField() class Meta(v0_1.CommCareUserResource.Meta): detail_allowed_methods = ['get', 'put', 'delete'] @@ -320,10 +322,22 @@ def obj_delete(self, bundle, **kwargs): deleted_via=USER_CHANGE_VIA_API) return ImmediateHttpResponse(response=http.HttpAccepted()) + def dehydrate_primary_location(self, bundle): + return bundle.obj.get_location_id(bundle.obj.domain) + + def dehydrate_locations(self, bundle): + return bundle.obj.get_location_ids(bundle.obj.domain) + @classmethod def _update(cls, bundle, user_change_logger=None): errors = [] - for key, value in bundle.data.items(): + + location_object = {'primary_location': bundle.data.pop('primary_location', None), + 'locations': bundle.data.pop('locations', None)} + + items_to_update = list(bundle.data.items()) + [('location', location_object)] + + for key, value in items_to_update: try: update(bundle.obj, key, value, user_change_logger) except UpdateUserException as e: diff --git a/corehq/apps/api/tests/user_resources.py b/corehq/apps/api/tests/test_user_resources.py similarity index 91% rename from corehq/apps/api/tests/user_resources.py rename to corehq/apps/api/tests/test_user_resources.py index 643343dad9ce..fffe39e8a127 100644 --- a/corehq/apps/api/tests/user_resources.py +++ b/corehq/apps/api/tests/test_user_resources.py @@ -19,6 +19,7 @@ from corehq.apps.es.tests.utils import es_test from corehq.apps.es.users import user_adapter from corehq.apps.groups.models import Group +from corehq.apps.locations.models import LocationType, SQLLocation from corehq.apps.users.analytics import update_analytics_indexes from corehq.apps.users.audit.change_messages import UserChangeMessage from corehq.apps.users.model_log import UserModelAction @@ -70,6 +71,11 @@ def setUpClass(cls): definition=cls.definition, ) cls.profile.save() + cls.loc_type = LocationType.objects.create(domain=cls.domain.name, name='loc_type') + cls.loc1 = SQLLocation.objects.create( + location_id='loc1', location_type=cls.loc_type, domain=cls.domain.name) + cls.loc2 = SQLLocation.objects.create( + location_id='loc2', location_type=cls.loc_type, domain=cls.domain.name) @classmethod def tearDownClass(cls): @@ -81,6 +87,8 @@ def test_get_list(self): commcare_user = CommCareUser.create(domain=self.domain.name, username='fake_user', password='*****', created_by=None, created_via=None) self.addCleanup(commcare_user.delete, self.domain.name, deleted_by=None) + commcare_user.set_location(self.loc1) + commcare_user.set_location(self.loc2) backend_id = commcare_user.get_id update_analytics_indexes() @@ -100,8 +108,13 @@ def test_get_list(self): 'last_name': '', 'phone_numbers': [], 'resource_uri': '/a/qwerty/api/v0.5/user/{}/'.format(backend_id), - 'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': ''}, - 'username': 'fake_user' + 'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': '', + 'commcare_location_id': self.loc2.location_id, + 'commcare_primary_case_sharing_id': self.loc2.location_id, + 'commcare_location_ids': f'{self.loc1.location_id} {self.loc2.location_id}'}, + 'username': 'fake_user', + 'primary_location': self.loc2.location_id, + 'locations': [self.loc1.location_id, self.loc2.location_id] }) @flaky @@ -109,6 +122,8 @@ def test_get_single(self): commcare_user = CommCareUser.create(domain=self.domain.name, username='fake_user', password='*****', created_by=None, created_via=None) self.addCleanup(commcare_user.delete, self.domain.name, deleted_by=None) + commcare_user.set_location(self.loc1) + commcare_user.set_location(self.loc2) backend_id = commcare_user._id response = self._assert_auth_get_resource(self.single_endpoint(backend_id)) @@ -126,8 +141,15 @@ def test_get_single(self): 'last_name': '', 'phone_numbers': [], 'resource_uri': '/a/qwerty/api/v0.5/user/{}/'.format(backend_id), - 'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': ''}, + 'user_data': {'commcare_project': 'qwerty', + PROFILE_SLUG: '', + 'imaginary': '', + 'commcare_location_id': self.loc2.location_id, + 'commcare_primary_case_sharing_id': self.loc2.location_id, + 'commcare_location_ids': f'{self.loc1.location_id} {self.loc2.location_id}'}, 'username': 'fake_user', + 'primary_location': self.loc2.location_id, + 'locations': [self.loc1.location_id, self.loc2.location_id] }) def test_create(self): @@ -153,7 +175,10 @@ def test_create(self): ], "user_data": { "chw_id": "13/43/DFA" - } + }, + 'locations': [self.loc1.location_id, self.loc2.location_id], + 'primary_location': self.loc1.location_id + } response = self._assert_auth_post_resource(self.list_endpoint, json.dumps(user_json), @@ -171,6 +196,9 @@ def test_create(self): self.assertEqual(user_back.get_group_ids()[0], group._id) self.assertEqual(user_back.get_user_data(self.domain.name)["chw_id"], "13/43/DFA") self.assertEqual(user_back.default_phone_number, "50253311399") + self.assertEqual(user_back.get_location_ids(self.domain.name), + [self.loc1.location_id, self.loc2.location_id]) + self.assertEqual(user_back.get_location_id(self.domain.name), self.loc1.location_id) @flag_enabled('COMMCARE_CONNECT') def test_create_connect_user_no_password(self): @@ -266,7 +294,9 @@ def test_update(self): PROFILE_SLUG: self.profile.id, "chw_id": "13/43/DFA" }, - "password": "qwerty1234" + "password": "qwerty1234", + 'locations': [self.loc1.location_id, self.loc2.location_id], + 'primary_location': self.loc1.location_id } backend_id = user._id @@ -288,6 +318,9 @@ def test_update(self): self.assertEqual(user_data.profile_id, self.profile.id) self.assertEqual(user_data["imaginary"], "yes") self.assertEqual(modified.default_phone_number, "50253311399") + self.assertEqual(modified.get_location_ids(self.domain.name), + [self.loc1.location_id, self.loc2.location_id]) + self.assertEqual(modified.get_location_id(self.domain.name), self.loc1.location_id) # test user history audit user_history = UserHistory.objects.get(action=UserModelAction.UPDATE.value, @@ -300,6 +333,8 @@ def test_update(self): 'last_name': 'last', 'first_name': 'test', 'user_data': {'chw_id': '13/43/DFA'}, + 'location_id': 'loc1', + 'assigned_location_ids': ['loc1', 'loc2'] } ) self.assertTrue("50253311398" in diff --git a/corehq/apps/api/tests/test_user_updates.py b/corehq/apps/api/tests/test_user_updates.py index 955618af93e0..534987175479 100644 --- a/corehq/apps/api/tests/test_user_updates.py +++ b/corehq/apps/api/tests/test_user_updates.py @@ -11,9 +11,11 @@ ) from corehq.apps.domain.shortcuts import create_domain from corehq.apps.groups.models import Group +from corehq.apps.locations.models import LocationType, SQLLocation from corehq.apps.user_importer.helpers import UserChangeLogger from corehq.apps.users.audit.change_messages import ( GROUPS_FIELD, + LOCATION_FIELD, PASSWORD_FIELD, ROLE_FIELD, ) @@ -31,6 +33,11 @@ def setUpClass(cls): cls.domain = 'test-domain' cls.domain_obj = create_domain(cls.domain) cls.addClassCleanup(cls.domain_obj.delete) + cls.loc_type = LocationType.objects.create(domain=cls.domain, name='loc_type') + cls.loc1 = SQLLocation.objects.create( + location_id='loc1', location_type=cls.loc_type, domain=cls.domain) + cls.loc2 = SQLLocation.objects.create( + location_id='loc2', location_type=cls.loc_type, domain=cls.domain) def setUp(self) -> None: super().setUp() @@ -190,6 +197,48 @@ def _setup_profile(self): profile.save() return profile.id + def test_update_locations_raises_if_primary_location_not_in_location_list(self): + with self.assertRaises(UpdateUserException) as e: + update(self.user, 'location', + {'primary_location': self.loc2.location_id, 'locations': [self.loc1.location_id]}) + + self.assertEqual(str(e.exception), 'Primary location must be included in the list of locations.') + + def test_update_locations_raises_if_any_location_does_not_exist(self): + with self.assertRaises(UpdateUserException) as e: + update(self.user, 'location', + {'primary_location': 'fake_loc', 'locations': [self.loc1.location_id, 'fake_loc']}) + self.assertEqual(str(e.exception), "Could not find location ids: fake_loc.") + + def test_update_locations_raises_if_primary_location_not_provided(self): + with self.assertRaises(UpdateUserException) as e: + update(self.user, 'location', {'locations': [self.loc1.location_id]}) + self.assertEqual(str(e.exception), 'Both primary_location and locations must be provided together.') + + def test_update_locations_raises_if_locations_not_provided(self): + with self.assertRaises(UpdateUserException) as e: + update(self.user, 'location', {'primary_location': self.loc1.location_id}) + self.assertEqual(str(e.exception), 'Both primary_location and locations must be provided together.') + + def test_update_locations_do_nothing_if_nothing_provided(self): + self.user.set_location(self.loc1) + update(self.user, 'location', {'primary_location': None, 'locations': None}) + self.assertEqual(self.user.get_location_ids(self.domain), [self.loc1.location_id]) + self.assertEqual(self.user.get_location_id(self.domain), self.loc1.location_id) + + def test_update_locations_removes_locations_if_empty_string_provided(self): + self.user.set_location(self.loc1) + update(self.user, 'location', {'primary_location': '', 'locations': []}) + self.assertEqual(self.user.get_location_ids(self.domain), []) + self.assertEqual(self.user.get_location_id(self.domain), None) + + def test_update_locations_succeeds(self): + update(self.user, 'location', + {'primary_location': self.loc1.location_id, + 'locations': [self.loc1.location_id, self.loc2.location_id]}) + self.assertEqual(self.user.get_location_ids(self.domain), [self.loc1.location_id, self.loc2.location_id]) + self.assertEqual(self.user.get_location_id(self.domain), self.loc1.location_id) + class TestUpdateUserMethodsLogChanges(TestCase): @@ -199,6 +248,11 @@ def setUpClass(cls): cls.domain = 'test-domain' cls.domain_obj = create_domain(cls.domain) cls.addClassCleanup(cls.domain_obj.delete) + cls.loc_type = LocationType.objects.create(domain=cls.domain, name='loc_type') + cls.loc1 = SQLLocation.objects.create( + location_id='loc1', location_type=cls.loc_type, domain=cls.domain) + cls.loc2 = SQLLocation.objects.create( + location_id='loc2', location_type=cls.loc_type, domain=cls.domain) def setUp(self) -> None: super().setUp() @@ -332,3 +386,25 @@ def test_update_user_role_does_not_logs_change(self): update(self.user, 'role', 'edit-data', user_change_logger=self.user_change_logger) self.assertNotIn(ROLE_FIELD, self.user_change_logger.change_messages.keys()) + + def test_update_location_logs_change(self): + update(self.user, 'location', + {'primary_location': self.loc1.location_id, + 'locations': [self.loc1.location_id, self.loc2.location_id]}, + user_change_logger=self.user_change_logger) + self.assertIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys()) + + def test_update_location_without_include_location_fields_does_not_log_change(self): + update(self.user, 'location', + {'primary_location': None, 'locations': None}, + user_change_logger=self.user_change_logger) + self.assertNotIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys()) + + def test_update_location_with_current_locations_does_not_log_change(self): + self.user.set_location(self.loc2) + self.user.set_location(self.loc1) + update(self.user, 'location', + {'primary_location': self.loc1.location_id, + 'locations': [self.loc1.location_id, self.loc2.location_id]}, + user_change_logger=self.user_change_logger) + self.assertNotIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys()) diff --git a/corehq/apps/api/user_updates.py b/corehq/apps/api/user_updates.py index 433fc6106f93..a823c80a4e52 100644 --- a/corehq/apps/api/user_updates.py +++ b/corehq/apps/api/user_updates.py @@ -1,5 +1,6 @@ from django.core.exceptions import ValidationError from django.utils.translation import gettext as _ +from corehq.apps.locations.models import SQLLocation from dimagi.utils.couch.bulk import get_docs @@ -35,6 +36,7 @@ def update(user, field, value, user_change_logger=None): 'phone_numbers': _update_phone_numbers, 'user_data': _update_user_data, 'role': _update_user_role, + 'location': _update_location, }.get(field) if not update_fn: @@ -165,3 +167,69 @@ def _log_phone_number_change(new_phone_numbers, old_phone_numbers, user_change_l if change_messages: user_change_logger.add_change_message({'phone_numbers': change_messages}) + + +def _update_location(user, location_object, user_change_logger): + primary_location_id = location_object.get('primary_location') + location_ids = location_object.get('locations') + + if primary_location_id is None and location_ids is None: + return + + current_primary_location_id = user.get_location_id(user.domain) + current_locations = user.get_location_ids(user.domain) + + if not primary_location_id and not location_ids: + _remove_all_locations(user, user_change_logger) + else: + if _validate_locations(primary_location_id, location_ids): + locations = _verify_location_ids(location_ids, user.domain) + if primary_location_id != current_primary_location_id: + _update_primary_location(user, primary_location_id, user_change_logger) + if set(current_locations) != set(location_ids): + _update_assigned_locations(user, locations, location_ids, user_change_logger) + + +def _validate_locations(primary_location_id, location_ids): + if not primary_location_id and not location_ids: + return False + if not primary_location_id or not location_ids: + raise UpdateUserException(_('Both primary_location and locations must be provided together.')) + if primary_location_id not in location_ids: + raise UpdateUserException(_('Primary location must be included in the list of locations.')) + return True + + +def _remove_all_locations(user, user_change_logger): + user.unset_location(commit=False) + user.reset_locations([], commit=False) + if user_change_logger: + user_change_logger.add_changes({'location_id': None}) + user_change_logger.add_info(UserChangeMessage.primary_location_removed()) + user_change_logger.add_changes({'assigned_location_ids': []}) + user_change_logger.add_info(UserChangeMessage.assigned_locations_info([])) + + +def _update_primary_location(user, primary_location_id, user_change_logger): + primary_location = SQLLocation.active_objects.get(location_id=primary_location_id) + user.set_location(primary_location, commit=False) + if user_change_logger: + user_change_logger.add_changes({'location_id': primary_location_id}) + user_change_logger.add_info(UserChangeMessage.primary_location_info(primary_location)) + + +def _verify_location_ids(location_ids, domain): + locations = SQLLocation.active_objects.filter(location_id__in=location_ids, domain=domain) + real_ids = [loc.location_id for loc in locations] + + if missing_ids := set(location_ids) - set(real_ids): + raise UpdateUserException(f"Could not find location ids: {', '.join(missing_ids)}.") + + return locations + + +def _update_assigned_locations(user, locations, location_ids, user_change_logger): + user.reset_locations(location_ids, commit=False) + if user_change_logger: + user_change_logger.add_changes({'assigned_location_ids': location_ids}) + user_change_logger.add_info(UserChangeMessage.assigned_locations_info(locations)) diff --git a/corehq/motech/repeaters/tests/test_repeater.py b/corehq/motech/repeaters/tests/test_repeater.py index c5457ee7321d..8841d41f13f0 100644 --- a/corehq/motech/repeaters/tests/test_repeater.py +++ b/corehq/motech/repeaters/tests/test_repeater.py @@ -947,6 +947,8 @@ def test_trigger(self): 'email': '', 'eulas': '[]', 'resource_uri': '/a/user-repeater/api/v0.5/user/{}/'.format(user._id), + 'locations': [], + 'primary_location': None, } )