Python AST

"Python AST(Abstract Syntax Trees)"

Posted by Stephen on April 6, 2020

前言

在了解Python的AST(Abstract Syntax Trees)抽象语法树之前,先了解Python源码和Python AST的关系, 官方网址

Python官方提供的CPython解释器对python源码的处理过程如下:

  1. Parse source code into a parse tree (Parser/pgen.c)
  2. Transform parse tree into an Abstract Syntax Tree (Python/ast.c)
  3. Transform AST into a Control Flow Graph (Python/compile.c)
  4. Emit bytecode based on the Control Flow Graph (Python/compile.c)

即实际python代码的处理过程如下:

源代码解析 –> 语法树 –> 抽象语法树(AST) –> 控制流程图 –> 字节码 -> 执行

环境

系统环境

Distributor ID:	Ubuntu
Description:	Ubuntu 18.04.4 LTS
Release:	18.04
Codename:	bionic
Linux version :       5.3.0-46-generic ( buildd@lcy01-amd64-013 ) 
Gcc version:         7.5.0  ( Ubuntu 7.5.0-3ubuntu1~18.04 )

软件信息

version : 	
     Python 3.7

正文

1. 简介

python源码首先被解析成语法树,随后又转换成抽象语法树。在抽象语法树中我们可以看到源码文件中的python的语法结构。

大部分时间编程可能都不需要用到抽象语法树,但是在特定的条件和需求的情况下,AST又有其特殊的方便性。

下面是一个抽象语法的简单实例。

Module(body=[
    Print(
          dest=None,
          values=[BinOp(left=Num(n=1),op=Add(),right=Num(n=2))],
          nl=True,
 )])                                

2. 创建AST

2.1 Compile函数

先简单了解一下compile函数。

compile(source, filename, mode[, flags[, dont_inherit]])

  • source – 字符串或者AST(Abstract Syntax Trees)对象。一般可将整个py文件内容file.read()传入。
  • filename – 代码文件名称,如果不是从文件读取代码则传递一些可辨认的值。
  • mode – 指定编译代码的种类。可以指定为 exec, eval, single。
  • flags – 变量作用域,局部命名空间,如果被提供,可以是任何映射对象。
  • flags和dont_inherit是用来控制编译源码时的标志。
func_def = \
"""
def add(x, y):
    return x + y
print add(3, 5)
"""

使用Compile编译并执行:

>>> cm = compile(func_def, '<string>', 'exec')
>>> exec cm
>>> 8

上面func_def经过compile编译得到字节码,cm即code对象,True == isinstance(cm, types.CodeType)。

compile(source, filename, mode, ast.PyCF_ONLY_AST) <==> ast.parse(source, *filename=’'*, *mode='exec'*)

2.2 生成ast

使用上面的func_def生成ast.

r_node = ast.parse(func_def)
print(astunparse.dump(r_node))    # print(ast.dump(r_node))

下面是func_def对应的ast结构:

Module(body=[
    FunctionDef(
        name='add',
        args=arguments(
            args=[Name(id='x',ctx=Param()),Name(id='y',ctx=Param())],
            vararg=None,
            kwarg=None,
            defaults=[]),
        body=[Return(value=BinOp(
            left=Name(id='x',ctx=Load()),
            op=Add(),
            right=Name(id='y',ctx=Load())))],
        decorator_list=[]),
    Print(
        dest=None,
        values=[Call(
                func=Name(id='add',ctx=Load()),
                args=[Num(n=3),Num(n=5)],
                keywords=[],
                starargs=None,
                kwargs=None)],
        nl=True)
  ])

除了ast.dump,有很多dump ast的第三方库,如astunparse, codegen, unparse等。这些第三方库不仅能够以更好的方式展示出ast结构,还能够将ast反向导出python source代码。

module Python version "$Revision$"
{
  mod = Module(stmt* body)| Expression(expr body)

  stmt = FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list)
        | ClassDef(identifier name, expr* bases, stmt* body, expr* decorator_list)
        | Return(expr? value)
        | Print(expr? dest, expr* values, bool nl)| For(expr target, expr iter, stmt* body, stmt* orelse)

  expr = BoolOp(boolop op, expr* values)
       | BinOp(expr left, operator op, expr right)| Lambda(arguments args, expr body)| Dict(expr* keys, expr* values)| Num(object n) -- a number as a PyObject.
       | Str(string s) -- need to specify raw, unicode, etc?| Name(identifier id, expr_context ctx)
       | List(expr* elts, expr_context ctx) 
        -- col_offset is the byte offset in the utf8 string the parser uses
        attributes (int lineno, int col_offset)

  expr_context = Load | Store | Del | AugLoad | AugStore | Param
  boolop = And | Or 
  operator = Add | Sub | Mult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv
  arguments = (expr* args, identifier? vararg, identifier? kwarg, expr* defaults)
}

