注意力机制(Attention Mechanism)是现代深度学习中最核心的突破之一。从早期在 Seq2Seq 模型中引入的简单对齐机制,到 Transformer 中的多头自注意力,再到近年来为了效率和长序列而生的各类变体——Attention 的演进史就是大语言模型(LLM)发展的缩影。
注意力机制的核心思想源自认知科学中的"视觉注意力"——人类在处理视觉信息时,会有选择性地关注某些区域,而非平等处理全部信息。借鉴这一思路,Attention 让模型能够动态地加权输入序列的不同部分,从而聚焦于与当前任务最相关的信息。
从数学上,Attention 可以抽象为一个查询(Query)在一组键值对(Key-Value pairs)上的加权求和:
其中:
在 Attention 机制出现之前,Seq2Seq 模型使用 RNN/LSTM 的最后一个隐藏状态作为上下文向量,存在以下问题:
| 问题 | 描述 | 影响 |
|---|---|---|
| 信息瓶颈 | 固定长度的上下文向量无法容纳长序列信息 | 长序列翻译质量随长度急剧下降 |
| 梯度消失 | RNN 在反向传播时梯度随距离指数衰减 | 难以捕获长距离依赖关系 |
| 顺序约束 | 必须按时间步依次计算 | 无法并行化,训练速度慢 |
Attention 机制从根本上解决了这些问题:不再依赖固定维度的向量压缩全部信息,而是让解码器在每一步都能"回看"编码器的全部输出。
| 时间 | 论文/工作 | 核心贡献 |
|---|---|---|
| 2014 | Bahdanau et al. "Neural Machine Translation by Jointly Learning to Align and Translate" | 首次在 NMT 中加入注意力机制(Additive Attention) |
| 2015 | Luong et al. "Effective Approaches to Attention-based Neural Machine Translation" | 提出 Global/Local Attention 和多种评分函数 |
| 2016 | Xu et al. "Show, Attend and Tell" | 将注意力引入图像描述生成 |
| 2017 | Vaswani et al. "Attention Is All You Need" | Transformer + Scaled Dot-Product + Multi-Head Attention |
| 2019 | Transformer-XL / XLNet | 相对位置编码与分段循环 |
| 2019 | ALiBi (Press et al.) | 基于距离的线性偏置位置编码 |
| 2020 | Longformer / BigBird | 稀疏注意力机制(Sliding Window + Global) |
| 2020 | Reformer (Kitaev et al.) | LSH 哈希近似注意力 |
| 2022 | Multi-Query Attention (MQA) | 共享 KV 头,加速推理 |
| 2023 | Grouped-Query Attention (GQA) | 折中方案,MQA 与 MHA 的中间态 |
| 2023 | FlashAttention (Dao et al.) | IO-aware 注意力加速算法 |
| 2023-2024 | Linear Attention、Mamba | 线性复杂度替代方案,状态空间模型 |
Attention 的核心在于如何计算 Query 和 Key 之间的相关性分数。以下是主流评分函数:
使用一个前馈网络计算对齐分数:
直接计算点积并进行缩放:
假设 和 的元素来自均值为 0、方差为 1 的随机变量。那么:
标准差不缩放时,softmax 输入会因为方差大而进入饱和区(只有极少数位置有显著梯度)。缩放 后:
| 函数 | 公式 | 特点 |
|---|---|---|
| 一般形式 | 引入可学习矩阵实现非线性变换 | |
| 拼接加性 | Bahdanau 的变体,将 Q 和 K 拼接 | |
| Cosine | 仅关注方向相似度,忽略长度 |
Soft Attention(确定性注意力)
Hard Attention(随机注意力)
实践中通常将注意力权重矩阵 可视化(常见于机器翻译论文中),图中可以看到源语言和目标语言之间的词对齐关系。在高层次,注意力矩阵形成了一种隐式的"对齐"映射。
Self-Attention 中 Q、K、V 全部来自同一个输入序列:
其中 是输入序列, 是可学习的投影矩阵。
Self-Attention 通过三个线性投影将输入映射到不同的表示空间。每个位置的输出是所有位置的 V 的加权和,权重由 Q 和 K 的相似度决定。
Self-Attention 的信息流:每个位置都能直接"看到"所有其他位置,形成全连接图。与 RNN 相比,Self-Attention 的路径长度为 1(任意两个位置之间只需一步),完美解决了长距离依赖问题。
Multi-Head Attention 是 Transformer 的关键创新:
其中每个 head 独立计算:
关于 head 数 与维度 :通常 ,即当 增加时,每个 head 的维度相应减小,保持总计算量基本不变。例如 GPT-3 使用 96 个 head,每个 head 的维度为 128。
为什么需要多头?
有研究(Clark et al., 2019)通过对 BERT 的注意力分析发现:
实际配置参考:
| 模型 | Head 数 | 每 head 维度 | |
|---|---|---|---|
| BERT-Base | 768 | 12 | 64 |
| BERT-Large | 1024 | 16 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| LLaMA-7B | 4096 | 32 | 128 |
| LLaMA-65B | 8192 | 64 | 128 |
在语言建模(自回归生成)中,模型不能看到未来的 token。因此引入 Mask:
其中 是一个上三角矩阵,元素为 (或一个非常大的负数):
这样在 softmax 后,未来位置的权重变为 0,模型只能关注当前及之前的位置。这种 Mask 在训练和推理时都使用。
Cross-Attention 在编码器-解码器结构中使用:Q 来自一个序列(如解码器输出),K 和 V 来自另一个序列(如编码器输出)。
典型应用:
在 Transformer 解码器中,每个解码器层包含两层 Attention:
Self-Attention 本身是置换不变的(permutation-invariant)——如果将输入序列打乱,输出的每个位置的表示也打乱。这意味着模型完全不知道词语的顺序。因此必须注入位置信息。
Sinusoidal Positional Encoding(Transformer 原始方案)
特点:
Learned Positional Encoding(BERT/ALBERT)
RoPE(Rotary Position Embedding)是目前最流行的位置编码方案,被 LLaMA、Qwen、GLM 等多数主流模型采用。
核心思想:通过旋转矩阵将位置信息注入 Query 和 Key:
其中 是旋转矩阵,对向量实施与位置 相关的旋转:
RoPE 的核心性质: 的值仅依赖于相对位置 ,而非绝对位置。这带来了很好的长度外推能力。
实际使用时的实现技巧:通过复数乘法高效实现,不需要显式构建大矩阵:
def apply_rope(x, position_ids, theta=10000.0):
"""x: (batch, seq_len, num_heads, head_dim)"""
seq_len = x.shape[1]
head_dim = x.shape[-1]
# 预计算频率
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
# 计算每个位置的旋转角度
positions = position_ids.float()
angles = positions[:, None] * freqs[None, :] # (batch, head_dim/2)
# 复数旋转实现
cos = torch.cos(angles).unsqueeze(-1).repeat(1, 1, 2)
sin = torch.sin(angles).unsqueeze(-1).repeat(1, 1, 2)
# 奇偶交错旋转
x_rotated = torch.cat([
-x[..., 1::2], x[..., ::2]
], dim=-1)
return x * cos + x_rotated * sin
ALiBi 是另一种位置编码方案,简单高效:
核心思想:去掉位置编码,在注意力分数上直接加上基于距离的线性偏置。 是每个 head 的斜率(按几何级数递减),距离越远的 token 偏置越大。
ALiBi 的优势:
ALiBi 被 BLOOM 等模型采用。
| 方案 | 外推能力 | 计算开销 | 主流应用 |
|---|---|---|---|
| Sinusoidal | 好 | 极低 | 原始 Transformer |
| Learned | 无 | 低 | BERT |
| RoPE | 优秀 | 中 | LLaMA, Qwen, GLM, Mistral |
| ALiBi | 优秀 | 极低 | BLOOM, MPT |
| XL/RoPE + NTK-aware | 极佳 | 中低 | 上下文扩展到 128K+ |
由于训练序列长度有限,推理时需要处理更长的上下文。常见的长度外推方法:
实践证明,NTK-aware + YaRN 组合能够将 4K 训练的模型外推到 128K 以上。
随着模型层数和上下文长度增长,标准 Attention 的 复杂度成为关键瓶颈。以下技术应对这一挑战。
MQA 是 LLaMA 原始版本采用的高效方案之一,在推理场景下显著加速。
标准 Multi-Head Attention:每 head 有独立的 K 和 V 投影
MQA:所有 head 共享相同的 K 和 V
为什么 MQA 节约推理内存?
性能影响:
GQA 是 MQA 与 MHA 的折中方案,由 Google 在 2023 年提出。
核心思想:将 个 head 分成 组,组内共享 KV:
GQA 的自由度:
实际配置对比:
| 模型 | 参数量 | Attention 类型 | KV Head 数 |
|---|---|---|---|
| LLaMA-7B | 7B | MHA | 32 |
| LLaMA-2-70B | 70B | GQA (8组) | 8 |
| LLaMA-3-8B | 8B | GQA (8组) | 8 |
| LLaMA-3-70B | 70B | GQA (8组) | 8 |
| Falcon-180B | 180B | MQA | 1 |
| Mistral-7B | 7B | GQA (8组) | 8 |
从 LLaMA-2 开始,GQA 已成为大型 LLM 的默认选择。
FlashAttention(Dao et al., 2022, 2023, FlashAttention 2 & 3)是革命性的 Attention 加速算法,不改变数学结果但大幅加速计算。
核心洞察:标准 Attention 的计算瓶颈不在 GPU 算力(FLOPs),而在显存带宽(memory bandwidth)。Attention 计算分为两个步骤:
在标准实现中, 矩阵需要写入 HBM(高带宽显存),再读回来用于第二步。这浪费了大量带宽。
FlashAttention 的做法:
FlashAttention 的数学技巧——安全 softmax:
传统 softmax 计算:
在线 softmax 分块计算:
# FlashAttention 的在线 softmax 核心思想(简化版)
def online_softmax_attn(Q_block, K_block, V_block, prev_out, prev_max, prev_sum):
"""合并在线 softmax 的计算"""
# 计算当前块
S_block = Q_block @ K_block.T
local_max = S_block.max(dim=-1, keepdim=True) # 按行的最大值
# 指数化
exp_S = torch.exp(S_block - local_max)
local_sum = exp_S.sum(dim=-1, keepdim=True)
# 合并:调整 prev 的状态
new_max = torch.max(prev_max, local_max)
reweight = torch.exp(prev_max - new_max)
# 计算当前块对输出的贡献
output_block = exp_S @ V_block * (local_sum ** -1)
# 合并 prev_out 和 output_block
# … (实际实现更复杂,含递归合并)
FlashAttention 版本演进:
| 版本 | 改进 | 加速比(vs PyTorch) |
|---|---|---|
| v1 | 基本 tiling + online softmax | 2-4x |
| v2 | 减少非矩阵乘操作,优化 warp 调度 | 3-6x |
| v3 | 利用 Hopper GPU 新特性 | 6-9x |
使用方式:FlashAttention 已被集成到主流框架中。
# PyTorch 2.0+ 原生支持(基于 FlashAttention)
import torch.nn.functional as F
# 自动使用 FlashAttention(满足条件时)
with torch.backends.cuda.sdp_kernel(enable_flash=True):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
FlashAttention 已成为现代 Attention 加速的标准工具,直接写 CUDA 的效率天花板。
在自回归推理中,为每个新生成 token 重新计算所有历史位置的 K 和 V 是浪费的。KV Cache 通过存储已计算的 K,V 值来避免重复计算。
KV Cache 工作原理:
KV Cache 大小估算:
以 LLaMA-3-70B 为例(FP16, 80层, 8 KV heads, 128 head_dim):
这就是为什么长上下文推理对显存需求极大。
KV Cache 优化技术:
PagedAttention 是 vLLM 推理框架的核心创新,灵感来自操作系统的虚拟内存分页。
核心思想:KV Cache 的内存管理与操作系统虚拟内存类似:
带来的好处:
性能提升:vLLM 的 PagedAttention 相比 HuggingFace Transformers 的原始实现,吞吐量提升 2-4 倍,成为最流行的推理引擎之一。
对于超长序列(如 100K+ tokens),即使是 FlashAttention 优化的 也难以为继。以下方法尝试降低复杂度。
稀疏 Attention 的核心思想是:一个 token 不需要关注所有其他 token,可以选择性地关注。
只关注当前位置前后固定窗口内的 token:
滑动窗口 Attention 的信息传播范围:经过 L 层 Transformer,每个位置可以"看到"最远 距离内的信息。
类似 CNN 的膨胀卷积,每层使用不同的步长:
效果:以对数级代价覆盖全局感受野。
混合两种模式:
[CLS])关注全序列BigBird 进一步扩展为三种模式:
BigBird 复杂度:从 降至 ,在长文本分类、QA 任务上效果接近 Full Attention。
线性 Attention 尝试完全移除 softmax,从而允许改变计算顺序:
如果去掉 softmax:
关联三个矩阵乘法有结合律:
的复杂度为 而不是 !这就是线性 Attention 的核心。
关键问题:softmax 不能直接拆开。解决方案是用核函数近似 softmax:
其中 是特征映射函数(如 )。
Linear Transformer(Katharopoulos et al., 2020) 使用:
Performer(Choromanski et al., 2020) 使用:
线性 Attention 的局限:
| 问题 | 说明 |
|---|---|
| 近似误差 | 核函数近似劣于精确 softmax |
| 训练不稳定 | 在大模型上训练容易发散 |
| 知识保留 | 在复杂的上下文理解任务上不如标准 Attention |
| 社区接受度 | 被 Mamba 等 SSM 替代方案超越 |
结论:纯粹的线性 Attention 在大模型时代并未成为主流,被状态空间模型(Mamba, Mamba-2)和混合架构取代。
严格来说 SSM 不是 Attention 的变体,但它是替代 Attention 解决长序列问题的重要方向。
Mamba(Gu & Dao, 2023) 的核心思路:
Mamba 与 Attention 的对比:
| 特性 | Self-Attention | Mamba |
|---|---|---|
| 训练复杂度 | ||
| 推理复杂度 | (带 KV Cache) | |
| 长序列处理 | 需 KV Cache 管理 | 状态持续 |
| 记忆强度 | 可直接访问所有位置 | 间接通过隐藏状态 |
| SoTA 模型 | GPT-4, Claude 3, LLaMA 3 | Mamba-2, Jamba, Samba |
混合架构:最新趋势是结合 Attention 和 SSM,即在一个模型中交替使用两种层:
Hugo 的实践笔记:在实际部署中,对于 8K 以内上下文,标准 FlashAttention 完全够用。128K+ 长上下文推理才需要关注稀疏或线性方案。混合架构(Attention + SSM)是趋势,但成熟度还不足。
传统推理中,batch 填充到最大序列长度会导致显著浪费。Continuous Batching 动态管理:
# 传统的填充方式(浪费大)
max_len = 2048
batch = pad_sequences(sequences, max_len) # 短序列被填充
# Continuous Batching(按实际长度)
batch = [seq1, seq2, ...] # 每个序列独立处理,互不阻塞
效果:吞吐量提升 2-3 倍。
两者都与 Attention 机制紧密相关:KV Cache 共享和注意力计算优化是这些技术的基础。
对极长序列推理的另一种思路:只保留"最近窗口 + 少数关键 token(Attention Sinks)"的 KV Cache:
# Streaming LLM 简化的实现思路
class StreamingLLM:
def __init__(self, window_size, sink_size=4):
self.window = [] # 滑动窗口
self.sink_tokens = [] # 最开始的 4 个 token(Attention Sink)
def generate(self, tokens):
for token in tokens:
# 只在这两个集合上计算 Attention
kv_subset = self.sink_tokens + self.window
next_token = self.model.forward(kv_subset)
self.window.append(token)
if len(self.window) > window_size:
self.window.pop(0)
Attention Sink 现象:语言模型的前几个 token(特别是 BOS token)会吸收大量未使用的注意力,即使它们在语义上不相关。保留这些 token 可以稳定长上下文推理。
以下内容来自实际项目经验,供团队参考
QK 维度匹配:在做 Attention 改造(如 MHA→GQA)时,一定要注意 Q 和 K 最后一个维度的对齐。GQA 中 Q heads > K heads,需要将 Q 分组后分别与对应 K 做点积。
MQA 训练陷阱:直接用 MQA 从头训练大模型可能质量下降明显。推荐的实践路径:先用 MHA 训练(或者 MHA 预训练 + MQA fine-tune),再转换为 MQA 推理。
FlashAttention 要求:
RoPE 实现细节:
分支 Attention 测试:在将 MHA 切换为 GQA 时,务必逐层验证 KV 对齐,用数值验证(check_numerics 或对比 output)确认分组逻辑正确。一次群里就有同事因为分组维度没对齐,调了 3 天才找到 bug。
推理显存估算公式:
推理最小显存 = 模型权重 (FP16) + KV Cache + 激活值
模型权重 ≈ 参数量 × 2 bytes (FP16)
KV Cache ≈ 2 × L × H_kv × L_seq × d_head × 2 bytes
示例:LLaMA-3-8B (FP16, batch=1, seq=4096)
- 模型权重: 8B × 2 = 16 GB
- KV Cache: 2 × 32 × 8 × 4096 × 128 × 2 = 0.5 GB
- 合计: ~16.5 GB → 一张 24G 的 4090 刚好
# 简化版 Scaled Dot-Product Attention
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(q, k, v, mask=None, causal=True):
"""Scaled Dot-Product Attention with optional causal mask
Args:
q: (batch, heads, seq_q, d_k)
k: (batch, heads, seq_k, d_k)
v: (batch, heads, seq_k, d_v)
causal: whether to apply causal masking
"""
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
if causal:
seq_len = q.size(-2)
mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
scores = scores + mask.to(scores.device)
if mask is not None:
scores = scores + mask
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, v), attn_weights
# 使用 PyTorch 2.0 SDPA(自动启用 FlashAttention)
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
2024-2025 年,混合架构成为主流:
初步结果表明:混合架构在保持 推理的同时,能接近纯 Transformer 的质量。
随着 AI Agent 的发展,Attention 不再仅关注 token 级别的相关性:
此页面为 AI 知识体系 的一部分,内容持续更新中。