Skip to content

Commit 8a060d3

Browse files
authored
Fix #509 added pykern.http for quest-based API (#510)
- wrapped around pkasyncio and pykern.quest - pykern.quest.Attr and API are "real" (simplified from sirepo) - Fix #374 copied logging from sirepo.uri_router - pkconst: remove py2/setup.py cruft
1 parent 038f507 commit 8a060d3

File tree

7 files changed

+443
-826
lines changed

7 files changed

+443
-826
lines changed

pykern/http.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
"""HTTP server
2+
3+
:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved.
4+
:license: http://www.apache.org/licenses/LICENSE-2.0.html
5+
"""
6+
7+
from pykern.pkcollections import PKDict
8+
from pykern.pkdebug import pkdc, pkdlog, pkdp, pkdexc, pkdformat
9+
import inspect
10+
import msgpack
11+
import pykern.pkasyncio
12+
import pykern.pkcollections
13+
import pykern.pkconfig
14+
import pykern.quest
15+
import re
16+
import tornado.web
17+
18+
19+
#: Http auth header name
20+
_AUTH_HEADER = "Authorization"
21+
22+
#: http auth header scheme bearer
23+
_AUTH_HEADER_SCHEME_BEARER = "Bearer"
24+
25+
#: POSIT: Matches anything generated by `unique_key`
26+
_UNIQUE_KEY_CHARS_RE = r"\w+"
27+
28+
#: Regex to test format of auth header and extract token
29+
_AUTH_HEADER_RE = re.compile(
30+
_AUTH_HEADER_SCHEME_BEARER + r"\s+(" + _UNIQUE_KEY_CHARS_RE + ")",
31+
re.IGNORECASE,
32+
)
33+
34+
_CONTENT_TYPE_HEADER = "Content-Type"
35+
_CONTENT_TYPE = "application/msgpack"
36+
37+
_VERSION_HEADER = "X-PyKern-HTTP-Version"
38+
39+
_VERSION_HEADER_VALUE = "1"
40+
41+
_API_NAME_RE = re.compile(rf"^{pykern.quest.API.METHOD_PREFIX}(\w+)")
42+
43+
44+
def server_start(api_classes, attr_classes, http_config):
45+
l = pykern.pkasyncio.Loop()
46+
_HTTPServer(l, api_classes, attr_classes, http_config)
47+
l.start()
48+
49+
50+
class Reply:
51+
52+
def __init__(self, result=None, exc=None, api_error=None):
53+
def _exception(exc):
54+
if exc is None:
55+
pkdlog("ERROR: no reply and no exception")
56+
return 500
57+
if isinstance(exc, NotFound):
58+
return 404
59+
if isinstance(exc, Forbidden):
60+
return 403
61+
pkdlog("untranslated exception={}", exc)
62+
return 500
63+
64+
if isinstance(result, Reply):
65+
self.http_status = result.http_status
66+
self.content = result.content
67+
elif result is not None or api_error is not None:
68+
self.http_status = 200
69+
self.content = PKDict(
70+
api_error=api_error,
71+
api_result=result,
72+
)
73+
else:
74+
self.http_status = _exception(exc)
75+
self.content = None
76+
77+
78+
class ReplyExc(Exception):
79+
"""Raised to end the request.
80+
81+
Args:
82+
pk_args (dict): exception args that specific to this module
83+
log_fmt (str): server side log data
84+
"""
85+
86+
def __init__(self, *args, **kwargs):
87+
super().__init__()
88+
if "pk_args" in kwargs:
89+
self.pk_args = kwargs["pk_args"]
90+
del kwargs["pk_args"]
91+
else:
92+
self.pk_args = PKDict()
93+
if args or kwargs:
94+
kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back
95+
pkdlog(*args, **kwargs)
96+
97+
def __repr__(self):
98+
a = self.pk_args
99+
return "{}({})".format(
100+
self.__class__.__name__,
101+
",".join(
102+
("{}={}".format(k, a[k]) for k in sorted(a.keys())),
103+
),
104+
)
105+
106+
def __str__(self):
107+
return self.__repr__()
108+
109+
110+
class APIError(ReplyExc):
111+
"""Raised by server/client for application level errors"""
112+
113+
def __init__(self, api_error_fmt, *args, **kwargs):
114+
super().__init__(
115+
pk_args=PKDict(api_error=pkdformat(api_error_fmt, *args, **kwargs)),
116+
)
117+
118+
119+
class Forbidden(ReplyExc):
120+
"""Raised for forbidden or protocol error"""
121+
122+
pass
123+
124+
125+
class InvalidResponse(ReplyExc):
126+
"""Raised when the reply is invalid (client)"""
127+
128+
pass
129+
130+
131+
class NotFound(ReplyExc):
132+
"""Raised for an object not found"""
133+
134+
pass
135+
136+
137+
class HTTPClient:
138+
def __init__(self, http_config):
139+
self._uri = (
140+
f"http://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}"
141+
)
142+
self._headers = PKDict(
143+
{
144+
_AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {_auth_secret(http_config.auth_secret)}",
145+
_CONTENT_TYPE_HEADER: _CONTENT_TYPE,
146+
_VERSION_HEADER: _VERSION_HEADER_VALUE,
147+
}
148+
)
149+
self._tornado = tornado.httpclient.AsyncHTTPClient(force_instance=True)
150+
151+
async def post(self, api_name, api_arg):
152+
r = await self._tornado.fetch(
153+
self._uri,
154+
body=_pack_msg(PKDict(api_name=api_name, api_arg=api_arg)),
155+
headers=self._headers,
156+
method="POST",
157+
)
158+
rv, e = _unpack_msg(r)
159+
if e:
160+
raise InvalidResponse(*e)
161+
if rv.api_error:
162+
raise APIError(
163+
"api_error={} api_name={} api_arg={}", rv.api_error, api_name, api_arg
164+
)
165+
return rv.api_result
166+
167+
168+
class _HTTPRequestHandler(tornado.web.RequestHandler):
169+
def initialize(self, server):
170+
self.pykern_http_server = server
171+
172+
async def get(self):
173+
await self.pykern_http_server.dispatch(self)
174+
175+
async def post(self):
176+
await self.pykern_http_server.dispatch(self)
177+
178+
179+
class _HTTPServer:
180+
181+
def __init__(self, loop, api_classes, attr_classes, http_config):
182+
def _api_class_funcs():
183+
for c in api_classes:
184+
for n, o in inspect.getmembers(c, predicate=inspect.isfunction):
185+
yield PKDict(api_class=c, api_func=o, api_func_name=n)
186+
187+
def _api_map():
188+
rv = PKDict()
189+
for a in _api_class_funcs():
190+
n = a.api_func_name
191+
if not ((m := _API_NAME_RE.search(n)) and n in a.api_class.__dict__):
192+
continue
193+
a.api_name = m.group(1)
194+
if a.api_name in rv:
195+
raise AssertionError(
196+
"duplicate api={a.api_name} class={a.api_class.__name__}"
197+
)
198+
if not inspect.iscoroutinefunction(a.pkdel("api_func")):
199+
raise AssertionError(
200+
"api_func={n} is not async class={a.api_class.__name__}"
201+
)
202+
rv[a.api_name] = a
203+
return rv
204+
205+
h = http_config.copy()
206+
self.loop = loop
207+
self.api_map = _api_map()
208+
self.attr_classes = attr_classes
209+
self.auth_secret = _auth_secret(h.pkdel("auth_secret"))
210+
h.uri_map = ((h.api_uri, _HTTPRequestHandler, PKDict(server=self)),)
211+
self.api_uri = h.pkdel("api_uri")
212+
loop.http_server(h)
213+
214+
async def dispatch(self, handler):
215+
async def _call(api, api_arg):
216+
with pykern.quest.start(api.api_class, self.attr_classes) as qcall:
217+
return await getattr(qcall, api.api_func_name)(api_arg)
218+
219+
m = None
220+
try:
221+
try:
222+
self.loop.http_log(handler, "start")
223+
self._authenticate(handler)
224+
m, e = _unpack_msg(handler.request)
225+
if e:
226+
raise Forbidden(*e)
227+
if not (a := self.api_map.get(m.api_name)):
228+
raise NotFound("unknown api={}", m.api_name)
229+
r = Reply(result=await _call(a, m.api_arg))
230+
except APIError as e:
231+
r = Reply(api_error=e.pk_args.api_error)
232+
except Exception as e:
233+
self.loop.http_log(
234+
handler,
235+
"error",
236+
fmt="exception={} msg={} stack={}",
237+
args=[e, m, pkdexc()],
238+
)
239+
r = Reply(exc=e)
240+
self._send_reply(handler, r)
241+
except Exception as e:
242+
pkdlog("unhandled exception={} stack={}", e, pkdexc())
243+
raise
244+
245+
def _authenticate(self, handler):
246+
def _token(headers):
247+
if not (h := headers.get(_AUTH_HEADER)):
248+
return None
249+
if m := _AUTH_HEADER_RE.search(h):
250+
return m.group(1)
251+
return None
252+
253+
if handler.request.method != "POST":
254+
raise Forbidden()
255+
if t := _token(handler.request.headers):
256+
if t == self.auth_secret:
257+
return
258+
raise Forbidden("token mismatch")
259+
raise Forbidden("no token")
260+
261+
def _send_reply(self, handler, reply):
262+
if (c := reply.content) is None:
263+
m = b""
264+
else:
265+
m = _pack_msg(c)
266+
handler.set_header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE)
267+
handler.set_header(_VERSION_HEADER, _VERSION_HEADER_VALUE)
268+
handler.set_header("Content-Length", str(len(m)))
269+
handler.set_status(reply.http_status)
270+
handler.write(m)
271+
272+
273+
def _auth_secret(value):
274+
if value:
275+
if len(value) < 16:
276+
raise AssertionError("secret too short len={len(value)} (<16)")
277+
return value
278+
if pykern.pkconfig.in_dev_mode():
279+
return "default_dev_secret"
280+
raise AssertionError("must supply http_config.auth_secret")
281+
282+
283+
def _pack_msg(content):
284+
p = msgpack.Packer(autoreset=False)
285+
p.pack(content)
286+
# TODO(robnagler) getbuffer() would be better
287+
return p.bytes()
288+
289+
290+
def _unpack_msg(request):
291+
def _header(name, value):
292+
if not (v := request.headers.get(name)):
293+
return ("missing header={}", name)
294+
if v != value:
295+
return ("unexpected {}={}", name, c)
296+
return None
297+
298+
if e := (
299+
_header(_VERSION_HEADER, _VERSION_HEADER_VALUE)
300+
or _header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE)
301+
):
302+
return None, e
303+
u = msgpack.Unpacker(
304+
object_pairs_hook=pykern.pkcollections.object_pairs_hook,
305+
)
306+
u.feed(request.body)
307+
return u.unpack(), None

0 commit comments

Comments
 (0)