用PyTorch逐行复现Transformer:从Attention到MultiHead的保姆级代码解读

张开发
2026/4/18 16:47:15 15 分钟阅读

分享文章

用PyTorch逐行复现Transformer:从Attention到MultiHead的保姆级代码解读
用PyTorch逐行复现Transformer从Attention到MultiHead的保姆级代码解读Transformer架构自2017年提出以来已成为自然语言处理领域的基石模型。本文将带您从零开始用PyTorch实现一个完整的Transformer模型特别聚焦于最核心的Multi-Head Attention机制。不同于简单的API调用我们会深入每一行代码的实现逻辑让您真正掌握其设计精髓。1. 环境准备与基础模块在开始构建Transformer之前我们需要准备开发环境并实现一些基础组件。确保已安装PyTorch 1.8和Matplotlib用于可视化。import torch import torch.nn as nn import torch.nn.functional as F import math import copy import matplotlib.pyplot as plt1.1 残差连接与层归一化Transformer中大量使用了残差连接和层归一化技术这是训练深层网络的关键。我们先实现这两个基础组件class LayerNorm(nn.Module): def __init__(self, features, eps1e-6): super().__init__() self.a_2 nn.Parameter(torch.ones(features)) self.b_2 nn.Parameter(torch.zeros(features)) self.eps eps def forward(self, x): mean x.mean(-1, keepdimTrue) std x.std(-1, keepdimTrue) return self.a_2 * (x - mean) / (std self.eps) self.b_2 class SublayerConnection(nn.Module): 残差连接后接层归一化 注意为了代码简洁这里先做归一化再进行子层处理与原论文顺序不同 def __init__(self, size, dropout): super().__init__() self.norm LayerNorm(size) self.dropout nn.Dropout(dropout) def forward(self, x, sublayer): return x self.dropout(sublayer(self.norm(x)))提示残差连接能有效缓解深层网络的梯度消失问题而层归一化则使每层的输入保持稳定分布两者结合大大提升了模型的训练稳定性。1.2 位置编码实现由于Transformer没有循环结构需要显式地注入序列的位置信息class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len5000): super().__init__() self.dropout nn.Dropout(pdropout) pe torch.zeros(max_len, d_model) position torch.arange(0, max_len).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] torch.sin(position * div_term) pe[:, 1::2] torch.cos(position * div_term) pe pe.unsqueeze(0) self.register_buffer(pe, pe) def forward(self, x): x x self.pe[:, :x.size(1)] return self.dropout(x)位置编码的可视化效果如下展示了不同维度的正弦波模式plt.figure(figsize(15, 5)) pe PositionalEncoding(20, 0) y pe(torch.zeros(1, 100, 20)) plt.plot(y[0, :, 4:8].data.numpy()) plt.legend([dim %d%p for p in [4,5,6,7]])2. Attention机制核心实现2.1 缩放点积注意力这是Transformer中最核心的计算单元实现了query-key-value的注意力机制def attention(query, key, value, maskNone, dropoutNone): 计算缩放点积注意力 d_k query.size(-1) scores torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores scores.masked_fill(mask 0, -1e9) p_attn F.softmax(scores, dim-1) if dropout is not None: p_attn dropout(p_attn) return torch.matmul(p_attn, value), p_attn关键点解析除以√d_k是为了防止点积结果过大导致softmax梯度消失mask机制在decoder中用于防止未来信息泄露返回的不仅是加权和还有注意力权重可用于可视化2.2 Multi-Head Attention实现多头注意力允许模型同时关注不同表示子空间的信息class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout0.1): super().__init__() assert d_model % h 0 self.d_k d_model // h self.h h self.linears clones(nn.Linear(d_model, d_model), 4) self.attn None self.dropout nn.Dropout(pdropout) def forward(self, query, key, value, maskNone): if mask is not None: mask mask.unsqueeze(1) nbatches query.size(0) # 1) 线性投影到h个头 query, key, value [ l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value)) ] # 2) 计算注意力 x, self.attn attention(query, key, value, maskmask, dropoutself.dropout) # 3) 合并多头结果 x x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x)多头注意力的三个典型应用场景场景类型Query来源Key/Value来源作用Encoder Self-Attn前一层的输出前一层的输出捕捉全局依赖关系Decoder Self-Attn前一层的输出前一层的输出保持自回归特性Encoder-Decoder AttnDecoder输出Encoder最终输出对齐源语言和目标语言3. 前馈网络与Embedding3.1 基于位置的前馈网络每个注意力层后面都接一个全连接前馈网络class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout0.1): super().__init__() self.w_1 nn.Linear(d_model, d_ff) self.w_2 nn.Linear(d_ff, d_model) self.dropout nn.Dropout(dropout) def forward(self, x): return self.w_2(self.dropout(F.relu(self.w_1(x))))3.2 Embedding层实现class Embeddings(nn.Module): def __init__(self, d_model, vocab): super().__init__() self.lut nn.Embedding(vocab, d_model) self.d_model d_model def forward(self, x): return self.lut(x) * math.sqrt(self.d_model)注意Embedding结果乘以√d_model是为了与位置编码保持相近的数值范围避免初始化阶段出现大幅值差异。4. Encoder与Decoder实现4.1 Encoder层堆叠class EncoderLayer(nn.Module): def __init__(self, size, self_attn, feed_forward, dropout): super().__init__() self.self_attn self_attn self.feed_forward feed_forward self.sublayer clones(SublayerConnection(size, dropout), 2) self.size size def forward(self, x, mask): x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) return self.sublayer[1](x, self.feed_forward) class Encoder(nn.Module): def __init__(self, layer, N): super().__init__() self.layers clones(layer, N) self.norm LayerNorm(layer.size) def forward(self, x, mask): for layer in self.layers: x layer(x, mask) return self.norm(x)4.2 Decoder层实现Decoder比Encoder多了一个encoder-decoder attention层class DecoderLayer(nn.Module): def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super().__init__() self.size size self.self_attn self_attn self.src_attn src_attn self.feed_forward feed_forward self.sublayer clones(SublayerConnection(size, dropout), 3) def forward(self, x, memory, src_mask, tgt_mask): m memory x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) return self.sublayer[2](x, self.feed_forward) def subsequent_mask(size): 创建防止信息泄露的mask attn_shape (1, size, size) subsequent_mask torch.triu(torch.ones(attn_shape), diagonal1).type(torch.uint8) return subsequent_mask 05. 完整模型组装将所有组件组合成完整的Transformer模型class EncoderDecoder(nn.Module): def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): super().__init__() self.encoder encoder self.decoder decoder self.src_embed src_embed self.tgt_embed tgt_embed self.generator generator def encode(self, src, src_mask): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) def forward(self, src, tgt, src_mask, tgt_mask): return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) def make_model(src_vocab, tgt_vocab, N6, d_model512, d_ff2048, h8, dropout0.1): 构建完整模型 c copy.deepcopy attn MultiHeadedAttention(h, d_model) ff PositionwiseFeedForward(d_model, d_ff, dropout) position PositionalEncoding(d_model, dropout) model EncoderDecoder( Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), nn.Sequential(Embeddings(d_model, src_vocab), c(position)), nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), nn.Linear(d_model, tgt_vocab)) # 使用Xavier初始化参数 for p in model.parameters(): if p.dim() 1: nn.init.xavier_uniform_(p) return model在实际项目中我发现初始化策略对Transformer训练至关重要。使用Xavier初始化配合适当的学习率预热能有效避免训练初期的梯度爆炸问题。

更多文章