Attention

从抽象层面来理解注意力的 QKV 计算,实际是为原始 token 序列生成了一个新的信息表示。这个过程分为两个步骤:

  1. QK 计算 token 间的相关性,这种相关性会通过注意力分数来体现
  2. 基于相关性对 V 中的数据求解加权平均

具体而言,从矩阵的角度来看,Q 的每一行是一个 token 的原始信息,与 $K^T$ 的计算结果则表征了原始的一个 token 对于其他 token 的关注程度(此时如果 K 中行序发生改变,则 $K^T$ 中列序将发生改变,那么 $QK^T$ 的计算结果中一个 token 对其他 token 的关注程度,这个“其他 token”的顺序是发生了变化的);考虑与 V 的计算,由于左乘矩阵每一行依然对应到原始 token,所以与 V 的计算结果每一行依然是原始 token 的信息(如果 K 中行序发生变化,那么 V 的行序也需要对应变化,即只需要保证 $QK^T$ 的列能够正确作为对应 token 信息的权重)

Bottleneck of Attention in Different Tasks

这个分类并不是完整的,可能只是这几类是 Attention 机制比较场景的应用场景。

任务/阶段瓶颈类型
Non-Autoregressive ModelsCompute Bound
LLM TrainingCompute Bound
LLM Inference
1. Prefilling 预填充阶段Compute Bound
2. Decoding 解码阶段Memory Bound

不同的计算类型决定在不同阶段的优化方向

Classification of Attention

https://cdn.jsdelivr.net/gh/gaohongy/cloudImages@master/20251212143928379.png 1

Major CategorySubcategoryCore GoalKey TechniqueCore AdvantageMain LimitationTypical Methods
1. Hardware-efficient AttentionPrefilling 阶段提升计算吞吐量(适配 Compute-bound 场景)块划分(tiling)、核融合、低比特量化(INT4/8、FP8)、异步 TensorCore 调度不改变注意力公式,性能无损;适配 GPU 特性依赖硬件架构(多针对 NVIDIA GPU);开发门槛高FlashAttention1/2/3、SageAttention 系列
Decoding 阶段加速 KV 缓存 I/O(适配 Memory-bound 场景)KV 缓存拆分 / 重分配(分页、紧凑存储)、低比特量化、动态调度降低内存碎片,提升 SM 利用率;适配长序列解码兼容性有限(需定制内核)PagedAttention、FlashDecoding、KVQuant、FlashInfer
2. Compact Attention注意力头共享 / 分组类压缩 KV 缓存存储(减少独立 KV 头)全共享(MQA)、分组共享(GQA)、张量积生成(TPA)实现简单,计算逻辑无损;适配 LLM 推理过度共享可能降低模型表达能力MQA、GQA、TPA
低秩分解 / 特征压缩类压缩 KV 特征维度(存储降维 + 计算升维)低秩投影(MLA)、矩阵因子化(MFA)、时空稀疏压缩内存压缩比高,保留多头核心优势升维过程可能引入轻微性能损失MLA、MFA
3. Sparse AttentionPattern-based跳过预定义非关键计算(无需训练)固定稀疏掩码(滑动窗口、注意力汇(Attention Sink)、时空分块)训练无关,即插即用;适配特定任务(视频 / LLM)掩码固定,灵活性不足;高稀疏度易丢信息StreamingLLM、DiTFastAttn、STA、NeighborAttn
Dynamic sparse自适应跳过非关键计算(部分需训练)Runtime 掩码生成(Top-K 筛选、PCA 降维、聚类、门控预测)灵活性高,适配复杂场景;精度损失可控部分需训练 / 微调;动态筛选有额外开销SpargeAttn、H2O、SeerAttention、VSA、NSA
4. Linear AttentionNaive复杂度降为 O(N)(无门控)核函数替换 Softmax、块-wise 循环累加计算 / 内存效率极高;适配非自回归任务长序列信息保留差;表达能力有限Linear Transformer、Lightning Attention
带遗忘门(Forget Gate)优化长序列信息保留(控制历史依赖)固定 / 输入依赖遗忘门(RetNet、GLA)、通道-wise 衰减(RWKV 系列)平衡效率与长序列建模;支持自回归推理门控设计复杂;部分需调参优化RetNet、GLA、RWKV4/5/6/7
带遗忘 + 选择门(Dual Gates)精准控制信息取舍(历史 + 当前)Delta 规则(DeltaNet)、时变状态空间(Mamba)、门控融合表达能力接近标准注意力;效率无损失实现复杂;硬件适配要求高DeltaNet、Mamba、Mamba2、gDeltaNet
测试时训练型(TTT)动态优化隐藏状态(适配场景)隐藏状态作为“快权重”,梯度下降更新(线性 / MLP 子网络)长序列建模精度高;适配多任务块级优化复杂;部分非线性结构难并行TTT-Linear/MLP、Titans、LaCT