上面是部分摘自官网的 Abstract Grammar,实际遍历ast Node过程中根据Node的类型访问其属性。

3. 遍历AST

python提供了两种方式来遍历整个抽象语法树。

3.1 ast.NodeTransfer

将func_def中的add函数中的加法运算改为减法,同时为函数实现添加调用日志。

 1 class CodeVisitor(ast.NodeVisitor):
 2     def visit_BinOp(self, node):
 3         if isinstance(node.op, ast.Add):
 4             node.op = ast.Sub()
 5         self.generic_visit(node)
 6 
 7     def visit_FunctionDef(self, node):
 8         print 'Function Name:%s'% node.name
 9         self.generic_visit(node)
10         func_log_stmt = ast.Print(
11             dest = None,
12             values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
13             nl = True,
14             lineno = 0,
15             col_offset = 0,
16         )
17         node.body.insert(0, func_log_stmt)
18 
19 r_node = ast.parse(func_def)
20 visitor = CodeVisitor()
21 visitor.visit(r_node)
22 # print astunparse.dump(r_node)
23 print astunparse.unparse(r_node)
24 exec compile(r_node, '<string>', 'exec')

运行结果:

Function Name:add
def add(x, y):
    print 'calling func: add'
    return (x - y)
print add(3, 5)
calling func: add
-2

3.2 ast.NodeTransformer

使用NodeVisitor主要是通过修改语法树上节点的方式改变AST结构,NodeTransformer主要是替换ast中的节点。

既然func_def中定义的add已经被改成一个减函数了,那么我们就彻底一点,把函数名和参数以及被调用的函数都在ast中改掉,并且将添加的函数调用log写的更加复杂一些,争取改的面目全非:-)

 1 class CodeTransformer(ast.NodeTransformer):
 2     def visit_BinOp(self, node):
 3         if isinstance(node.op, ast.Add):
 4             node.op = ast.Sub()
 5         self.generic_visit(node)
 6         return node
 7 
 8     def visit_FunctionDef(self, node):
 9         self.generic_visit(node)
10         if node.name == 'add':
11             node.name = 'sub'
12         args_num = len(node.args.args)
13         args = tuple([arg.id for arg in node.args.args])
14         func_log_stmt = ''.join(["print 'calling func: %s', " % node.name, "'args:'", ", %s" * args_num % args])
15         node.body.insert(0, ast.parse(func_log_stmt))
16         return node
17 
18     def visit_Name(self, node):
19         replace = {'add': 'sub', 'x': 'a', 'y': 'b'}
20         re_id = replace.get(node.id, None)
21         node.id = re_id or node.id22         self.generic_visit(node)
23         return node
24 
25 r_node = ast.parse(func_def)
26 transformer = CodeTransformer()
27 r_node = transformer.visit(r_node)
28 # print astunparse.dump(r_node)
29 source = astunparse.unparse(r_node)
30 print source
31 # exec compile(r_node, '<string>', 'exec')        # 新加入的node func_log_stmt 缺少lineno和col_offset属性
32 exec compile(source, '<string>', 'exec')
33 exec compile(ast.parse(source), '<string>', 'exec')

结果:

def sub(a, b):
    print 'calling func: sub', 'args:', a, b
    return (a - b)
print sub(3, 5)
calling func: sub args: 3 5
-2
calling func: sub args: 3 5
-2

代码中能够清楚的看到两者的区别。这里不再赘述。

