Skip to content

Commit

Permalink
pythongh-106905: Use a struct allocated on the stack to track recursi…
Browse files Browse the repository at this point in the history
…on depths in PyAST_mod2obj.
  • Loading branch information
yilei committed Dec 12, 2023
1 parent 81a15ea commit d1eb0fd
Show file tree
Hide file tree
Showing 3 changed files with 405 additions and 339 deletions.
2 changes: 0 additions & 2 deletions Include/internal/pycore_ast_state.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 28 additions & 25 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
class PyTypesDeclareVisitor(PickleVisitor):

def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes:
Expand All @@ -752,7 +752,7 @@ def visitSum(self, sum, name):
ptype = "void*"
if is_simple(sum):
ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types:
self.visitConstructor(t, name)

Expand Down Expand Up @@ -984,15 +984,16 @@ def visitModule(self, mod):
/* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{
Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n);
PyObject *value;
if (!result)
return NULL;
for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) {
Py_DECREF(result);
return NULL;
Expand All @@ -1002,7 +1003,7 @@ def visitModule(self, mod):
return result;
}
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{
PyObject *op = (PyObject*)o;
if (!op) {
Expand All @@ -1014,7 +1015,7 @@ def visitModule(self, mod):
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{
return PyLong_FromLong(b);
}
Expand Down Expand Up @@ -1116,8 +1117,6 @@ def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
return 0;
}
'''))
Expand Down Expand Up @@ -1260,25 +1259,25 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name):
ctype = get_c_type(name)
self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1)
self.emit("PyTypeObject *tp;", 1)
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2)
self.emit("}", 1)

def func_end(self):
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1)
Expand All @@ -1296,15 +1295,15 @@ def visitSum(self, sum, name):
self.visitConstructor(t, i + 1, name)
self.emit("}", 1)
for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2)
self.emit('Py_DECREF(value);', 1)
self.func_end()

def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0)
self.emit("switch(o) {", 1)
for t in sum.types:
Expand All @@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
for field in prod.fields:
self.visitField(field, name, 1, True)
for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2)
Expand Down Expand Up @@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
Expand All @@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)


class PartingShots(StaticVisitor):
Expand All @@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
if (!tstate) {
return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);
PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) {
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result;
Expand Down Expand Up @@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
f.write(' PyObject *' + s + ';\n')
f.write('};')
Expand Down Expand Up @@ -1538,6 +1536,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration
static int init_types(struct ast_state *state);
Expand Down
Loading

0 comments on commit d1eb0fd

Please sign in to comment.