|
| 1 | +"""HTTP server |
| 2 | +
|
| 3 | +:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. |
| 4 | +:license: http://www.apache.org/licenses/LICENSE-2.0.html |
| 5 | +""" |
| 6 | + |
| 7 | +from pykern.pkcollections import PKDict |
| 8 | +from pykern.pkdebug import pkdc, pkdlog, pkdp, pkdexc, pkdformat |
| 9 | +import inspect |
| 10 | +import msgpack |
| 11 | +import pykern.pkasyncio |
| 12 | +import pykern.pkcollections |
| 13 | +import pykern.pkconfig |
| 14 | +import pykern.quest |
| 15 | +import re |
| 16 | +import tornado.web |
| 17 | + |
| 18 | + |
| 19 | +#: Http auth header name |
| 20 | +_AUTH_HEADER = "Authorization" |
| 21 | + |
| 22 | +#: http auth header scheme bearer |
| 23 | +_AUTH_HEADER_SCHEME_BEARER = "Bearer" |
| 24 | + |
| 25 | +#: POSIT: Matches anything generated by `unique_key` |
| 26 | +_UNIQUE_KEY_CHARS_RE = r"\w+" |
| 27 | + |
| 28 | +#: Regex to test format of auth header and extract token |
| 29 | +_AUTH_HEADER_RE = re.compile( |
| 30 | + _AUTH_HEADER_SCHEME_BEARER + r"\s+(" + _UNIQUE_KEY_CHARS_RE + ")", |
| 31 | + re.IGNORECASE, |
| 32 | +) |
| 33 | + |
| 34 | +_CONTENT_TYPE_HEADER = "Content-Type" |
| 35 | +_CONTENT_TYPE = "application/msgpack" |
| 36 | + |
| 37 | +_VERSION_HEADER = "X-PyKern-HTTP-Version" |
| 38 | + |
| 39 | +_VERSION_HEADER_VALUE = "1" |
| 40 | + |
| 41 | +_API_NAME_RE = re.compile(rf"^{pykern.quest.API.METHOD_PREFIX}(\w+)") |
| 42 | + |
| 43 | + |
| 44 | +def server_start(api_classes, attr_classes, http_config): |
| 45 | + l = pykern.pkasyncio.Loop() |
| 46 | + _HTTPServer(l, api_classes, attr_classes, http_config) |
| 47 | + l.start() |
| 48 | + |
| 49 | + |
| 50 | +class Reply: |
| 51 | + |
| 52 | + def __init__(self, result=None, exc=None, api_error=None): |
| 53 | + def _exception(exc): |
| 54 | + if exc is None: |
| 55 | + pkdlog("ERROR: no reply and no exception") |
| 56 | + return 500 |
| 57 | + if isinstance(exc, NotFound): |
| 58 | + return 404 |
| 59 | + if isinstance(exc, Forbidden): |
| 60 | + return 403 |
| 61 | + pkdlog("untranslated exception={}", exc) |
| 62 | + return 500 |
| 63 | + |
| 64 | + if isinstance(result, Reply): |
| 65 | + self.http_status = result.http_status |
| 66 | + self.content = result.content |
| 67 | + elif result is not None or api_error is not None: |
| 68 | + self.http_status = 200 |
| 69 | + self.content = PKDict( |
| 70 | + api_error=api_error, |
| 71 | + api_result=result, |
| 72 | + ) |
| 73 | + else: |
| 74 | + self.http_status = _exception(exc) |
| 75 | + self.content = None |
| 76 | + |
| 77 | + |
| 78 | +class ReplyExc(Exception): |
| 79 | + """Raised to end the request. |
| 80 | +
|
| 81 | + Args: |
| 82 | + pk_args (dict): exception args that specific to this module |
| 83 | + log_fmt (str): server side log data |
| 84 | + """ |
| 85 | + |
| 86 | + def __init__(self, *args, **kwargs): |
| 87 | + super().__init__() |
| 88 | + if "pk_args" in kwargs: |
| 89 | + self.pk_args = kwargs["pk_args"] |
| 90 | + del kwargs["pk_args"] |
| 91 | + else: |
| 92 | + self.pk_args = PKDict() |
| 93 | + if args or kwargs: |
| 94 | + kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back |
| 95 | + pkdlog(*args, **kwargs) |
| 96 | + |
| 97 | + def __repr__(self): |
| 98 | + a = self.pk_args |
| 99 | + return "{}({})".format( |
| 100 | + self.__class__.__name__, |
| 101 | + ",".join( |
| 102 | + ("{}={}".format(k, a[k]) for k in sorted(a.keys())), |
| 103 | + ), |
| 104 | + ) |
| 105 | + |
| 106 | + def __str__(self): |
| 107 | + return self.__repr__() |
| 108 | + |
| 109 | + |
| 110 | +class APIError(ReplyExc): |
| 111 | + """Raised by server/client for application level errors""" |
| 112 | + |
| 113 | + def __init__(self, api_error_fmt, *args, **kwargs): |
| 114 | + super().__init__( |
| 115 | + pk_args=PKDict(api_error=pkdformat(api_error_fmt, *args, **kwargs)), |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +class Forbidden(ReplyExc): |
| 120 | + """Raised for forbidden or protocol error""" |
| 121 | + |
| 122 | + pass |
| 123 | + |
| 124 | + |
| 125 | +class InvalidResponse(ReplyExc): |
| 126 | + """Raised when the reply is invalid (client)""" |
| 127 | + |
| 128 | + pass |
| 129 | + |
| 130 | + |
| 131 | +class NotFound(ReplyExc): |
| 132 | + """Raised for an object not found""" |
| 133 | + |
| 134 | + pass |
| 135 | + |
| 136 | + |
| 137 | +class HTTPClient: |
| 138 | + def __init__(self, http_config): |
| 139 | + self._uri = ( |
| 140 | + f"http://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" |
| 141 | + ) |
| 142 | + self._headers = PKDict( |
| 143 | + { |
| 144 | + _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {_auth_secret(http_config.auth_secret)}", |
| 145 | + _CONTENT_TYPE_HEADER: _CONTENT_TYPE, |
| 146 | + _VERSION_HEADER: _VERSION_HEADER_VALUE, |
| 147 | + } |
| 148 | + ) |
| 149 | + self._tornado = tornado.httpclient.AsyncHTTPClient(force_instance=True) |
| 150 | + |
| 151 | + async def post(self, api_name, api_arg): |
| 152 | + r = await self._tornado.fetch( |
| 153 | + self._uri, |
| 154 | + body=_pack_msg(PKDict(api_name=api_name, api_arg=api_arg)), |
| 155 | + headers=self._headers, |
| 156 | + method="POST", |
| 157 | + ) |
| 158 | + rv, e = _unpack_msg(r) |
| 159 | + if e: |
| 160 | + raise InvalidResponse(*e) |
| 161 | + if rv.api_error: |
| 162 | + raise APIError( |
| 163 | + "api_error={} api_name={} api_arg={}", rv.api_error, api_name, api_arg |
| 164 | + ) |
| 165 | + return rv.api_result |
| 166 | + |
| 167 | + |
| 168 | +class _HTTPRequestHandler(tornado.web.RequestHandler): |
| 169 | + def initialize(self, server): |
| 170 | + self.pykern_http_server = server |
| 171 | + |
| 172 | + async def get(self): |
| 173 | + await self.pykern_http_server.dispatch(self) |
| 174 | + |
| 175 | + async def post(self): |
| 176 | + await self.pykern_http_server.dispatch(self) |
| 177 | + |
| 178 | + |
| 179 | +class _HTTPServer: |
| 180 | + |
| 181 | + def __init__(self, loop, api_classes, attr_classes, http_config): |
| 182 | + def _api_class_funcs(): |
| 183 | + for c in api_classes: |
| 184 | + for n, o in inspect.getmembers(c, predicate=inspect.isfunction): |
| 185 | + yield PKDict(api_class=c, api_func=o, api_func_name=n) |
| 186 | + |
| 187 | + def _api_map(): |
| 188 | + rv = PKDict() |
| 189 | + for a in _api_class_funcs(): |
| 190 | + n = a.api_func_name |
| 191 | + if not ((m := _API_NAME_RE.search(n)) and n in a.api_class.__dict__): |
| 192 | + continue |
| 193 | + a.api_name = m.group(1) |
| 194 | + if a.api_name in rv: |
| 195 | + raise AssertionError( |
| 196 | + "duplicate api={a.api_name} class={a.api_class.__name__}" |
| 197 | + ) |
| 198 | + if not inspect.iscoroutinefunction(a.pkdel("api_func")): |
| 199 | + raise AssertionError( |
| 200 | + "api_func={n} is not async class={a.api_class.__name__}" |
| 201 | + ) |
| 202 | + rv[a.api_name] = a |
| 203 | + return rv |
| 204 | + |
| 205 | + h = http_config.copy() |
| 206 | + self.loop = loop |
| 207 | + self.api_map = _api_map() |
| 208 | + self.attr_classes = attr_classes |
| 209 | + self.auth_secret = _auth_secret(h.pkdel("auth_secret")) |
| 210 | + h.uri_map = ((h.api_uri, _HTTPRequestHandler, PKDict(server=self)),) |
| 211 | + self.api_uri = h.pkdel("api_uri") |
| 212 | + loop.http_server(h) |
| 213 | + |
| 214 | + async def dispatch(self, handler): |
| 215 | + async def _call(api, api_arg): |
| 216 | + with pykern.quest.start(api.api_class, self.attr_classes) as qcall: |
| 217 | + return await getattr(qcall, api.api_func_name)(api_arg) |
| 218 | + |
| 219 | + m = None |
| 220 | + try: |
| 221 | + try: |
| 222 | + self.loop.http_log(handler, "start") |
| 223 | + self._authenticate(handler) |
| 224 | + m, e = _unpack_msg(handler.request) |
| 225 | + if e: |
| 226 | + raise Forbidden(*e) |
| 227 | + if not (a := self.api_map.get(m.api_name)): |
| 228 | + raise NotFound("unknown api={}", m.api_name) |
| 229 | + r = Reply(result=await _call(a, m.api_arg)) |
| 230 | + except APIError as e: |
| 231 | + r = Reply(api_error=e.pk_args.api_error) |
| 232 | + except Exception as e: |
| 233 | + self.loop.http_log( |
| 234 | + handler, |
| 235 | + "error", |
| 236 | + fmt="exception={} msg={} stack={}", |
| 237 | + args=[e, m, pkdexc()], |
| 238 | + ) |
| 239 | + r = Reply(exc=e) |
| 240 | + self._send_reply(handler, r) |
| 241 | + except Exception as e: |
| 242 | + pkdlog("unhandled exception={} stack={}", e, pkdexc()) |
| 243 | + raise |
| 244 | + |
| 245 | + def _authenticate(self, handler): |
| 246 | + def _token(headers): |
| 247 | + if not (h := headers.get(_AUTH_HEADER)): |
| 248 | + return None |
| 249 | + if m := _AUTH_HEADER_RE.search(h): |
| 250 | + return m.group(1) |
| 251 | + return None |
| 252 | + |
| 253 | + if handler.request.method != "POST": |
| 254 | + raise Forbidden() |
| 255 | + if t := _token(handler.request.headers): |
| 256 | + if t == self.auth_secret: |
| 257 | + return |
| 258 | + raise Forbidden("token mismatch") |
| 259 | + raise Forbidden("no token") |
| 260 | + |
| 261 | + def _send_reply(self, handler, reply): |
| 262 | + if (c := reply.content) is None: |
| 263 | + m = b"" |
| 264 | + else: |
| 265 | + m = _pack_msg(c) |
| 266 | + handler.set_header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) |
| 267 | + handler.set_header(_VERSION_HEADER, _VERSION_HEADER_VALUE) |
| 268 | + handler.set_header("Content-Length", str(len(m))) |
| 269 | + handler.set_status(reply.http_status) |
| 270 | + handler.write(m) |
| 271 | + |
| 272 | + |
| 273 | +def _auth_secret(value): |
| 274 | + if value: |
| 275 | + if len(value) < 16: |
| 276 | + raise AssertionError("secret too short len={len(value)} (<16)") |
| 277 | + return value |
| 278 | + if pykern.pkconfig.in_dev_mode(): |
| 279 | + return "default_dev_secret" |
| 280 | + raise AssertionError("must supply http_config.auth_secret") |
| 281 | + |
| 282 | + |
| 283 | +def _pack_msg(content): |
| 284 | + p = msgpack.Packer(autoreset=False) |
| 285 | + p.pack(content) |
| 286 | + # TODO(robnagler) getbuffer() would be better |
| 287 | + return p.bytes() |
| 288 | + |
| 289 | + |
| 290 | +def _unpack_msg(request): |
| 291 | + def _header(name, value): |
| 292 | + if not (v := request.headers.get(name)): |
| 293 | + return ("missing header={}", name) |
| 294 | + if v != value: |
| 295 | + return ("unexpected {}={}", name, c) |
| 296 | + return None |
| 297 | + |
| 298 | + if e := ( |
| 299 | + _header(_VERSION_HEADER, _VERSION_HEADER_VALUE) |
| 300 | + or _header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) |
| 301 | + ): |
| 302 | + return None, e |
| 303 | + u = msgpack.Unpacker( |
| 304 | + object_pairs_hook=pykern.pkcollections.object_pairs_hook, |
| 305 | + ) |
| 306 | + u.feed(request.body) |
| 307 | + return u.unpack(), None |
0 commit comments