GAT

图注意力网络中多头机制的目的在于让每个头专注于学习不同类型的关系特征,因此只会划分节点的特征数据,而不会对节点或边进行划分,即不划分拓扑结构。

Flash Attnetion

都是模版特化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
// 基础
flash_fwd_hdim64_bf16_sm90.cu
flash_fwd_hdim64_bf16_softcap_sm90.cu	// soft-cap 优化提升数值稳定性
flash_fwd_hdim64_bf16_split_sm90.cu		// 序列切分

// 额外功能
flash_fwd_hdim64_bf16_paged_sm90.cu
flash_fwd_hdim64_bf16_packgqa_sm90.cu


# softcap 功能叠加
flash_fwd_hdim64_bf16_paged_softcap_sm90.cu
flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu

# split 功能叠加
flash_fwd_hdim64_bf16_split_softcap_sm90.cu
flash_fwd_hdim64_bf16_paged_split_sm90.cu
flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu
  1. mainloop_fwd_sm90_tma_gmma_ws.hpp
1
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_pv, cute::conditional_return<MmaPV_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO);
  1. utils.h
1
2
3
for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) {
  cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}

The number of executions of function cute::gemm() is determined by std::min(kNumKIters, kMaxKIters).

1
2
static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA));
static constexpr int kMaxKIters = 16;
  1. gemm.hpp(TiledMMA)
1
mma.call(D, A, B, C);

3.1 mainloop_fwd_sm90_tma_gmma_ws.hpp

1
2
3
4
5
6
7
using TiledMmaQK = decltype(cute::make_tiled_mma(
    std::conditional_t<
        !MmaQK_is_RS,
        /*SMEM, SMEM*/ decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
        /*resigter, SMEM*/ decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
    >{},
    AtomLayoutQK{}));

3.2 mma_sm90.hpp

1
2
3
4
ss_op_selector()
{
  return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
}

Considering the simplicity of the description, we don’t copy the wholl content of ss_op_selector() function. The principle of selection is based on the following process:

  1. data type of accumulator(F16, F32, S32)
  2. data type of input A and B
  3. because the apis fix the value of M and K, so it only need to judge is the relationship between the value of N and some preset values, the judgement follows the following order: 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, which are the multiples of 8

Actually, this step will use the value of N. But it isn’t a simple process just only select the mma api according to the value of N.

The actual value of N which will be used is the kBlockN, that is generated by the following code:

1
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);

At the same time, we can consider the generation method of the actual value of M and K which will be used:

1
2
3
4
5
6
7
8
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);

// there is only some situations preset for the value of K
#ifndef FLASHATTENTION_DISABLE_HDIM64
#ifndef FLASHATTENTION_DISABLE_HDIM96
#ifndef FLASHATTENTION_DISABLE_HDIM128
#ifndef FLASHATTENTION_DISABLE_HDIM192
#ifndef FLASHATTENTION_DISABLE_HDIM256

From the above analysis, we can confirm that the kHeadDim is decided by the environmental variable which is used to decide which macro defination will be enabled.

However, the kBlockM and kBlockN is decided by the kBlockMN_RS_IntraWGOverlap variable that is generated by the tile_size_fwd_sm90 function.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
if (headdim <= 64) {
  if (headdim_v == 512) {
      return {64, 64, false, false};
  } else if (headdim_v == 256) {
      return {128, 96, true, false};
  } else {
      // Switch to tile size 192 x 192 for now
      bool const use_blockN_128 = is_causal || is_local;
      return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true};
  }
}

