Skip to content

Commit

Permalink
Merge pull request #35100 from dimagi/jc/update-location-through-user…
Browse files Browse the repository at this point in the history
…-api

Update location through user api
  • Loading branch information
jingcheng16 authored Sep 17, 2024
2 parents 9749254 + 2359fd5 commit 5353a42
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 6 deletions.
16 changes: 15 additions & 1 deletion corehq/apps/api/resources/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def detail_uri_kwargs(self, bundle_or_obj):


class CommCareUserResource(v0_1.CommCareUserResource):
primary_location = fields.CharField()
locations = fields.ListField()

class Meta(v0_1.CommCareUserResource.Meta):
detail_allowed_methods = ['get', 'put', 'delete']
Expand Down Expand Up @@ -320,10 +322,22 @@ def obj_delete(self, bundle, **kwargs):
deleted_via=USER_CHANGE_VIA_API)
return ImmediateHttpResponse(response=http.HttpAccepted())

def dehydrate_primary_location(self, bundle):
return bundle.obj.get_location_id(bundle.obj.domain)

def dehydrate_locations(self, bundle):
return bundle.obj.get_location_ids(bundle.obj.domain)

@classmethod
def _update(cls, bundle, user_change_logger=None):
errors = []
for key, value in bundle.data.items():

location_object = {'primary_location': bundle.data.pop('primary_location', None),
'locations': bundle.data.pop('locations', None)}

items_to_update = list(bundle.data.items()) + [('location', location_object)]

for key, value in items_to_update:
try:
update(bundle.obj, key, value, user_change_logger)
except UpdateUserException as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from corehq.apps.es.tests.utils import es_test
from corehq.apps.es.users import user_adapter
from corehq.apps.groups.models import Group
from corehq.apps.locations.models import LocationType, SQLLocation
from corehq.apps.users.analytics import update_analytics_indexes
from corehq.apps.users.audit.change_messages import UserChangeMessage
from corehq.apps.users.model_log import UserModelAction
Expand Down Expand Up @@ -70,6 +71,11 @@ def setUpClass(cls):
definition=cls.definition,
)
cls.profile.save()
cls.loc_type = LocationType.objects.create(domain=cls.domain.name, name='loc_type')
cls.loc1 = SQLLocation.objects.create(
location_id='loc1', location_type=cls.loc_type, domain=cls.domain.name)
cls.loc2 = SQLLocation.objects.create(
location_id='loc2', location_type=cls.loc_type, domain=cls.domain.name)

@classmethod
def tearDownClass(cls):
Expand All @@ -81,6 +87,8 @@ def test_get_list(self):
commcare_user = CommCareUser.create(domain=self.domain.name, username='fake_user', password='*****',
created_by=None, created_via=None)
self.addCleanup(commcare_user.delete, self.domain.name, deleted_by=None)
commcare_user.set_location(self.loc1)
commcare_user.set_location(self.loc2)
backend_id = commcare_user.get_id
update_analytics_indexes()

Expand All @@ -100,15 +108,22 @@ def test_get_list(self):
'last_name': '',
'phone_numbers': [],
'resource_uri': '/a/qwerty/api/v0.5/user/{}/'.format(backend_id),
'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': ''},
'username': 'fake_user'
'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': '',
'commcare_location_id': self.loc2.location_id,
'commcare_primary_case_sharing_id': self.loc2.location_id,
'commcare_location_ids': f'{self.loc1.location_id} {self.loc2.location_id}'},
'username': 'fake_user',
'primary_location': self.loc2.location_id,
'locations': [self.loc1.location_id, self.loc2.location_id]
})

@flaky
def test_get_single(self):
commcare_user = CommCareUser.create(domain=self.domain.name, username='fake_user', password='*****',
created_by=None, created_via=None)
self.addCleanup(commcare_user.delete, self.domain.name, deleted_by=None)
commcare_user.set_location(self.loc1)
commcare_user.set_location(self.loc2)
backend_id = commcare_user._id

