011-Multi-Head-Attention

发布时间:2026/5/15 14:26:12

011-Multi-Head-Attention Multi-Head Attention让 AI 拥有多个大脑同时思考摘要Multi-Head Attention多头注意力通过将注意力分成多个并行子空间让模型能同时捕获语法、语义、长距离等不同类型的关系是 Transformer 表达能力的关键。引言还记得上一篇文章讲的 Self-Attention 吗它让每个词能关注句子中的其他词。但如果一个模型只有一个注意力头就像让一个专家同时做所有事情——他能做到的事其实很有限。想象你在分析一篇文章你需要关注语法关系形容词修饰名词你需要关注语义关系苹果和水果的关联你需要关注长距离关系第 1 段的他指代第 3 段的人名你需要关注局部关系相邻词的搭配如果只用一个注意力头模型只能学习一种注意力模式。但现实中的语言关系如此多样一个头怎么可能兼顾所有这就是Multi-Head Attention多头注意力要解决的问题——让模型拥有多个大脑每个头从不同角度分析信息最后综合判断。核心概念什么是 Multi-Head AttentionMulti-Head Attention 的核心思想很简单把 Self-Attention 复制多份每份在独立的子空间中计算注意力最后把结果拼接起来。通俗类比单头注意力就像一个专家只用一种方法看问题。多头注意力就像一个委员会语法专家关注词性和句法结构语义专家关注词义相似度位置专家关注词的相对位置长距离专家关注远距离的词关联每个专家独立发表意见最后综合所有人的判断结果通常比任何单一专家都准确。为什么多头比单个大头好有人可能会问为什么不直接用一个大头但维度更大比如 8 个 64 维的头 vs 1 个 512 维的大头计算量差不多效果却不同对比维度单头大头多头小头注意力模式只有一种多种独立模式容错能力学偏了就全错即使某几个头效果不好其他头仍可补救表达能力受限多个子空间的组合表达能力更强分工协作无法分工训练后会自发分化出不同的注意力模式实验发现多头在训练过程中会自发分化出不同的关注模式语法头关注相邻词捕捉局部语法关系句法头关注主谓关系、介词短语附着长距离头忽略距离关注语义相似的词位置偏移头固定关注前一个或后几个词特殊标记头专门关注[CLS]等全局标记为什么每个头维度是 d_model/n_headTransformer 的设计非常巧妙每个头的维度不是保持完整的 d_model而是d_k d_model / n_head。这样做有三个原因保持总计算量不变8 个 64 维的头总计算量 ≈ 1 个 512 维的大头防止过参数化低维空间迫使每个头专注特定模式而不是冗余学习保持输出维度8 个 64 维拼接 512 维自然匹配后续层输入总参数量 3 × d_model × d_k × n_head 3 × 512 × 64 × 8 3 × 512²和单头完全一样但表达能力却强得多。原理深入计算公式MultiHead(Q, K, V) Concat(head_1, head_2, ..., head_h) W^O 其中 head_i Attention(Q W_i^Q, K W_i^K, V W_i^V)4 个关键步骤多组投影每组头有独立的 W_iQ、W_iK、W_i^V独立注意力每个头独立计算 Scaled Dot-Product Attention拼接Concat 所有头的输出线性变换通过 W^O 映射回原始维度残差连接Skip Connection在 Transformer 中每个子层后都有残差连接输出 子层(输入) 输入为什么需要残差连接解决梯度消失提供短路路径让梯度直接流回浅层缓解退化问题让网络更容易学习恒等映射加深网络时性能不退化平滑优化使损失函数更平滑更容易用梯度下降找到最优解层归一化Layer Normalization层归一化对单个样本的所有特征计算均值和方差然后标准化LN(x) γ × (x - μ)/σ β为什么用 LayerNorm 而不是 BatchNorm不依赖批次统计适合可变长度序列小批量时依然稳定与残差配合重新居中缩放为下一层提供稳定输入Pre-Norm vs Post-NormPost-Norm原始 Transformer子层 → 残差 → LNPre-NormGPT/BERT 常用LN → 子层 → 残差Pre-Norm 的梯度流更直接训练更稳定现代实现多采用 Pre-Norm。代码示例让我们用 PyTorch 实现一个完整的 Multi-Head Attention 模块包含残差连接和 LayerNorm。示例 1Multi-Head Attention 完整实现importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassMultiHeadAttention(nn.Module):Multi-Head Attention 完整实现def__init__(self,d_model512,num_heads8,dropout0.1):super().__init__()assertd_model%num_heads0,d_model 必须能被 num_heads 整除self.d_modeld_model self.num_headsnum_heads self.d_kd_model//num_heads# 每个头的维度 512/8 64# Q/K/V 投影矩阵self.w_qnn.Linear(d_model,d_model)self.w_knn.Linear(d_model,d_model)self.w_vnn.Linear(d_model,d_model)# 输出投影self.w_onn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)self.layer_normnn.LayerNorm(d_model)defscaled_dot_product_attention(self,Q,K,V,maskNone,dropout_fnNone):Scaled Dot-Product Attention# Q K^T 计算注意力分数scorestorch.matmul(Q,K.transpose(-2,-1))# 缩放scoresscores/math.sqrt(self.d_k)# 应用 mask生成时用ifmaskisnotNone:scoresscores.masked_fill(mask0,-1e9)# Softmax Dropoutattn_weightsF.softmax(scores,dim-1)ifdropout_fnisnotNone:attn_weightsdropout_fn(attn_weights)# 加权 Valuecontexttorch.matmul(attn_weights,V)returncontext,attn_weightsdefforward(self,x,maskNone):batch_size,seq_lenx.size(0),x.size(1)residualx# 保存残差# Pre-Norm先归一化xself.layer_norm(x)# 投影Qself.w_q(x)Kself.w_k(x)Vself.w_v(x)# 分割成多头: (batch, seq, d_model) → (batch, heads, seq, d_k)QQ.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)KK.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)VV.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)# 多头注意力context,attn_weightsself.scaled_dot_product_attention(Q,K,V,mask,self.dropoutifself.trainingelseNone)# 拼接多头: (batch, heads, seq, d_k) → (batch, seq, d_model)contextcontext.transpose(1,2).contiguous().view(batch_size,seq_len,self.d_model)# 输出投影 残差连接outputself.w_o(context)residualreturnoutput,attn_weights示例 2运行测试与形状验证torch.manual_seed(42)# 虚拟输入batch_size2, seq_len6, d_model512batch_size,seq_len,d_model2,6,512num_heads8xtorch.randn(batch_size,seq_len,d_model)# 创建模块mhaMultiHeadAttention(d_modeld_model,num_headsnum_heads,dropout0.1)output,attn_weightsmha(x)print( Multi-Head Attention 测试结果 )print(f输入形状{x.shape})# (2, 6, 512)print(f输出形状{output.shape})# (2, 6, 512)print(f注意力权重形状{attn_weights.shape})# (2, 8, 6, 6)print(f每个头的维度{d_model//num_heads})# 64# 观察不同头的注意力模式forhead_idxinrange(min(3,num_heads)):head_attnattn_weights[0,head_idx].mean(dim0)entropy(-head_attn*torch.log(head_attn1e-9)).sum()print(f第{head_idx}头注意力分布熵值{entropy:.2f})# 残差连接验证print(f\n输入标准差{x.std().item():.4f})print(f输出标准差{output.std().item():.4f})形状解读(2, 8, 6, 6)2 个样本8 个头每个头关注 6 个位置对其他 6 个位置熵值熵值越高注意力分布越均匀熵值越低注意力越集中在少数位置实战应用配置建议模型规模d_modelnum_headsd_k典型应用小模型256464轻量级 NLP 任务标准模型512864BERT-base, GPT-2 small大模型7681264BERT-large, GPT-2 medium超大模型10241664GPT-3, T5-large注意力头数越多越好吗不一定头数和模型效果的关系呈现边际收益递减从 1 头到 8 头效果提升明显从 8 头到 16 头提升较小从 16 头到 32 头几乎无提升甚至可能下降过拟合实践建议从 8 个头开始根据任务复杂度和计算资源调整。可视化不同头的关注模式不同头在训练后会自发分化# 可视化第 0 头语法头的注意力importmatplotlib.pyplotasplt plt.rcParams[font.sans-serif][SimHei]plt.rcParams[axes.unicode_minus]Falsetokens[我,昨天,去,了,北京,出差]attn_matrixattn_weights[0,0].detach().numpy()plt.figure(figsize(8,6))plt.imshow(attn_matrix,cmapviridis)plt.colorbar(labelAttention Weight)plt.xticks(range(len(tokens)),tokens)plt.yticks(range(len(tokens)),tokens)plt.show()通过观察不同头的热力图你会发现有的头主要关注相邻词局部语法有的头关注远距离词语义关联有的头形成对角线模式自关注最佳实践头数选择常用 8 个头d_model512每个头维度 64d_model 必须是 num_heads 的整数倍残差连接多头注意力输出后必须加残差连接防止梯度消失Pre-Norm 架构先 LayerNorm 再计算注意力梯度流更直接训练更稳定Dropout注意力权重上应用 Dropout不是输出上防止过拟合Mask 处理生成文本时使用因果 Mask防止模型看到未来信息熵值分析通过计算注意力分布熵值判断不同头的关注模式是否分化总结Multi-Head Attention 的核心要点多头分工每个头在独立子空间中计算注意力自发学习不同的关注模式维度设计d_k d_model / n_head保持计算量不变的同时增强表达能力残差连接让梯度直接流回浅层解决深层网络的梯度消失问题Pre-Norm先归一化再计算训练更稳定容错与鲁棒多个头互相补充比单头更稳定可靠理解了 Multi-Head Attention你就掌握了 Transformer 为什么能同时捕捉语法、语义、位置等多种关系的核心秘密。

相关新闻