The value of kBlockM and kBlockN is decided by the value of headdim and headdim_v, the dimension of actual input data of k and v.

Based on the information currently available, we couldn’t use the m64n16k16 shape api.

  1. mma_atom.hpp(MMA_Atom)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
  // Cast, check, and call fma
  template <class TD, class DLayout,
            class TA, class ALayout,
            class TB, class BLayout,
            class TC, class CLayout>
  CUTE_HOST_DEVICE constexpr
  void
  call(Tensor<TD, DLayout>      & D,
       Tensor<TA, ALayout> const& A,
       Tensor<TB, BLayout> const& B,
       Tensor<TC, CLayout> const& C) const
  {
    static_assert(DLayout::rank == 1, "Expected rank-1 D tensor");
    static_assert(ALayout::rank == 1, "Expected rank-1 A tensor");
    static_assert(BLayout::rank == 1, "Expected rank-1 B tensor");
    static_assert(CLayout::rank == 1, "Expected rank-1 C tensor");

    // 实现会调用 MMAOperation::fma,由此把 MMA_Atom 同 MMAOperation 相关联
    return mma_unpack(static_cast<Traits const&>(*this), D, A, B, C);
  }
  1. mma_traits_sm90_gmma.hpp(MMA_Traits)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
template <class MMA_Op, class... MMA_Args,
          class TD, class DLayout,
          class TA, class ALayout,
          class TB, class BLayout,
          class TC, class CLayout>
CUTE_HOST_DEVICE constexpr
void
mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits,
           Tensor<TD, DLayout>      & D,
           Tensor<TA, ALayout> const& A,
           Tensor<TB, BLayout> const& B,
           Tensor<TC, CLayout> const& C)
{
  static_assert(is_rmem<TD>::value, "Expected registers in MMA_Atom::call");
  static_assert(is_rmem<TA>::value, "Expected registers in MMA_Atom::call");
  static_assert(is_rmem<TB>::value, "Expected registers in MMA_Atom::call");
  static_assert(is_rmem<TC>::value, "Expected registers in MMA_Atom::call");

  // Register value types from the MMA_Operation register arrays
  using RegTypeA = typename remove_extent<typename MMA_Op::ARegisters>::type;
  using RegTypeB = typename remove_extent<typename MMA_Op::BRegisters>::type;
  using RegTypeC = typename remove_extent<typename MMA_Op::CRegisters>::type;

  // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased
  static_assert(is_same<typename TD::value_type, typename TC::value_type>::value, "GMMA C and D value_type must match.");
  static_assert(is_same<DLayout, CLayout>::value, "GMMA C and D layouts must match.");
  // assert((void*)&C == (void*)&D);

  Tensor rA = recast<RegTypeA>(A);
  Tensor rB = recast<RegTypeB>(B);
  Tensor rC = recast<RegTypeC>(D);  // NOTE: D and C are same, so use mutable D

  constexpr int RegNumA = extent<typename MMA_Op::ARegisters>::value;
  constexpr int RegNumB = extent<typename MMA_Op::BRegisters>::value;
  constexpr int RegNumC = extent<typename MMA_Op::CRegisters>::value;

  CUTE_STATIC_ASSERT_V(size(rA) == Int<RegNumA>{});
  CUTE_STATIC_ASSERT_V(size(rB) == Int<RegNumB>{});
  CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{});

  detail::explode(MMA_Op::fma,
                  rA, make_int_sequence<RegNumA>{},
                  rB, make_int_sequence<RegNumB>{},
                  rC, make_int_sequence<RegNumC>{},
                  &(traits.accumulate_), seq<0>{});
}
  1. mma_sm90_gmma.hpp(MMAOperation)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
template <
  GMMA::Major tnspA,
  GMMA::Major tnspB,
  GMMA::ScaleIn  scaleA = GMMA::ScaleIn::One,
  GMMA::ScaleIn  scaleB = GMMA::ScaleIn::One