response = self._assert_auth_get_resource(self.single_endpoint(backend_id))
Expand All @@ -126,8 +141,15 @@ def test_get_single(self):
'last_name': '',
'phone_numbers': [],
'resource_uri': '/a/qwerty/api/v0.5/user/{}/'.format(backend_id),
'user_data': {'commcare_project': 'qwerty', PROFILE_SLUG: '', 'imaginary': ''},
'user_data': {'commcare_project': 'qwerty',
PROFILE_SLUG: '',
'imaginary': '',
'commcare_location_id': self.loc2.location_id,
'commcare_primary_case_sharing_id': self.loc2.location_id,
'commcare_location_ids': f'{self.loc1.location_id} {self.loc2.location_id}'},
'username': 'fake_user',
'primary_location': self.loc2.location_id,
'locations': [self.loc1.location_id, self.loc2.location_id]
})

def test_create(self):
Expand All @@ -153,7 +175,10 @@ def test_create(self):
],
"user_data": {
"chw_id": "13/43/DFA"
}
},
'locations': [self.loc1.location_id, self.loc2.location_id],
'primary_location': self.loc1.location_id

}
response = self._assert_auth_post_resource(self.list_endpoint,
json.dumps(user_json),
Expand All @@ -171,6 +196,9 @@ def test_create(self):
self.assertEqual(user_back.get_group_ids()[0], group._id)
self.assertEqual(user_back.get_user_data(self.domain.name)["chw_id"], "13/43/DFA")
self.assertEqual(user_back.default_phone_number, "50253311399")
self.assertEqual(user_back.get_location_ids(self.domain.name),
[self.loc1.location_id, self.loc2.location_id])
self.assertEqual(user_back.get_location_id(self.domain.name), self.loc1.location_id)

@flag_enabled('COMMCARE_CONNECT')
def test_create_connect_user_no_password(self):
Expand Down Expand Up @@ -266,7 +294,9 @@ def test_update(self):
PROFILE_SLUG: self.profile.id,
"chw_id": "13/43/DFA"
},
"password": "qwerty1234"
"password": "qwerty1234",
'locations': [self.loc1.location_id, self.loc2.location_id],
'primary_location': self.loc1.location_id
}

backend_id = user._id
Expand All @@ -288,6 +318,9 @@ def test_update(self):
self.assertEqual(user_data.profile_id, self.profile.id)
self.assertEqual(user_data["imaginary"], "yes")
self.assertEqual(modified.default_phone_number, "50253311399")
self.assertEqual(modified.get_location_ids(self.domain.name),
[self.loc1.location_id, self.loc2.location_id])
self.assertEqual(modified.get_location_id(self.domain.name), self.loc1.location_id)

# test user history audit
user_history = UserHistory.objects.get(action=UserModelAction.UPDATE.value,
Expand All @@ -300,6 +333,8 @@ def test_update(self):
'last_name': 'last',
'first_name': 'test',
'user_data': {'chw_id': '13/43/DFA'},
'location_id': 'loc1',
'assigned_location_ids': ['loc1', 'loc2']
}
)
self.assertTrue("50253311398" in
Expand Down
76 changes: 76 additions & 0 deletions corehq/apps/api/tests/test_user_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
)
from corehq.apps.domain.shortcuts import create_domain
from corehq.apps.groups.models import Group
from corehq.apps.locations.models import LocationType, SQLLocation
from corehq.apps.user_importer.helpers import UserChangeLogger
from corehq.apps.users.audit.change_messages import (
GROUPS_FIELD,
LOCATION_FIELD,
PASSWORD_FIELD,
ROLE_FIELD,
)
Expand All @@ -31,6 +33,11 @@ def setUpClass(cls):
cls.domain = 'test-domain'
cls.domain_obj = create_domain(cls.domain)
cls.addClassCleanup(cls.domain_obj.delete)
cls.loc_type = LocationType.objects.create(domain=cls.domain, name='loc_type')
cls.loc1 = SQLLocation.objects.create(
location_id='loc1', location_type=cls.loc_type, domain=cls.domain)
cls.loc2 = SQLLocation.objects.create(
location_id='loc2', location_type=cls.loc_type, domain=cls.domain)

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -190,6 +197,48 @@ def _setup_profile(self):
profile.save()
return profile.id

