C++ Template

Metaprogramming 就是写一段程序,这段程序可以生成或者修改一段程序,Macro and Template 都是实现 Metaprogramming 的手段。

特化(specialization)和部分特化(Partial Specialization)或者说叫偏特化,都是为了让编译器明确执行路径或者说接口选择

特化

当定义的逻辑对所有类型的行为都是一致的时候,就不需要偏特化, 一套代码走天下,这就是全泛型

特化是“特殊化处理”的总称

函数模板是不支持偏特化的(只能全特化或重载),只有类模板才支持偏特化

全特化

为一组特定的、确定的类型提供一套完全不同的实现。此时,所有的模板参数都被指定了,不再有泛型空间

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// 基础模板:对所有类型 T 打印 "Generic"
template <typename T>
struct Printer {
    void print() { std::cout << "Generic" << std::endl; }
};

// 全特化:只针对 int 类型,打印 "Integer"
template <>
struct Printer<int> {
    void print() { std::cout << "Integer" << std::endl; }
};

偏特化

指定了部分模板参数,或者对参数增加了某种限制条件,但它依然保持一定的泛型能力。

在保持接口统一的前提下,针对特定类型进行性能优化或差异化实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// 基础模板:接受两个参数
template <typename T, typename U>
struct Storage { ... };

// 偏特化 A:当第二个参数固定为 int 时(指定部分参数)
template <typename T>
struct Storage<T, int> { ... };

// 偏特化 B:当参数是指针类型时(增加限制条件)
template <typename T>
struct Storage<T*> { ... };

显式实例化

显式实例化和特化的感觉比较像,但是目的完全不同。显式实例化是为了解决使用模版时分离编译导致的链接问题,特化则是满足功能性需求。

首先明确什么是“模板的分离编译“场景。分离编译是指将程序的源代码拆分为多个独立的编译单元(通常是 .cpp 或 .cu 文件),分别编译成二进制目标文件(.o 或 .obj),最后由链接器将它们组合成可执行文件。当这个场景遇到模板时,通常表现为以下布局:

  • kernel.h (头文件):仅包含函数模版的声明,例如 template <typename T> void compute();
  • kernel.cu (实现文件):包含函数模版的具体实现代码
  • main.cu (调用文件):通过 #include “kernel.h” 包含声明,并尝试调用该模板

对于编译器而言,因为模板不是真正的函数,如果没有人在当前文件里调用它,编译器就不会主动实例化,所以处理 kernel.cu 产生的 kernel.o 中不会包含函数的符号;处理 main.cu 时编译器可以看到对于函数的调用情况,但是它只能看到 kernel.h 中的声明,因此只能寄希望于链接器去别处找。

对于链接器而言,因为 kernel.o 中并没有包含 main.o 所需的符号,因此就会出现 undefined reference 的错误。

分离组织的形式是符合软件工程的设计逻辑的,但恰好遇到了模版“非调用不实例化"的性质,显式实例化就是为了解决这一问题。做法就是在 kernel.cu 模版实现后,类似于实际调用时指定所有参数,如以下代码所示

1
2
3
4
5
6
7
8
template<typename Element, int kHeadDim>
void run_qk_mma(QKMmaParams& params, cudaStream_t stream) {
  // 实现过程
}

// 显式实例化
template void run_qk_mma<cutlass::half_t, 32>(QKMmaParams&, cudaStream_t);
template void run_qk_mma<cutlass::half_t, 64>(QKMmaParams&, cudaStream_t);

从上面提到的编译链接流程不难看出,我们在 main.cu 中实际可以调用的函数模版类型的多少取决于我们显式实例化的类型的多少,这样的话模版在一定程度上就变为了伪模板。这实际上是C++ 工程实践中一个 tradeoff,分离编译 + 显式实例化带来了工程上的可维护性以及提升了编译速度,但是却丧失了一定灵活性。

调试技巧

  1. 编译时确定一个编译时常量的值

通过实例化一个没有实现的类模版,会报以下错误,从而我们就知道了一个编译时常量的值

error: incomplete type “cute::ValuePrinter<4>” is not allowed

1
2
3
4
template <int N> 
struct ValuePrinter;

ValuePrinter<RegNumA>{};

这里我们可以把 RegNumA 其他数据

  1. 编译时确定一个类型
1
2
template <typename T>
struct TypePrinter;

多模版并存时的符号决议全链路分析

这个问题来源于我们仅需要唯一的模版参数时,希望减少 FlashAttention 需要编译的 .cu 文件时,修改代码所遇到的编译和链接问题。具体而言,我们指定的输入数据对应到 kHeadDim=128 的分支中,但是在编译完成后,运行过程中报错缺少 kHeadDim=32 函数符号。

在 FlashAttention 中存在一些根据运行时值来判断执行路径的行为,我们很容易理解为执行哪个分支就只需要这个分析对应的函数被编译。

如果我们秉持着这个思路,我们会发现编译可以正常完成,但是在实际运行过程中可能会出现 undefined symbol 的错误。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
// flash_api.cpp
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
    FP16_SWITCH(!params.is_bf16, [&] {
        HEADDIM_SWITCH(params.d, [&] {
            BOOL_SWITCH(params.is_causal, Is_causal, [&] {
                if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
                    run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
                } else {
                    run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
                }
            });
        });
    });
}

想要分析这个问题的核心在于理解 C++ template 在分离编译模式下,编译器和链接器各自的行为,因此我们分别站在编译器和链接器两个视角来分析。

编译器视角

在 C++ template 下,编译器采用的是一种冗余式生成策略。例如在我们上面的示例代码中,运行时会进行分支选择,但是编译器并不知道运行时的数值,所以编译器只能把所有可能涉及到的模版全部进行实例化。

同时,在分离编译模式下,由于编译器只看到了函数的声明,所以只生成了对于函数符号的一个未定义引用。

链接器视角

链接器在全局视角下,把所有编译器生成的未定义引用之间构建起联系。

因为编译器为所有分支的函数都生成了引用需求,如果实际上并没有给对对应的实现,就会在链接过程出现问题。

0%