作用在于将 C++ 实现的函数封装为 module,供 python 代码调用。
想要理解 pybind 的封装逻辑,还需要理解 python 是如何使用 pybind 生成的内容的,即关键点在于 python 的 import 机制的基本原理。
Python 的模块加载机制
Python 源文件(.py),Python 解释器会加载并执行该文件
扩展模块(例如 .so、.dll 或 .dylib 文件),Python 解释器会将其当作动态库来加载,并通过相应的 Python C API 或 Python/C++ 接口进行交互。
pybind 对应的就是第二种模式
pybind “hello-world”
编写供 Python 程序调用的 C++ 代码 1
2
3
4
5
6
7
8
9
10
#include <pybind11/pybind11.h>
int add ( int i , int j ) {
return i + j ;
}
PYBIND11_MODULE ( example , m ) {
m . doc () = "pybind11 example plugin" ; // optional module docstring
m . def ( "add" , & add , "A function which adds two numbers" );
}
C++ 编译为动态库 1
c++ -O3 -Wall -shared -std= c++11 -fPIC $( python3 -m pybind11 --includes) example.cpp -o example$( python3-config --extension-suffix)
$(python3 -m pybind11 --includes): 负责获取 pybind 所需要使用的头文件$(python3-config --extension-suffix): 负责获取文件后缀 .cpython-39-x86_64-linux-gnu.so检查使用效果 1
2
3
4
5
6
7
>>> import example
>>> example
<module 'example' from '/work/gaohy/experiments/pybind/example.cpython-39-x86_64-linux-gnu.so' >
>>> example.add( 1,2)
3
对于 PYBIND11_MODULE 宏的分析
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#define PYBIND11_MODULE(name, variable) \
static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name) \
PYBIND11_MAYBE_UNUSED; \
PYBIND11_MAYBE_UNUSED \
static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
PYBIND11_ENSURE_INTERNALS_READY \
auto m = ::pybind11::module_::create_extension_module( \
PYBIND11_TOSTRING(name), nullptr, &PYBIND11_CONCAT(pybind11_module_def_, name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} \
PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ & (variable))
同 pip 结合
封装模式
通过 PYBIND11_MODULE 完成最底层的 C++ function 的封装,向 Python 提供了底层算子的支持,需要注意这里是一次潜在的修改 function name 的机会,向 Python 暴露的函数名称并不一定等同于底层的算子名称。
一般会在 C++ function 封装之上构建一个 Python package,完成一些参数处理,校验等预处理工作,在 package 向外导出函数时也是有可能修改名称的,如下面代码所示
1
2
3
4
5
6
7
8
9
10
# __init__.py
from .recompute_kascade_gqa_prefill import flashattn as prefill_recompute_kernel
from .reuse_kascade_gqa_prefill import flashattn as prefill_reuse_kernel
__all__ = [ "prefill_recompute_kernel" , "prefill_reuse_kernel" ]
# recompute_kascade_gqa_prefill.py
def flashattn ( batch , heads , seq_len , dim , tune = False , groups = 1 , kernel_type = "prefill" ):
# 由于这一部分代码使用了 python DSL,所以底层并没有 C++ function 了,如果是贯穿式工作则底层还应当存在一个 pybind 封装
package 向外暴露的函数名称是 prefill_recompute_kernel 和 prefill_reuse_kernel。然而实际上包含前面两个名称和 flashattn 在内都已经是在 python 层面的封装了。