代码优化分为前端优化和后端优化。前端优化是指在源代码层面进行优化,例如语法树优化、中间代码优化等。后端优化是指在目标代码层面进行优化,例如机器码优化等。
Python AST 语法树优化的实现在文件 Python/astopt.c 中,主要是通过调用 PyASTOptimize 函数来实现。PyAST_Optimize 函数的实现如下:
int
_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
{
PyThreadState *tstate;
int recursion_limit = Py_GetRecursionLimit();
int starting_recursion_depth;
/* Setup recursion depth check counters */
tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
}
/* Be careful here to prevent overflow. */
starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
state->recursion_depth = starting_recursion_depth;
state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
int ret = astfold_mod(mod, arena, state);
assert(ret || PyErr_Occurred());
/* Check that the recursion depth counting balanced correctly */
if (ret && state->recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST optimizer recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
return 0;
}
return ret;
}
astfold_mod
函数是 AST 优化的核心函数,对 body,stmt,expr 等节点进行优化。其实现如下:
static int
astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
switch (node_->kind) {
case Module_kind:
CALL(astfold_body, asdl_seq, node_->v.Module.body);
break;
case Interactive_kind:
CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
break;
case Expression_kind:
CALL(astfold_expr, expr_ty, node_->v.Expression.body);
break;
// The following top level nodes don't participate in constant folding
case FunctionType_kind:
break;
// No default case, so the compiler will emit a warning if new top level
// compilation nodes are added without being handled here
}
return 1;
}
如果语法树节点的类型是 Modulekind,那么调用 astfoldbody 函数。如果语法树的类型是 Interactivekind,那么调用 astfoldstmt 函数。如果语法树的类型是 Expressionkind,那么调用 astfoldexpr 函数。
语法树的优化是通过将语法树的节点进行折叠(折叠无效,累赘,重复,可编译时计算的节点),从而减少语法树的节点,最终达到减少执行指令的目的。
astfold_body body 代码块折叠优化
static int
astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
{
int docstring = _PyAST_GetDocString(stmts) != NULL;
CALL_SEQ(astfold_stmt, stmt, stmts);
if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
if (!values) {
return 0;
}
asdl_seq_SET(values, 0, st->v.Expr.value);
expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
st->end_lineno, st->end_col_offset,
ctx_);
if (!expr) {
return 0;
}
st->v.Expr.value = expr;
}
return 1;
}
astfold_body
函数的作用是对语法树的 body 进行优化,如果 body 中包含 docstring,那么调用 PyASTGetDocString 函数获取 docstring;然后调用 astfoldstmt 函数对 body 进行优化;最后,如果 body 中包含 docstring,那么调用 _PyASTJoinedStr 函数对 docstring 进行优化。
body 里面的每一个节点都会调用 astfold_stmt 函数进行优化。这是一个递归的过程,直到所有的节点都被优化。
astfold_stmt 语句折叠优化
static int
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
switch (node_->kind) {
case FunctionDef_kind:
CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
}
break;
// ...
}
}
astfold_stmt
函数的作用是对语法树的 stmt 进行优化,如果 stmt 的类型是 FunctionDefkind,那么调用 astfoldarguments 函数对 args 进行优化,调用 astfoldbody 函数对 body 进行优化,调用 astfoldexpr 函数对 decoratorlist 进行优化,如果没有 COFUTUREANNOTATIONS,那么调用 astfoldexpr 函数对 returns 进行优化。
stmt 的优化是通过节点类型的数据结构进行逐步优化,这也是一个递归的过程,直到所有的节点都被优化。例如 if 包含 test, body, orelse 三个节点,test 是一个表达式,而 body 和 orelse 是一个语句序列(stmt),所以需要分别调用 astfoldexpr 和 astfoldstmt 函数进行优化。
case If_kind:
CALL(astfold_expr, expr_ty, node_->v.If.test);
CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
break;
astfold_expr 表达式折叠优化
static int
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
if (++state->recursion_depth > state->recursion_limit) {
PyErr_SetString(PyExc_RecursionError,
"maximum recursion depth exceeded during compilation");
return 0;
}
switch (node_->kind) {
case BoolOp_kind: // bool
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
break;
case BinOp_kind: // bin?
CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
CALL(fold_binop, expr_ty, node_);
break;
// ...
}
}
astfold_expr
函数的作用是对语法树的 expr(表达式)进行优化,如果 expr 的类型是 BinOpkind,那么调用 astfoldexpr 函数对 left 和 right 进行优化,然后调用 foldbinop 函数对 expr 进行优化。astfoldexpr 是一个递归的操作(我们可以不用关心这里的递归),最终调用 fold_binop 函数对 expr 进行优化。
简单了解一下 fold_binop 的实现
static int
fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
{
expr_ty lhs, rhs;
lhs = node->v.BinOp.left;
rhs = node->v.BinOp.right;
if (lhs->kind != Constant_kind || rhs->kind != Constant_kind) {
return 1;
}
PyObject *lv = lhs->v.Constant.value;
PyObject *rv = rhs->v.Constant.value;
PyObject *newval = NULL;
switch (node->v.BinOp.op) {
case Add:
newval = PyNumber_Add(lv, rv);
break;
// ... 其他操作符
}
}
foldbinop 是对表达式的优化,如果表达式的左右节点都是常量,那么就可以对表达式进行折叠优化。如 1 + 2,可以直接折叠成 3。这个操作是通过调用 PyNumberAdd 函数来实现数值的相加,那么最终的结果就是一个常量。
接下来我们验证这个优化,可以通过 dis 模块来查看 Python 代码的字节码。
import dis
code = """
a = 1 + 2
"""
dis.dis(code)
可以看到 1 + 2 被优化成了 3。调用 a 变成了 LOADCONST,将 3 加载到栈中,然后调用 STORENAME 将 3 存储到 a 中。输出字节码如下:
2 0 LOAD_CONST 0 (3)
2 STORE_NAME 0 (a)
4 LOAD_CONST 1 (None)
6 RETURN_VALUE
总结
- 语法树的优化是通过将语法树中的节点进行折叠,从而减少语法树的节点数,最终达到减少执行指令的步骤。这样可以提高 Python 代码的执行效率。
- 优化的是一个递归的过程,直到所有的节点都被优化。
- 语法树优化的最小单元是表达式。