>
struct MMA_64x128x16_F32BF16BF16_SS
{
  using DRegisters = void;
  using ARegisters = uint64_t[1];
  using BRegisters = uint64_t[1];
  using CRegisters = float[64];

  CUTE_HOST_DEVICE static void
  fma(uint64_t const& desc_a,
      uint64_t const& desc_b,
      float         & d00, float         & d01, float         & d02, float         & d03,
      float         & d04, float         & d05, float         & d06, float         & d07,
      float         & d08, float         & d09, float         & d10, float         & d11,
      float         & d12, float         & d13, float         & d14, float         & d15,
      float         & d16, float         & d17, float         & d18, float         & d19,
      float         & d20, float         & d21, float         & d22, float         & d23,
      float         & d24, float         & d25, float         & d26, float         & d27,
      float         & d28, float         & d29, float         & d30, float         & d31,
      float         & d32, float         & d33, float         & d34, float         & d35,
      float         & d36, float         & d37, float         & d38, float         & d39,
      float         & d40, float         & d41, float         & d42, float         & d43,
      float         & d44, float         & d45, float         & d46, float         & d47,
      float         & d48, float         & d49, float         & d50, float         & d51,
      float         & d52, float         & d53, float         & d54, float         & d55,
      float         & d56, float         & d57, float         & d58, float         & d59,
      float         & d60, float         & d61, float         & d62, float         & d63,
      GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
  {
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
    cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b);
    asm volatile(
    "{\n"
      ".reg .pred p;\n"
      "setp.ne.b32 p, %66, 0;\n"
      "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 "
      "{%0,   %1,   %2,   %3,   %4,   %5,   %6,   %7,   "
      " %8,   %9,   %10,  %11,  %12,  %13,  %14,  %15,  "
      " %16,  %17,  %18,  %19,  %20,  %21,  %22,  %23,  "
      " %24,  %25,  %26,  %27,  %28,  %29,  %30,  %31,  "
      " %32,  %33,  %34,  %35,  %36,  %37,  %38,  %39,  "
      " %40,  %41,  %42,  %43,  %44,  %45,  %46,  %47,  "
      " %48,  %49,  %50,  %51,  %52,  %53,  %54,  %55,  "
      " %56,  %57,  %58,  %59,  %60,  %61,  %62,  %63},"
      " %64,"
      " %65,"
      " p,    %67,  %68,  %69,  %70;\n"
    "}\n"
      : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
        "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
        "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
        "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
        "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
        "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
        "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
        "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
        "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
        "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
        "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
        "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
        "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
        "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
        "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
        "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
      :  "l"(desc_a),
         "l"(desc_b),
         "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
#else
    CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
  }
};

示例流程

https://cdn.jsdelivr.net/gh/gaohongy/cloudImages@master/20251216000418161.png

https://cdn.jsdelivr.net/gh/gaohongy/cloudImages@master/20251215235854686.png 2

Sparse Attention

Pattern-based

Dynamic sparse

KV Cache

https://cdn.jsdelivr.net/gh/gaohongy/cloudImages@master/20251212094950487.png

Cache 的两个核心:不变、有用

  1. KV Cache 为什么不变

https://cdn.jsdelivr.net/gh/gaohongy/cloudImages@master/20251212101429423.png

Causal Mask => 历史 KV 不变

目前还没有从数据的角度理解带与不带 mask 时,数据的变化方式,但是考虑到理解 kv cache 并不对解决目前的主要疑惑有什么贡献,所以暂时搁置一下

  1. KV Cache 为什么有用

自回归推理 => 历史 KV 有用

KV Cache 可行性解读

KV Cache 的开销要从两个方面来理解:I/O 开销;存储开销

“固定模式稀疏注意力"和"动态稀疏注意力"因为都涉及到 mask,所以 I/O 开销是都会减少的。

固定模式稀疏注意力因为 mask 固定,所以只需要存储掩码选中的 KV 数据,存储开销是会降低的;动态稀疏注意力如果 mask 是来源于 q,k,v 的相关性,那就需要保存完整的 KV 数据,存储开销是不会降低的。

0%