def test_update_locations_raises_if_primary_location_not_in_location_list(self):
with self.assertRaises(UpdateUserException) as e:
update(self.user, 'location',
{'primary_location': self.loc2.location_id, 'locations': [self.loc1.location_id]})

self.assertEqual(str(e.exception), 'Primary location must be included in the list of locations.')

def test_update_locations_raises_if_any_location_does_not_exist(self):
with self.assertRaises(UpdateUserException) as e:
update(self.user, 'location',
{'primary_location': 'fake_loc', 'locations': [self.loc1.location_id, 'fake_loc']})
self.assertEqual(str(e.exception), "Could not find location ids: fake_loc.")

def test_update_locations_raises_if_primary_location_not_provided(self):
with self.assertRaises(UpdateUserException) as e:
update(self.user, 'location', {'locations': [self.loc1.location_id]})
self.assertEqual(str(e.exception), 'Both primary_location and locations must be provided together.')

def test_update_locations_raises_if_locations_not_provided(self):
with self.assertRaises(UpdateUserException) as e:
update(self.user, 'location', {'primary_location': self.loc1.location_id})
self.assertEqual(str(e.exception), 'Both primary_location and locations must be provided together.')

def test_update_locations_do_nothing_if_nothing_provided(self):
self.user.set_location(self.loc1)
update(self.user, 'location', {'primary_location': None, 'locations': None})
self.assertEqual(self.user.get_location_ids(self.domain), [self.loc1.location_id])
self.assertEqual(self.user.get_location_id(self.domain), self.loc1.location_id)

def test_update_locations_removes_locations_if_empty_string_provided(self):
self.user.set_location(self.loc1)
update(self.user, 'location', {'primary_location': '', 'locations': []})
self.assertEqual(self.user.get_location_ids(self.domain), [])
self.assertEqual(self.user.get_location_id(self.domain), None)

def test_update_locations_succeeds(self):
update(self.user, 'location',
{'primary_location': self.loc1.location_id,
'locations': [self.loc1.location_id, self.loc2.location_id]})
self.assertEqual(self.user.get_location_ids(self.domain), [self.loc1.location_id, self.loc2.location_id])
self.assertEqual(self.user.get_location_id(self.domain), self.loc1.location_id)


class TestUpdateUserMethodsLogChanges(TestCase):

Expand All @@ -199,6 +248,11 @@ def setUpClass(cls):
cls.domain = 'test-domain'
cls.domain_obj = create_domain(cls.domain)
cls.addClassCleanup(cls.domain_obj.delete)
cls.loc_type = LocationType.objects.create(domain=cls.domain, name='loc_type')
cls.loc1 = SQLLocation.objects.create(
location_id='loc1', location_type=cls.loc_type, domain=cls.domain)
cls.loc2 = SQLLocation.objects.create(
location_id='loc2', location_type=cls.loc_type, domain=cls.domain)

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -332,3 +386,25 @@ def test_update_user_role_does_not_logs_change(self):
update(self.user, 'role', 'edit-data', user_change_logger=self.user_change_logger)

self.assertNotIn(ROLE_FIELD, self.user_change_logger.change_messages.keys())

def test_update_location_logs_change(self):
update(self.user, 'location',
{'primary_location': self.loc1.location_id,
'locations': [self.loc1.location_id, self.loc2.location_id]},
user_change_logger=self.user_change_logger)
self.assertIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys())

def test_update_location_without_include_location_fields_does_not_log_change(self):
update(self.user, 'location',
{'primary_location': None, 'locations': None},
user_change_logger=self.user_change_logger)
self.assertNotIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys())