4. AST应用

AST模块实际编程中很少用到,但是作为一种源代码辅助检查手段是非常有意义的;语法检查,调试错误,特殊字段检测等。

上面通过为函数添加调用日志的信息是一种调试python源代码的一种方式,不过实际中我们是通过parse整个python文件的方式遍历修改源码。

4.1 汉字检测

下面是中日韩字符的unicode编码范围

CJK Unified Ideographs

  • Range: 4E00— 9FFF
  • Number of characters: 20992
  • Languages: chinese, japanese, korean, vietnamese

使用 unicode 范围 \u4e00 - \u9fff 来判别汉字,注意这个范围并不包含中文字符(e.g. u’;’ == u’\uff1b’) .

下面是一个判断字符串中是否包含中文字符的一个类CNCheckHelper:

 1 class CNCheckHelper(object):
 2     # 待检测文本可能的编码方式列表
 3     VALID_ENCODING = ('utf-8', 'gbk')
 4 
 5     def _get_unicode_imp(self, value, idx = 0):
 6         if idx < len(self.VALID_ENCODING):
 7             try:
 8                 return value.decode(self.VALID_ENCODING[idx])
 9             except:
10                 return self._get_unicode_imp(value, idx + 1)
11 
12     def _get_unicode(self, from_str):
13         if isinstance(from_str, unicode):
14             return None
15         return self._get_unicode_imp(from_str)
16 
17     def is_any_chinese(self, check_str, is_strict = True):
18         unicode_str = self._get_unicode(check_str)
19         if unicode_str:
20             c_func = any if is_strict else all
21             return c_func(u'\u4e00' <= char <= u'\u9fff' for char in unicode_str)
22         return False

接口is_any_chinese有两种判断模式,严格检测只要包含中文字符串就可以检查出,非严格必须全部包含中文。

下面我们利用ast来遍历源文件的抽象语法树,并检测其中字符串是否包含中文字符。

 1 class CodeCheck(ast.NodeVisitor):
 2 
 3     def __init__(self):
 4         self.cn_checker = CNCheckHelper()
 5 
 6     def visit_Str(self, node):
 7         self.generic_visit(node)
 8         # if node.s and any(u'\u4e00' <= char <= u'\u9fff' for char in node.s.decode('utf-8')):
 9         if self.cn_checker.is_any_chinese(node.s, True):
10             print 'line no: %d, column offset: %d, CN_Str: %s' % (node.lineno, node.col_offset, node.s)
11 
12 project_dir = './your_project/script'
13 for root, dirs, files in os.walk(project_dir):
14     print root, dirs, files
15     py_files = filter(lambda file: file.endswith('.py'), files)
16     checker = CodeCheck()
17     for file in py_files:
18         file_path = os.path.join(root, file)
19         print 'Checking: %s' % file_path
20         with open(file_path, 'r') as f:
21             root_node = ast.parse(f.read())
22             checker.visit(root_node)

上面这个例子比较的简单,但大概就是这个意思。

关于CPython解释器执行源码的过程可以参考官网描述:PEP 339

4.2 Closure 检查

一个函数中定义的函数或者lambda中引用了父函数中的local variable,并且当做返回值返回。特定场景下闭包是非常有用的,但是也很容易被误用。

关于python闭包的概念可以参考我的另一篇文章:

这里简单介绍一下如何借助ast来检测lambda中闭包的引用。代码如下:

 1 class LambdaCheck(ast.NodeVisitor):
 2 
 3     def __init__(self):
 4         self.illegal_args_list = []
 5         self._cur_file = None
 6         self._cur_lambda_args = []
 7 
 8     def set_cur_file(self, cur_file):
 9         assert os.path.isfile(cur_file) cur_file
