Skip to content

Commit

Permalink
pythongh-106905: Use separate structs to track recursion depth in eac…
Browse files Browse the repository at this point in the history
…h PyAST_mod2obj call. (pythonGH-113035)


Co-authored-by: Gregory P. Smith [Google LLC] <[email protected]>
  • Loading branch information
2 people authored and ryan-duve committed Dec 26, 2023
1 parent 32b8f23 commit 1838687
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 339 deletions.
2 changes: 0 additions & 2 deletions Include/internal/pycore_ast_state.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Use per AST-parser state rather than global state to track recursion depth
within the AST parser to prevent potential race condition due to
simultaneous parsing.

The issue primarily showed up in 3.11 by multithreaded users of
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
triggered prevented the race condition from occurring.
53 changes: 28 additions & 25 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
class PyTypesDeclareVisitor(PickleVisitor):

def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes:
Expand All @@ -752,7 +752,7 @@ def visitSum(self, sum, name):
ptype = "void*"
if is_simple(sum):
ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types:
self.visitConstructor(t, name)

Expand Down Expand Up @@ -984,15 +984,16 @@ def visitModule(self, mod):
/* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{
Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n);
PyObject *value;
if (!result)
return NULL;
for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) {
Py_DECREF(result);
return NULL;
Expand All @@ -1002,7 +1003,7 @@ def visitModule(self, mod):
return result;
}
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{
PyObject *op = (PyObject*)o;
if (!op) {
Expand All @@ -1014,7 +1015,7 @@ def visitModule(self, mod):
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{
return PyLong_FromLong(b);
}
Expand Down Expand Up @@ -1116,8 +1117,6 @@ def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
return 0;
}
'''))
Expand Down Expand Up @@ -1260,25 +1259,25 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name):
ctype = get_c_type(name)
self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1)
self.emit("PyTypeObject *tp;", 1)
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2)
self.emit("}", 1)

def func_end(self):
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1)
Expand All @@ -1296,15 +1295,15 @@ def visitSum(self, sum, name):
self.visitConstructor(t, i + 1, name)
self.emit("}", 1)
for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2)
self.emit('Py_DECREF(value);', 1)
self.func_end()

def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0)
self.emit("switch(o) {", 1)
for t in sum.types:
Expand All @@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
for field in prod.fields:
self.visitField(field, name, 1, True)
for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2)
Expand Down Expand Up @@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
Expand All @@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)


class PartingShots(StaticVisitor):
Expand All @@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
if (!tstate) {
return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);
PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) {
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result;
Expand Down Expand Up @@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
f.write(' PyObject *' + s + ';\n')
f.write('};')
Expand Down Expand Up @@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration
static int init_types(struct ast_state *state);
Expand Down
Loading

0 comments on commit 1838687

Please sign in to comment.