def test_update_location_with_current_locations_does_not_log_change(self):
self.user.set_location(self.loc2)
self.user.set_location(self.loc1)
update(self.user, 'location',
{'primary_location': self.loc1.location_id,
'locations': [self.loc1.location_id, self.loc2.location_id]},
user_change_logger=self.user_change_logger)
self.assertNotIn(LOCATION_FIELD, self.user_change_logger.change_messages.keys())
68 changes: 68 additions & 0 deletions corehq/apps/api/user_updates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
from corehq.apps.locations.models import SQLLocation

from dimagi.utils.couch.bulk import get_docs

Expand Down Expand Up @@ -35,6 +36,7 @@ def update(user, field, value, user_change_logger=None):
'phone_numbers': _update_phone_numbers,
'user_data': _update_user_data,
'role': _update_user_role,
'location': _update_location,
}.get(field)

if not update_fn:
Expand Down Expand Up @@ -165,3 +167,69 @@ def _log_phone_number_change(new_phone_numbers, old_phone_numbers, user_change_l

if change_messages:
user_change_logger.add_change_message({'phone_numbers': change_messages})


def _update_location(user, location_object, user_change_logger):
primary_location_id = location_object.get('primary_location')
location_ids = location_object.get('locations')

if primary_location_id is None and location_ids is None:
return

current_primary_location_id = user.get_location_id(user.domain)
current_locations = user.get_location_ids(user.domain)

if not primary_location_id and not location_ids:
_remove_all_locations(user, user_change_logger)
else:
if _validate_locations(primary_location_id, location_ids):
locations = _verify_location_ids(location_ids, user.domain)
if primary_location_id != current_primary_location_id:
_update_primary_location(user, primary_location_id, user_change_logger)
if set(current_locations) != set(location_ids):
_update_assigned_locations(user, locations, location_ids, user_change_logger)


def _validate_locations(primary_location_id, location_ids):
if not primary_location_id and not location_ids:
return False
if not primary_location_id or not location_ids:
raise UpdateUserException(_('Both primary_location and locations must be provided together.'))
if primary_location_id not in location_ids:
raise UpdateUserException(_('Primary location must be included in the list of locations.'))
return True


def _remove_all_locations(user, user_change_logger):
user.unset_location(commit=False)
user.reset_locations([], commit=False)
if user_change_logger:
user_change_logger.add_changes({'location_id': None})
user_change_logger.add_info(UserChangeMessage.primary_location_removed())
user_change_logger.add_changes({'assigned_location_ids': []})
user_change_logger.add_info(UserChangeMessage.assigned_locations_info([]))


def _update_primary_location(user, primary_location_id, user_change_logger):
primary_location = SQLLocation.active_objects.get(location_id=primary_location_id)
user.set_location(primary_location, commit=False)
if user_change_logger:
user_change_logger.add_changes({'location_id': primary_location_id})
user_change_logger.add_info(UserChangeMessage.primary_location_info(primary_location))


def _verify_location_ids(location_ids, domain):
locations = SQLLocation.active_objects.filter(location_id__in=location_ids, domain=domain)
real_ids = [loc.location_id for loc in locations]

if missing_ids := set(location_ids) - set(real_ids):
raise UpdateUserException(f"Could not find location ids: {', '.join(missing_ids)}.")

return locations


def _update_assigned_locations(user, locations, location_ids, user_change_logger):
user.reset_locations(location_ids, commit=False)
if user_change_logger:
user_change_logger.add_changes({'assigned_location_ids': location_ids})
user_change_logger.add_info(UserChangeMessage.assigned_locations_info(locations))
2 changes: 2 additions & 0 deletions corehq/motech/repeaters/tests/test_repeater.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,8 @@ def test_trigger(self):
'email': '',
'eulas': '[]',
'resource_uri': '/a/user-repeater/api/v0.5/user/{}/'.format(user._id),
'locations': [],
'primary_location': None,
}
)

Expand Down

0 comments on commit 5353a42

Please sign in to comment.