10         self._cur_file = os.path.realpath(cur_file)
11 
12     def visit_Lambda(self, node):
13         """
14         lambda 闭包检查原则:
15         只需检测lambda expr body中args是否引用了lambda args list之外的参数
16         """
17         self._cur_lambda_args =[a.id for a in node.args.args]
18         print astunparse.unparse(node)
19         # print astunparse.dump(node)
20         self.get_lambda_body_args(node.body)
21         self.generic_visit(node)
22 
23     def record_args(self, name_node):
24         if isinstance(name_node, ast.Name) and name_node.id not in self._cur_lambda_args:
25             self.illegal_args_list.append((self._cur_file, 'line no:%s' % name_node.lineno, 'var:%s' % name_node.id))
26 
27     def _is_args(self, node):
28         if isinstance(node, ast.Name):
29             self.record_args(node)
30             return True
31         if isinstance(node, ast.Call):
32             map(self.record_args, node.args)
33             return True
34         return False
35 
36     def get_lambda_body_args(self, node):
37         if self._is_args(node): return
38         # for cnode in ast.walk(node):
39         for cnode in ast.iter_child_nodes(node):
40             if not self._is_args(cnode):
41                 self.get_lambda_body_args(cnode)

遍历工程文件:

 1 project_dir = './your project/script'
 2 for root, dirs, files in os.walk(project_dir):
 3     py_files = filter(lambda file: file.endswith('.py'), files)
 4     checker = LambdaCheck()
 5     for file in py_files:
 6         file_path = os.path.join(root, file)
 7         checker.set_cur_file(file_path)
 8         with open(file_path, 'r') as f:
 9             root_node = ast.parse(f.read())
10             checker.visit(root_node)
11     res = '\n'.join([' ## '.join(info) for info in checker.illegal_args_list])
12     print res

4.3 作为type传参

下面是ast.literal_eval的文档:

Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python literal or container display. The string or node provided may only consist of the following Python literal structures: strings, numbers, tuples, lists, dicts, booleans, and None. This can be used for safely evaluating strings containing Python values from untrusted sources without the need to parse the values oneself. It is not capable of evaluating arbitrarily complex expressions, for example involving operators or indexing.

可以得知,ast.literal_eval可以从字符串中读取Python的string, numbers, tuples, lists, dicts, booleans and None类型的对象。所以我们只需指定当前argument的type为ast.literal_eval,就可以得到boolean类型的值了。但这种方法的问题在于,只有当参数输入为'False'时读取的值才为False,否则为True。如下面的例子所示:

parser.add_argument(
    '--flag',
    help='True or False flag, input should be either "True" or "False".',
    type=ast.literal_eval,
    dest='flag',
)

调用方法

python command --flag True
python command --flag False

4.3.1 ast.literal_eval与eval对比

例子:

先看一个比较场景的例子:

eval是Python的一个内置函数,这个函数的作用是,返回传入字符串的表达式的结果。想象一下变量赋值时,将等号右边的表达式写成字符串的格式,将这个字符串作为eval的参数,eval的返回值就是这个表达式的结果。

python中eval函数的用法十分的灵活,但也十分危险,安全性是其最大的缺点。

eval的语法格式如下:

eval(expression[, globals[, locals]])

expression : 字符串 globals : 变量作用域,全局命名空间,如果被提供,则必须是一个字典对象。 locals : 变量作用域,局部命名空间,如果被提供,可以是任何映射对象。

4.3.1.1 强大之处

结合globals和locals看看几个例子 传递globals参数值为{“age”:1822},

eval("{'name':'linux','age':age}",{"age":1822}) 输出结果:{‘name’: ‘linux’, ‘age’: 1822}

再加上locals变量

age=18
eval("{'name':'linux','age':age}",{"age":1822},locals())

根据上面两个例子可以看到当locals参数为空,globals参数不为空时,查找globals参数中是否存在变量,并计算。

当两个参数都不为空时,先查找locals参数,再查找globals参数,locals参数中同名变量会覆盖globals中的变量。

4.3.1.2 危险之处

eval虽然方便,但是要注意安全性,可以将字符串转成表达式并执行,就可以利用执行系统命令,删除文件等操作。 假设用户恶意输入。比如: eval("__import__('os').system('ls /Users/Downloads/')") 那么eval()之后,你会发现,当前文件夹文件都会展如今用户前面。这句其实相当于执行了 os.system('ls /Users/chunming.liu/Downloads/') 那就是删除rm和执行文件命令bin也是照样执行。很可怕。

后记

None