TR3周:Pytorch复现Transformer

发布时间:2026/5/25 18:00:32

TR3周:Pytorch复现Transformer 本文为365天深度学习训练营 中的学习记录博客 原作者K同学啊思维导图如下Transformer 模型完整实现Transformer 是2017年Google提出的革命性深度学习架构完全基于注意力机制摒弃了传统的RNN和CNN结构在机器翻译等序列到序列任务上取得了突破性成果。本代码实现了完整的Transformer模型包括编码器-解码器架构。importmathimporttorchimporttorch.nnasnn devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)classMultiHeadAttention(nn.Module):# n_heads:多头注意力的数量# hid_dim:每个词输出的向量维度def__init__(self,hid_dim,n_heads):super(MultiHeadAttention,self).__init__()self.hid_dimhid_dim self.n_headsn_heads#强制hid_dim必须整除 hasserthid_dim%n_heads0#定义W_q矩阵ceself.w_qnn.Linear(hid_dim,hid_dim)#定义W_k矩阵self.w_knn.Linear(hid_dim,hid_dim)#定义W_v矩阵self.w_vnn.Linear(hid_dim,hid_dim)self.fcnn.Linear(hid_dim,hid_dim)#缩放self.scaletorch.sqrt(torch.FloatTensor([hid_dim//n_heads]))defforward(self,query,key,value,maskNone):#Q,K,V的在句子这长度这一个维度的数值可以不一样可以一样#K:[64,10,300],假设batch_size为64有10个词每个词的Query向量是300维bszquery.shape[0]Qself.w_q(query)Kself.w_k(key)Vself.w_v(value)#这里把K Q V 矩阵拆分为多组注意力#最后一维就是是用self.hid_dim // self.n_heads 来得到的表示每组注意力的向量长度每个head的向量长度是:300/650#64表示batch size,6表示有6组注意力10表示有10词50表示每组注意力的词的向量长度#K: [64,10,300] 拆分多组注意力 - [64,10,6,50] 转置得到 - [64,6,10,50]#转置是为了把注意力的数量6放在前面把10和50放在后面方便下面计算QQ.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)KK.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)VV.view(bsz,-1,self.n_heads,self.hid_dim//self.n_heads).permute(0,2,1,3)#Q乘以K的转置除以scale#[64,6,12,50]*[64,6,50,10][64,6,12,10]#attention:[64,6,12,10]attentiontorch.matmul(Q,K.permute(0,1,3,2))/self.scale#如果mask不为空那么就把mask为0的位置的attention分数设置为-1e10,这里用‘0’来指示哪些位置的词向量不能被attention到,比如padding位置ifmaskisnotNone:attentionattention.masked_fill(mask0,-1e10)#第二步:计算上一步结果的softmax再经过dropout,得到attention#注意这里是对最后一维做softmax也就是在输入序列的维度做softmax#attention: [64,6,12,10]attentiontorch.softmax(attention,dim-1)#第三步,attention结果与V相乘得到多头注意力的结果#[64,6,12,10] * [64,6,10,50] [64,6,12,50]# x: [64,6,12,50]xtorch.matmul(attention,V)#因为query有12个词所以把12放在前面把50和6放在后面方便下面拼接多组的结果#x: [64,6,12,50] 转置 - [64,12,6,50]xx.permute(0,2,1,3).contiguous()#这里的矩阵转换就是把多头注意力的结果拼接起来#最后结果就是[64,12,300]# x:[64,12,6,50] - [64,12,300]xx.view(bsz,-1,self.n_heads*(self.hid_dim//self.n_heads))xself.fc(x)returnxclassFeedforward(nn.Module):def__init__(self,d_model,d_ff,dropout0.1):super(Feedforward,self).__init__()#两层线性映射和激活函数self.linear1nn.Linear(d_model,d_ff)self.dropoutnn.Dropout(dropout)self.linear2nn.Linear(d_ff,d_model)defforward(self,x):xtorch.nn.functional.relu(self.linear1(x))xself.dropout(x)xself.linear2(x)returnxclassPositionalEncoding(nn.Module):实现位置编码def__init__(self,d_model,dropout0.1,max_len5000):super(PositionalEncoding,self).__init__()self.dropoutnn.Dropout(pdropout)# 初始化Shape为(max_len, d_model)的PE (positional encoding)petorch.zeros(max_len,d_model).to(device)# 初始化一个tensor [[0, 1, 2, 3, ...]]positiontorch.arange(0,max_len).unsqueeze(1)# 这里就是sin和cos括号中的内容通过e和ln进行了变换div_termtorch.exp(torch.arange(0,d_model,2)*-(math.log(10000.0)/d_model))pe[:,0::2]torch.sin(position*div_term)# 计算PE(pos, 2i)pe[:,1::2]torch.cos(position*div_term)# 计算PE(pos, 2i1)pepe.unsqueeze(0)# 为了方便计算在最外面在unsqueeze出一个batch# 如果一个参数不参与梯度下降但又希望保存model的时候将其保存下来# 这个时候就可以用register_bufferself.register_buffer(pe,pe)defforward(self,x): x 为embedding后的inputs例如(1,7, 128)batch size为1,7个单词单词维度为128 # 将x和positional encoding相加。xxself.pe[:,:x.size(1)].requires_grad_(False)returnself.dropout(x)classEncoderLayer(nn.Module):def__init__(self,d_model,n_heads,d_ff,dropout0.1):super(EncoderLayer,self).__init__()#编码器层包含自注意机制和前馈神经网络self.self_attnMultiHeadAttention(d_model,n_heads)self.feedforwardFeedforward(d_model,d_ff,dropout)self.norm1nn.LayerNorm(d_model)self.norm2nn.LayerNorm(d_model)self.dropoutnn.Dropout(dropout)defforward(self,x,mask):#自注意力机制atten_outputself.self_attn(x,x,x,mask)xxself.dropout(atten_output)xself.norm1(x)#前馈神经网络ff_outputself.feedforward(x)xxself.dropout(ff_output)xself.norm2(x)returnxclassDecoderLayer(nn.Module):def__init__(self,d_model,n_heads,d_ff,dropout0.1):super(DecoderLayer,self).__init__()# 解码器层包含自注意力机制、编码器-解码器注意力机制和前馈神经网络self.self_attnMultiHeadAttention(d_model,n_heads)self.enc_attnMultiHeadAttention(d_model,n_heads)self.feedforwardFeedforward(d_model,d_ff,dropout)self.norm1nn.LayerNorm(d_model)self.norm2nn.LayerNorm(d_model)self.norm3nn.LayerNorm(d_model)self.dropoutnn.Dropout(dropout)defforward(self,x,enc_output,self_mask,context_mask):# 自注意力机制attn_outputself.self_attn(x,x,x,self_mask)xxself.dropout(attn_output)xself.norm1(x)# 编码器-解码器注意力机制attn_outputself.enc_attn(x,enc_output,enc_output,context_mask)xxself.dropout(attn_output)xself.norm2(x)# 前馈神经网络ff_outputself.feedforward(x)xxself.dropout(ff_output)xself.norm3(x)returnxclassTransformer(nn.Module):def__init__(self,vocab_size,d_model,n_heads,n_encoder_layers,n_decoder_layers,d_ff,dropout0.1):super(Transformer,self).__init__()# Transformer 模型包含词嵌入、位置编码、编码器和解码器self.embeddingnn.Embedding(vocab_size,d_model)self.positional_encodingPositionalEncoding(d_model)self.encoder_layersnn.ModuleList([EncoderLayer(d_model,n_heads,d_ff,dropout)for_inrange(n_encoder_layers)])self.decoder_layersnn.ModuleList([DecoderLayer(d_model,n_heads,d_ff,dropout)for_inrange(n_decoder_layers)])self.fc_outnn.Linear(d_model,vocab_size)self.dropoutnn.Dropout(dropout)defforward(self,src,trg,src_mask,trg_mask):# 词嵌入和位置编码srcself.embedding(src)srcself.positional_encoding(src)trgself.embedding(trg)trgself.positional_encoding(trg)# 编码器forlayerinself.encoder_layers:srclayer(src,src_mask)# 解码器forlayerinself.decoder_layers:trglayer(trg,src,trg_mask,src_mask)# 输出层outputself.fc_out(trg)returnoutput vocab_size10000d_model128n_heads8n_encoder_layers6n_decoder_layers6d_ff2048dropout0.1devicetorch.device(cpu)transformer_modelTransformer(vocab_size,d_model,n_heads,n_encoder_layers,n_decoder_layers,d_ff,dropout)# 定义输入srctorch.randint(0,vocab_size,(32,10))# 源语言句子trgtorch.randint(0,vocab_size,(32,20))# 目标语言句子src_mask(src!0).unsqueeze(1).unsqueeze(2)# 掩码用于屏蔽填充的位置trg_mask(trg!0).unsqueeze(1).unsqueeze(2)# 掩码用于屏蔽填充的位置# 模型前向传播outputtransformer_model(src,trg,src_mask,trg_mask)print(output.shape)#打印当前的时间fromdatetimeimportdatetimeprint(f当前时间{datetime.now()})输出如下

相关新闻