diff --git a/consulate/__init__.py b/consulate/__init__.py index 2f1f12f..2a56375 100644 --- a/consulate/__init__.py +++ b/consulate/__init__.py @@ -5,11 +5,13 @@ from consulate.client import Consul from consulate.exceptions import (ConsulateException, + ClientError, ServerError, ACLDisabled, Forbidden, NotFound, - LockFailure) + LockFailure, + RequestError) import logging from logging import NullHandler @@ -24,9 +26,11 @@ __version__, Consul, ConsulateException, + ClientError, ServerError, ACLDisabled, Forbidden, NotFound, - LockFailure + LockFailure, + RequestError ] diff --git a/consulate/adapters.py b/consulate/adapters.py index c7141e0..44e8bf5 100644 --- a/consulate/adapters.py +++ b/consulate/adapters.py @@ -70,8 +70,7 @@ def delete(self, uri): """ LOGGER.debug("DELETE %s", uri) - return self._process_response( - self.session.delete(uri, timeout=self.timeout)) + return api.Response(self.session.delete(uri, timeout=self.timeout)) def get(self, uri, timeout=None): """Perform a HTTP get @@ -84,9 +83,9 @@ def get(self, uri, timeout=None): """ LOGGER.debug("GET %s", uri) try: - return self._process_response( - self.session.get(uri, timeout=timeout or self.timeout)) - except (requests.exceptions.ConnectionError, + return api.Response(self.session.get( + uri, timeout=timeout or self.timeout)) + except (requests.exceptions.RequestException, OSError, socket.error) as err: raise exceptions.RequestError(str(err)) @@ -100,15 +99,12 @@ def get_stream(self, uri): LOGGER.debug("GET Stream from %s", uri) try: response = self.session.get(uri, stream=True) - except (requests.exceptions.ConnectionError, + except (requests.exceptions.RequestException, OSError, socket.error) as err: raise exceptions.RequestError(str(err)) - if response.encoding is None: - response.encoding = 'utf-8' if utils.response_ok(response): - for line in response.iter_lines(): - if line: - yield line.decode('utf-8') + for line in response.iter_lines(): # pragma: no cover + yield line.decode('utf-8') @prepare_data def put(self, uri, data=None, timeout=None): @@ -127,31 +123,16 @@ def put(self, uri, data=None, timeout=None): if utils.is_string(data) else CONTENT_JSON } try: - return self._process_response( + return api.Response( self.session.put( uri, data=data, headers=headers, timeout=timeout or self.timeout)) - except (requests.exceptions.ConnectionError, + except (requests.exceptions.RequestException, OSError, socket.error) as err: raise exceptions.RequestError(str(err)) - @staticmethod - def _process_response(response): - """Build an api.Response object based upon the requests response - object. - - :param requests.response response: The requests response - :rtype: consulate.api.Response - - """ - try: - return api.Response( - response.status_code, response.content, response.headers) - except (requests.exceptions.HTTPError, OSError, socket.error) as err: - raise exceptions.RequestError(str(err)) - -class UnixSocketRequest(Request): +class UnixSocketRequest(Request): # pragma: no cover """Use to communicate with Consul over a Unix socket""" def __init__(self, timeout=None): diff --git a/consulate/api/base.py b/consulate/api/base.py index 6a1ffb9..46ac6ff 100644 --- a/consulate/api/base.py +++ b/consulate/api/base.py @@ -6,7 +6,7 @@ import json try: from urllib.parse import urlencode # Python 3 -except ImportError: # pragma: no cover +except ImportError: from urllib import urlencode # Python 2 from consulate import utils @@ -106,7 +106,7 @@ def _put_no_response_body(self, url_parts, query=None, payload=None): def _put_response_body(self, url_parts, query=None, payload=None): response = self._adapter.put( - self._build_uri(url_parts, query), payload) + self._build_uri(url_parts, query), data=payload) if utils.response_ok(response): return response.body @@ -123,17 +123,15 @@ class Response(object): body = None headers = None - def __init__(self, status_code, body, headers): + def __init__(self, response): """Create a new instance of the Response class. - :param int status_code: HTTP Status code - :param str body: The response body - :param dict headers: Response headers + :param requests.response response: The requests response """ - self.status_code = status_code - self.body = self._demarshal(body) - self.headers = headers + self.status_code = response.status_code + self.body = self._demarshal(response.content) + self.headers = response.headers def _demarshal(self, body): """Demarshal the request payload. diff --git a/tests/acl_tests.py b/tests/acl_tests.py index e83353a..ffad0b0 100644 --- a/tests/acl_tests.py +++ b/tests/acl_tests.py @@ -5,9 +5,11 @@ import json import uuid -import consulate import httmock +import consulate +from consulate import exceptions + from . import base ACL_RULES = """key "" { @@ -25,6 +27,16 @@ class TestCase(base.TestCase): def uuidv4(): return str(uuid.uuid4()) + def test_bootstrap_request_exception(self): + + @httmock.all_requests + def response_content(_url_unused, _request): + raise OSError + + with httmock.HTTMock(response_content): + with self.assertRaises(exceptions.RequestError): + self.consul.acl.bootstrap() + def test_bootstrap_success(self): expectation = self.uuidv4() @@ -96,6 +108,11 @@ def test_info_acl_id_not_found(self): with self.assertRaises(consulate.NotFound): self.forbidden_consul.acl.info(self.uuidv4()) + def test_list_request_exception(self): + with httmock.HTTMock(base.raise_oserror): + with self.assertRaises(exceptions.RequestError): + self.consul.acl.list() + def test_replication(self): result = self.forbidden_consul.acl.replication() self.assertFalse(result['Enabled']) diff --git a/tests/agent_tests.py b/tests/agent_tests.py index 54a99ad..c46bc44 100644 --- a/tests/agent_tests.py +++ b/tests/agent_tests.py @@ -4,6 +4,8 @@ """ import uuid +import httmock + import consulate from consulate import utils from consulate.models import agent @@ -56,9 +58,17 @@ def test_metrics_forbidden(self): self.forbidden_consul.agent.metrics() def test_monitor(self): - for line in self.consul.agent.monitor(): + for offset, line in enumerate(self.consul.agent.monitor()): self.assertTrue(utils.is_string(line)) - break + self.consul.agent.metrics() + if offset > 1: + break + + def test_monitor_request_exception(self): + with httmock.HTTMock(base.raise_oserror): + with self.assertRaises(consulate.RequestError): + for _line in self.consul.agent.monitor(): + break def test_monitor_forbidden(self): with self.assertRaises(consulate.Forbidden): diff --git a/tests/base.py b/tests/base.py index 4868362..714abff 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,6 +4,8 @@ import unittest import uuid +import httmock + import consulate from consulate import exceptions @@ -21,6 +23,11 @@ def _decorator(self, *args, **kwargs): return _decorator +@httmock.all_requests +def raise_oserror(_url_unused, _request): + raise OSError + + class TestCase(unittest.TestCase): def setUp(self): self.consul = consulate.Consul(