From 2f16c98253f24553c6c9e5bed78bb5d24922390c Mon Sep 17 00:00:00 2001 From: nggit Date: Wed, 15 Nov 2023 20:58:43 +0700 Subject: [PATCH] release 0.0.301 (#32) * customizable exception page with @app.error(500) * implement app_handler_timeout * more efficient ASGILifespan impl. --------- Co-authored-by: nggit <12218311+nggit@users.noreply.github.com> --- setup.py | 2 +- tests/http_server.py | 4 +- tests/test_http_client.py | 12 ++++ tremolo/__init__.py | 2 +- tremolo/__main__.py | 10 +++- tremolo/asgi_lifespan.py | 49 ++++++++-------- tremolo/asgi_server.py | 39 +++++-------- tremolo/handlers.py | 11 ++++ tremolo/lib/http_protocol.py | 108 +++++++++++++++++++++++------------ tremolo/tremolo.py | 9 ++- 10 files changed, 155 insertions(+), 91 deletions(-) diff --git a/setup.py b/setup.py index 761d061..e29756b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name='tremolo', - version='0.0.300', + version='0.0.301', license='MIT', author='nggit', author_email='contact@anggit.com', diff --git a/tests/http_server.py b/tests/http_server.py index c4cf4c8..f62dc75 100644 --- a/tests/http_server.py +++ b/tests/http_server.py @@ -270,6 +270,8 @@ async def timeouts(request=None, **_): # attempt to read body on a GET request # should raise a TimeoutError and ended up with a RequestTimeout await request.recv(100) + elif request.query_string == b'handler': + await asyncio.sleep(10) @app.route('/reload') @@ -284,7 +286,7 @@ async def reload(request=None, **_): # test multiple ports app.listen(HTTP_PORT + 1, request_timeout=2, keepalive_timeout=2) -app.listen(HTTP_PORT + 2) +app.listen(HTTP_PORT + 2, app_handler_timeout=1) # test unix socket # 'tremolo-test.sock' diff --git a/tests/test_http_client.py b/tests/test_http_client.py index ccf9e26..58fbff7 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -501,6 +501,18 @@ def test_recvtimeout(self): b'HTTP/1.1 408 Request Timeout') self.assertEqual(body, b'Request Timeout') + def test_handlertimeout(self): + header, body = getcontents( + host=HTTP_HOST, + port=HTTP_PORT + 2, + raw=b'GET /timeouts?handler HTTP/1.1\r\n' + b'Host: localhost:%d\r\n\r\n' % (HTTP_PORT + 2) + ) + + self.assertEqual(header[:header.find(b'\r\n')], + b'HTTP/1.1 500 Internal Server Error') + self.assertEqual(body, b'Internal Server Error') + def test_download_10(self): header, body = getcontents(host=HTTP_HOST, port=HTTP_PORT + 2, diff --git a/tremolo/__init__.py b/tremolo/__init__.py index 75f964f..2b36d3f 100644 --- a/tremolo/__init__.py +++ b/tremolo/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.300' +__version__ = '0.0.301' from .tremolo import Tremolo # noqa: E402 from . import exceptions # noqa: E402,F401 diff --git a/tremolo/__main__.py b/tremolo/__main__.py index f436ddc..e4899a8 100644 --- a/tremolo/__main__.py +++ b/tremolo/__main__.py @@ -50,6 +50,12 @@ print(' --keepalive-timeout Defaults to 30 (seconds)') print(' --keepalive-connections Maximum number of keep-alive connections') # noqa: E501 print(' Defaults to 512 (connections/worker)') # noqa: E501 + print(' --app-handler-timeout Kill the app if it takes too long to finish') # noqa: E501 + print(' Upgraded connection/scope will not be affected') # noqa: E501 + print(' Defaults to 120 (seconds)') + print(' --app-close-timeout Kill the app if it does not exit within this timeframe,') # noqa: E501 + print(' from when the client is disconnected') # noqa: E501 + print(' Defaults to 30 (seconds)') print(' --server-name Set the "Server" field in the response header') # noqa: E501 print(' --root-path Set the ASGI root_path. Defaults to ""') # noqa: E501 print(' --help Show this help and exit') @@ -73,7 +79,9 @@ '--client-max-header-size', '--request-timeout', '--keepalive-timeout', - '--keepalive-connections'): + '--keepalive-connections', + '--app-handler-timeout', + '--app-close-timeout'): try: options[sys.argv[i - 1].lstrip('-').replace('-', '_')] = int(sys.argv[i]) # noqa: E501 except ValueError: diff --git a/tremolo/asgi_lifespan.py b/tremolo/asgi_lifespan.py index 16d0da0..9e687bc 100644 --- a/tremolo/asgi_lifespan.py +++ b/tremolo/asgi_lifespan.py @@ -10,26 +10,26 @@ def __init__(self, app, **kwargs): self._loop = kwargs['loop'] self._logger = kwargs['logger'] - scope = { - 'type': 'lifespan', - 'asgi': {'version': '3.0'} - } - self._queue = asyncio.Queue() - self._task = self._loop.create_task( - app(scope, self.receive, self.send) - ) - self._complete = False + self._waiter = self._loop.create_future() + self._task = self._loop.create_task(self.main(app)) - def startup(self): - self._complete = False + async def main(self, app): + try: + scope = { + 'type': 'lifespan', + 'asgi': {'version': '3.0'} + } + await app(scope, self.receive, self.send) + finally: + self._waiter.cancel() + + def startup(self): self._queue.put_nowait({'type': 'lifespan.startup'}) self._logger.info('lifespan: startup') def shutdown(self): - self._complete = False - self._queue.put_nowait({'type': 'lifespan.shutdown'}) self._logger.info('lifespan: shutdown') @@ -42,7 +42,7 @@ async def receive(self): async def send(self, data): if data['type'] in ('lifespan.startup.complete', 'lifespan.shutdown.complete'): - self._complete = True + self._waiter.set_result(None) self._logger.info(data['type']) elif data['type'] in ('lifespan.startup.failed', 'lifespan.shutdown.failed'): @@ -56,10 +56,14 @@ async def send(self, data): raise LifespanProtocolUnsupported async def exception(self, timeout=30): - for _ in range(timeout): - if self._complete: - return + timer = self._loop.call_at(self._loop.time() + timeout, + self._waiter.cancel) + + try: + await self._waiter + self._waiter = self._loop.create_future() + except asyncio.CancelledError: try: exc = self._task.exception() @@ -72,12 +76,9 @@ async def exception(self, timeout=30): else: self._logger.info( '%s: %s' % (LifespanProtocolUnsupported.message, - str(exc)) + str(exc) or repr(exc)) ) - - return except asyncio.InvalidStateError: - await asyncio.sleep(1) - - if not self._complete: - self._logger.warning('lifespan: timeout after %gs' % timeout) + self._logger.warning('lifespan: timeout after %gs' % timeout) + finally: + timer.cancel() diff --git a/tremolo/asgi_server.py b/tremolo/asgi_server.py index 57e4a57..274ee2a 100644 --- a/tremolo/asgi_server.py +++ b/tremolo/asgi_server.py @@ -26,22 +26,16 @@ class ASGIServer(HTTPProtocol): - __slots__ = ('_app', - '_scope', + __slots__ = ('_scope', '_read', - '_task', '_timer', - '_timeout', '_websocket', '_http_chunked') - def __init__(self, _app=None, **kwargs): - self._app = _app + def __init__(self, **kwargs): self._scope = None self._read = None - self._task = None self._timer = None - self._timeout = 30 self._websocket = None self._http_chunked = None @@ -88,37 +82,32 @@ async def header_received(self): await self._handle_http() self._read = self.request.stream() - self._task = self.loop.create_task(self.app()) + self.handler = self.loop.create_task(self.main()) def connection_lost(self, exc): - if (self._task is not None and not self._task.done() and - self._timer is None): - self._timer = self.loop.call_at(self.loop.time() + self._timeout, - self._task.cancel) + if self.handler is not None and not self.handler.done(): + self._set_app_close_timeout() super().connection_lost(exc) - async def app(self): + async def main(self): try: - await self._app(self._scope, self.receive, self.send) + await self.options['_app'](self._scope, self.receive, self.send) if self._timer is not None: self._timer.cancel() - except asyncio.CancelledError: - self.logger.warning( - 'task: ASGI application is cancelled due to timeout' - ) - except Exception as exc: + except (asyncio.CancelledError, Exception) as exc: if (self.request is not None and self.request.upgraded and self._websocket is not None): exc = WebSocketServerClosed(cause=exc) await self.handle_exception(exc) - def _set_app_timeout(self): + def _set_app_close_timeout(self): if self._timer is None: self._timer = self.loop.call_at( - self.loop.time() + self._timeout, self._task.cancel + self.loop.time() + self.options['_app_close_timeout'], + self.handler.cancel ) async def receive(self): @@ -152,7 +141,7 @@ async def receive(self): if self.request is not None: self.print_exception(exc) - self._set_app_timeout() + self._set_app_close_timeout() return { 'type': 'websocket.disconnect', 'code': code @@ -176,12 +165,12 @@ async def receive(self): self.request.body_size < self.request.content_length ) } - except Exception as exc: + except (asyncio.CancelledError, Exception) as exc: if not (self.request is None or isinstance(exc, StopAsyncIteration)): self.print_exception(exc) - self._set_app_timeout() + self._set_app_close_timeout() return {'type': 'http.disconnect'} async def send(self, data): diff --git a/tremolo/handlers.py b/tremolo/handlers.py index 04c3daf..44e2dc0 100644 --- a/tremolo/handlers.py +++ b/tremolo/handlers.py @@ -1,5 +1,7 @@ # Copyright (c) 2023 nggit +import traceback + from .exceptions import BadRequest from .utils import html_escape @@ -27,3 +29,12 @@ async def error_404(request=None, **_): b'
%s
' b'' % request.protocol.options['server_info']['name'] ) + + +async def error_500(request=None, exc=None, **_): + if request.protocol.options['debug']: + return '' % '
  • '.join( + traceback.TracebackException.from_exception(exc).format() + ) + + return str(exc) diff --git a/tremolo/lib/http_protocol.py b/tremolo/lib/http_protocol.py index c315134..d6bd17b 100644 --- a/tremolo/lib/http_protocol.py +++ b/tremolo/lib/http_protocol.py @@ -1,7 +1,6 @@ # Copyright (c) 2023 nggit import asyncio -import traceback from urllib.parse import quote, unquote @@ -30,6 +29,7 @@ class HTTPProtocol(asyncio.Protocol): '_request', '_response', '_watermarks', + 'handler', '_header_buf', '_waiters') @@ -45,6 +45,7 @@ def __init__(self, context, loop=None, logger=None, worker=None, **kwargs): self._response = None self._watermarks = {'high': 65536, 'low': 8192} + self.handler = None self._header_buf = None self._waiters = {} @@ -102,6 +103,22 @@ def connection_made(self, transport): timeout_cb=self.request_timeout)) ) + def abort(self, exc=None): + if (self._transport is not None and + not self._transport.is_closing()): + self._transport.abort() + + if exc: + self.print_exception(exc, 'abort') + + def close(self): + if (self._transport is not None and + not self._transport.is_closing()): + if self._transport.can_write_eof(): + self._transport.write_eof() + + self._transport.close() + async def request_timeout(self, timeout): self._logger.info('request timeout after %gs' % timeout) @@ -122,9 +139,7 @@ async def set_timeout(self, waiter, timeout=30, timeout_cb=None): if callable(timeout_cb): await timeout_cb(timeout) finally: - if (self._transport is not None and - not self._transport.is_closing()): - self._transport.abort() + self.abort() finally: timer.cancel() @@ -171,6 +186,17 @@ async def put_to_queue( async def header_received(self): return + def handler_timeout(self): + if (self._request is None or self._request.upgraded or + self.handler is None): + return + + self.handler.cancel() + self._logger.error('handler timeout after %gs. consider increasing ' + 'the value of app_handler_timeout' % + self._options['app_handler_timeout']) + self.handler = None + def print_exception(self, exc, *args): self._logger.error( ': '.join((*args, exc.__class__.__name__, str(exc))), @@ -181,6 +207,8 @@ async def handle_exception(self, exc): if (self._request is None or self._response is None or (self._response.headers_sent() and not self._request.upgraded)): + # it's here for redundancy + self.abort(exc) return self.print_exception( @@ -205,30 +233,35 @@ async def handle_exception(self, exc): elif not isinstance(exc, HTTPException): exc = InternalServerError(cause=exc) - encoding = 'utf-8' + if self._request is not None and self._response is not None: + self._response.set_status(exc.code, exc.message) + self._response.set_content_type(exc.content_type) + data = b'' + + try: + data = await self._options['_routes'][0][-1][1]( + request=self._request, response=self._response, exc=exc) + + if data is None: + data = b'' + finally: + if isinstance(data, str): + encoding = 'utf-8' - for v in exc.content_type.split(';'): - v = v.lstrip() + for v in exc.content_type.split(';'): + v = v.lstrip() - if v.startswith('charset='): - charset = v[len('charset='):].strip() + if v.startswith('charset='): + charset = v[len('charset='):].strip() - if charset != '': - encoding = charset + if charset != '': + encoding = charset - break + break - if self._options['debug']: - data = b'' % '
  • '.join( - traceback.TracebackException.from_exception(exc).format() - ).encode(encoding) - else: - data = str(exc).encode(encoding) + data = data.encode(encoding) - if self._response is not None: - self._response.set_status(exc.code, exc.message) - self._response.set_content_type(exc.content_type) - await self._response.end(data, keepalive=False) + await self._response.end(data, keepalive=False) async def _handle_request_header(self, data, header_size): header = ParseHeader(data, @@ -296,8 +329,17 @@ async def _handle_request_header(self, data, header_size): 'keepalive') and not fut.done(): fut.set_result(None) - await self.header_received() - except Exception as exc: + self.handler = self._loop.create_task(self.header_received()) + timer = self._loop.call_at( + self._loop.time() + self._options['app_handler_timeout'], + self.handler_timeout) + + try: + await self.handler + finally: + if self._options['_app'] is None: + timer.cancel() + except (asyncio.CancelledError, Exception) as exc: await self.handle_exception(exc) async def _receive_data(self, data, waiter): @@ -338,11 +380,11 @@ def data_received(self, data): self._header_buf = None elif header_size > self._options['client_max_header_size']: self._logger.info('request header too large') - self._transport.abort() + self.abort() elif not (header_size == -1 and len(self._header_buf) <= self._options['client_max_header_size']): self._logger.info('bad request') - self._transport.abort() + self.abort() return @@ -390,10 +432,7 @@ async def _send_data(self): self._request.clear_body() - if self._transport.can_write_eof(): - self._transport.write_eof() - - self._transport.close() + self.close() return # send data @@ -422,9 +461,7 @@ async def _send_data(self): except asyncio.CancelledError: pass except Exception as exc: - if self._transport is not None: - self._transport.abort() - self.print_exception(exc) + self.abort(exc) def _handle_keepalive(self): if 'request' in self._waiters: @@ -432,10 +469,7 @@ def _handle_keepalive(self): self._options['_connections'][self] = None if self not in self._options['_connections']: - if self._transport.can_write_eof(): - self._transport.write_eof() - - self._transport.close() + self.close() self._logger.info( 'a keepalive connection is kicked out of the list' ) diff --git a/tremolo/tremolo.py b/tremolo/tremolo.py index 8e7c6d2..c3a5804 100644 --- a/tremolo/tremolo.py +++ b/tremolo/tremolo.py @@ -35,7 +35,8 @@ def __init__(self): 0: [ (400, handlers.error_400, {}), (404, handlers.error_404, dict(status=(404, b'Not Found'), - stream=False)) + stream=False)), + (500, handlers.error_500, {}) ], 1: [ ( @@ -336,10 +337,16 @@ async def _serve(self, host, port, **options): keepalive_timeout=options.get( 'keepalive_timeout', 30 ), + app_handler_timeout=options.get( + 'app_handler_timeout', 120 + ), server_info=server_info, _connections=connections, _pools=pools, _app=options['app'], + _app_close_timeout=options.get( + 'app_close_timeout', 30 + ), _root_path=options.get('root_path', ''), _routes=options['_routes'], _middlewares=options['_middlewares']),