追踪图中的变压器

发布时间:2026/5/31 3:26:11

追踪图中的变压器 原文towardsdatascience.com/tracing-the-transformer-in-diagrams-95dbeb68160c?sourcecollection_archive---------2-----------------------#2024-11-07具体要输入什么输出什么并且你是如何生成文本的https://medium.com/eric.silberstein?sourcepost_page---byline--95dbeb68160c--------------------------------https://towardsdatascience.com/?sourcepost_page---byline--95dbeb68160c-------------------------------- Eric Silberstein·发表于Towards Data Science ·15 分钟阅读 ·2024 年 11 月 7 日–上周我在听一集关于 Nvidia 的 Acquired播客。这集播客讲到了变压器GPT 中的T以及 21 世纪可能最大的发明之一。在走下 Beacon 街的时候我一边听着一边想着我理解变压器了对吧在训练过程中你会屏蔽掉一些 token你有这些学习如何连接文本中概念的注意力头然后你预测下一个单词的概率。我从 Hugging Face 下载了 LLM 并进行过一些尝试。在早期我使用过 GPT-3还没出现“聊天”功能时。在 Klaviyo我们甚至开发了首批基于 GPT 的生成性 AI 功能之一——我们的主题行助手。很久以前我也曾参与过一个基于旧式语言模型的语法检查器的开发。所以也许吧。变压器是由谷歌的一个团队发明的他们在做自动化翻译比如从英语到德语。它在 2017 年通过那篇现在广为人知的论文《Attention Is All You Need》向世界介绍。我调出了这篇论文并看了图 1https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/09e8fb2cab4c2a77260296d1a6647181.png来自《Attention Is All You Need》的图 1嗯……如果我理解的话那只是最初的粗略了解。当我越来越仔细地看着图表读着论文时我才意识到我并没有完全理解细节。以下是我写下的一些问题在训练过程中输入是英语的标记化句子输出是德语的标记化句子吗在一个训练批次中具体每一项是什么为什么要将输出输入到模型中以及“掩码多头注意力”是如何足以防止通过学习输出结果来作弊的多头注意力究竟是什么损失是如何计算的不可能是先输入源语言句子翻译整个句子再计算损失这样不合理。训练后究竟该输入什么来生成翻译为什么有三条箭头指向多头注意力模块我敢肯定对于两类人来说这些问题既简单又显得幼稚。第一类是那些已经在使用类似模型例如 RNN、编码器-解码器做类似事情的人。当他们读到这篇论文时肯定立刻就理解了 Google 团队的成就以及他们是如何做到的。第二类是过去七年里意识到 Transformer 重要性的更多人并且花时间学习其中的细节。好吧我想学习这个于是我觉得最好的方法是从头开始构建模型。我很快就迷失了方向决定追踪别人写的代码。我找到了一份很棒的笔记它解释了这篇论文并在 PyTorch 中实现了模型。我复制了代码并训练了模型。我将所有内容输入、批次、词汇、维度都做得非常小这样我就可以追踪每一步发生了什么。我发现在图表上标注维度和张量帮助我理清了思路。到我完成时我已经对上面所有的问题有了相当不错的答案接下来我会在讲解图表之后回答它们。这是我整理过的笔记版本。本部分的内容是为了训练一个单一的、非常小的批次这意味着不同图表中的所有张量都是一组的。为了让内容更易于跟随并借鉴了笔记中的一个想法我们将训练模型来复制标记。例如一旦训练完成“dog run” 应该翻译成 “dog run”。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/26ed6a4129b8375201b96f830ac22a33.png换句话说https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/de8a85de6f19d7b8e64e7bc1bf538b4d.png下面试图用文字解释一下图表中张量维度以紫色显示到目前为止的含义https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b2cea2931c64c5377a07ba2f7a7a619d.png其中一个超参数是d-model在论文中的基础模型中它是 512。这个例子中我设定为 8。这意味着我们的嵌入向量长度为 8。这里再次展示主图标记了多个地方的维度https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f43ecc3439ffa117a867954961932e90.png让我们放大看看编码器的输入https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/7460fef236e794002d6f3e98912d05e8.pnghttps://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/040c9af92d654c390791d91f6cae1e7e.png图中显示的大部分模块加法与归一化、前馈神经网络、最终线性变换只作用于最后一个维度8。如果仅仅是这样的话那么模型只能利用序列中单一位置的信息来预测单一位置。某个地方必须“混合”位置之间的信息这个魔法发生在多头注意力模块中。让我们放大查看编码器中的多头注意力模块。在接下来的图示中请记住在我的例子中我将超参数h头数设置为2。在论文中的基础模型中它是 8。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/4cb5d05dc31aae173d6351e44e372b7e.png来自《Attention Is All You Need》的图 2带有作者注释(2,3,8)是如何变成(2,2,3,4)的我们进行了线性变换然后将结果拆分成头数8 / 2 4并重新排列张量的维度使得我们的第二个维度就是头。让我们来看一些实际的张量https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/245a1c4c1f62b88cbbd0470faf87ee5b.png我们仍然没有做任何会在位置之间混合信息的操作。那将发生在接下来的缩放点积注意力模块中。维度“4”和维度“3”最终会接触在一起。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/50b88c4d8fc98f9b2eb44e751463caf1.png来自《Attention Is All You Need》的图 2带有作者注释让我们看看这些张量但为了更容易理解我们只关注批次中的第一个项目和第一个头。换句话说就是 Q[0,0]K[0,0]等等。其他三个头也会进行同样的操作。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fe5e540d35d0f111c4b2d70a834c1bd6.png让我们看一下软最大输出与 V 之间的最终矩阵乘法https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/09e63c393da6f1a8f830b3ea6646d45a.png从最开始开始回溯我们可以看到在这次乘法之前V 中的三个位置一直都是独立操作的一直到我们原始句子“ dog run”中。这次乘法第一次将来自其他位置的信息混合在一起。回到多头注意力的示意图我们可以看到concat 操作将每个头的输出重新组合在一起因此每个位置现在由长度为 8 的向量表示。注意concat 后但在线性变换前的张量中的1.8和**-1.1与上面显示的经过缩放点积注意力后批次中第一个项目、第一头第一个位置的向量中的1.8和-1.1**相匹配。接下来的两个数字也匹配只不过它们被省略号隐藏了。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/9c1ba8f87c1319445aff846595fac271.png现在让我们回到整个编码器的视图https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/5abbdf515d94462e1db8e6a8d9b89627.png起初我以为我需要详细追踪前馈网络块。论文中称它为“位置-wise 前馈网络”我以为这意味着它可能会将信息从一个位置传递到右侧的其他位置。然而事实并非如此。“位置-wise”意味着它在每个位置上独立运算。它对每个位置进行线性变换从 8 个元素变换到 32 个元素然后进行 ReLU取 0 和数字中的最大值接着再做一次线性变换回到 8 个元素。这是在我们的小例子中。在论文中的基础模型中它从 512 变到 2048再回到 512。这里有很多参数可能是学习发生的主要地方前馈网络的输出回到2,3,8。先暂时离开我们的简化模型看看论文中基础模型里的编码器是怎样的。输入和输出的维度匹配真是太好了https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/bfb4c198389ce4988f4c2ef1a25f003f.png现在让我们拉远视角看看解码器。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/4693f18a10a7f0c7d5278fa79359f84f.png我们不需要追踪解码器的绝大部分内容因为它和我们刚刚在编码器端看到的非常相似。然而我标记为A和B的部分是不同的。A不同是因为我们做了掩蔽的多头注意力。这应该是避免在训练时“作弊”的关键。B稍后我们会再回到。首先让我们隐藏内部细节保持对解码器输出结果的大致图景。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/39a6c1f6a5b02411b6791edc07a68b6f.png为了更好地强调这一点假设我们的英文句子是“she pet the dog”而翻译成 Pig Latin 后的句子是“eshay etpay ethay ogday”。如果模型已经有了“eshay etpay ethay”并且正在试图推测下一个词“ogday”和“atcay”都是高概率选择。考虑到完整英文句子“she pet the dog”的上下文模型应该能够选择“ogday”。然而如果模型在训练期间能够看到“ogday”它就不需要通过上下文来预测它只需要学会复制。让我们看看掩码是如何做到这一点的。我们可以跳过一些步骤因为A的第一部分和之前一样都是应用线性变换并将东西分割成头部。唯一不同的是进入缩放点积注意力部分的维度是2,2,2,4而不是2,2,3,4因为我们原始的输入序列长度是 2。这里是缩放点积注意力部分。正如我们在编码器端所做的那样我们只看批次中的第一个项目和第一个头部。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d07ad1d278dca4dbffcae5da009037af.png这次我们有一个掩码。让我们看看 softmax 的输出和 V 之间的最终矩阵乘法https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e988447d3a58a2c9be0ff1169542fc14.png现在我们准备好看看B解码器中的第二个多头注意力块。与其他两个多头注意力块不同我们并没有输入三个相同的张量因此我们需要思考 V、K 和 Q 分别代表什么。我用红色标出了输入。可以看到V 和 K 来自编码器的输出并且维度是2,3,8。Q 的维度是2,2,8。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f093df7fc5079f0340500b5cfd65ea92.png和之前一样我们跳到缩放点积注意力部分。V 和 K 的维度是2,2,3,4——批量中有两个项目两个头三个位置长度为四的向量而 Q 的维度是2,2,2,4。这很合理但也有些令人困惑。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ce9a1ee8a761e37d6139e0ce59702fcd.png即使我们是在“读取”编码器输出其中“序列”长度为三所有的矩阵计算也能顺利进行我们最终得到了所需的维度2,2,2,4。让我们来看一下最终的矩阵乘法https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/5e07f351a7453ad38a0f298f6b45b7c6.png每个多头注意力块的输出会被加在一起。让我们跳过到解码器的输出部分并将其转换为预测https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/61b29eaddad690ae6ad30dacaeea6448.png线性变换将我们从**(2,2,8)转换为(2,2,5)**。可以把它看作是反向嵌入除了我们不是从长度为 8 的向量转到单个标记的整数标识符而是转到一个包含 5 个标记的词汇表上的概率分布。在我们这个小示例中数字看起来有点奇怪。在论文中这更像是从大小为 512 的向量转到包含 37,000 个词汇的词汇表当时他们做的是英语到德语的翻译。稍后我们将计算损失。不过即使是匆匆一瞥你也可以大致感知模型的表现。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8141557f54786f52abad6e4552d81d08.png它正确预测了一个标记。毫不奇怪因为这是我们的第一个训练批次而且一切都是随机的。这个图的一个优点是它清晰地表明这是一个多类分类问题。类别是词汇表在这个例子中有 5 个类别这正是我之前感到困惑的地方我们对翻译句子中的每个标记做出并评分一个预测而不是对每个句子做一个预测。让我们进行实际的损失计算。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/929b25a262696ce2b80901d89c52e87c.png比如说如果-3.2 变成了-2.2那么我们的损失将减少到 5.7朝着我们希望的方向移动因为我们希望模型学习到第一个标记的正确预测是 4。上面的图省略了标签平滑。在实际论文中损失计算会平滑标签并使用 KL 散度损失。我认为当没有平滑时这样的损失计算结果与交叉熵相同或相似。下面是与上图相同的图但添加了标签平滑。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/5472835199850c82a2efa3d615797cbd.png让我们也快速看一下在编码器和解码器中学习的参数数量https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6a40913435a4b778cb4ba27d450fac77.png作为一个检查我们的玩具模型中的前馈块有一个从 8 到 32 再回到 8 的线性变换如上所述因此是 8 * 32权重 32偏置 32 * 8权重 8偏置 52。请记住在论文中的基础模型中d-model是 512d-ff是 2048并且有 6 个编码器和 6 个解码器因此会有更多的参数。使用训练好的模型现在让我们看看如何将源语言文本输入并得到翻译后的文本。我这里仍然使用一个玩具模型通过复制 token 来“翻译”但与上面的例子不同这里使用的是大小为 11 的词汇表并且d-model为 512。上面我们有一个大小为 5 的词汇表d-model是 8。首先让我们做一次翻译。然后我们再来看它是如何工作的。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2fa933c517521e3d0104917cda11f368.png第一步是将源句子输入编码器并保留其输出在本例中是一个维度为(1, 10, 512)的张量。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/7868757361bb9e69583a456a57341d56.png第二步是将输出的第一个 token 输入解码器并预测第二个 token。我们知道第一个 token因为它总是 1。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/43afa5686eb5be64c9782fdb9da98ce5.png在本文中他们使用了束搜索beam search束大小为 4这意味着我们将考虑此时概率最高的 4 个 token。为了简化我将改用贪心搜索。你可以把它看作是束大小为 1 的束搜索。因此从图表的顶部读取概率最高的 token 是编号5。上面的输出是概率的对数。概率最高的仍然是最大数字。在这个例子中是-0.0实际上是-0.004但我只显示到一位小数。模型非常确定 5 是正确的exp(-0.004) 99.6%现在我们将[1,5]输入解码器。如果我们在进行束搜索并且束大小为 2我们可以将包含[1,5]和[1,4]下一个最可能的 token的批次输入这样可以得到下一步的结果。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/3cf05eff5a17c139511dfce64664eb24.png现在我们将[1,5,4]输入https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fcbe3b72d487fe2611d647531bd68e66.png并输出3。一直这样进行直到我们得到一个表示句子结束的 token在我们的示例词汇表中不存在或者达到最大长度。回到上面的问题现在我大致可以回答我最初的问题了。在训练过程中输入是用英语分词的句子输出是用德语分词的句子吗是的大致如此。训练批次中的每个项目究竟是什么每个项目对应一个翻译后的句子对。项目的“x”有两部分。第一部分是源句子的所有 token。第二部分是目标句子的所有 token除了最后一个。项目的“y”标签是目标句子的所有 token除了第一个。由于源句子和目标句子的第一个 token 总是所以我们并没有浪费或丢失任何训练数据。有一点比较微妙如果这是一个分类任务例如模型需要接收一张图像并输出一个类别比如房子、汽车、兔子等我们会认为批次中的每个项目都会对损失计算贡献一个“分类”。然而在这里批次中的每个项目将会对损失计算贡献目标句子的 tokens 数量 — 1个“分类”。为什么要将输出数据输入模型且“掩码多头注意力”又如何足以防止模型通过学习输出直接生成输出结果你将输出数据馈送给模型以便模型能够基于源句子的意思和目前已翻译的单词预测翻译。虽然模型中有很多事情在进行但信息在各个位置之间传递的唯一时机是在注意力步骤中。尽管我们确实将翻译后的句子输入解码器但第一次注意力计算使用掩码将所有超出当前预测位置的信息清零。什么是多头注意力我可能应该先问一下什么是注意力机制因为它是更核心的概念。多头注意力意味着将向量切分成若干组对每组进行注意力计算然后再将这些组合并起来。例如如果向量的大小是 512且有 8 个头那么注意力将独立地在 8 个组上进行每组包含一个完整批次的所有位置每个位置的向量大小为 64。如果你稍微思考一下你会发现每个头可以学习集中注意力于某些相关的概念正如那些著名的可视化展示所示头部会学习代词指代的是哪个词。损失究竟是如何计算的不可能是将源语言句子翻译完整个句子后再计算损失这样不合理吧。对的。我们不是一次性翻译整个句子然后计算整个句子的相似度或类似的东西。损失的计算方式与其他多分类问题类似。类别就是我们词汇表中的 token。诀窍在于我们独立地预测目标句子中每个 token 的类别且仅使用此时应当拥有的信息。标签是我们目标句子中的实际 token。通过使用预测和标签我们利用交叉熵计算损失。实际上我们对标签进行了“平滑”以考虑到它们不是绝对的有时同义词也能起到同样的作用。训练完成后生成翻译时究竟输入什么呢你不能直接输入某些内容并让模型在一次评估中输出翻译结果。你需要多次使用模型。首先将源句子输入到模型的编码器部分得到表示句子含义的编码版本这种表示是以某种抽象、深层次的方式进行的。然后将该编码信息和起始标记输入到解码器部分这样你就可以预测目标句子的第二个标记。接着你将和第二个标记输入预测第三个标记。如此反复直到你得到完整的翻译句子。实际上你会考虑每个位置多个高概率的标记每次输入多个候选序列并根据总概率和长度惩罚选择最终的翻译句子。为什么有三条箭头指向多头注意力块我猜有三个原因。1展示解码器中第二个多头注意力块的输入部分来自编码器和解码器前一个块的输入。2暗示注意力算法是如何工作的。3暗示每个输入都在实际进行注意力计算之前经历独立的线性变换。结论这太美妙了如果它不这么有用我可能不会这么想。我现在能理解人们第一次看到这个工作原理时的感受。这个优雅且可训练的模型用极少的代码就能表达学会了如何翻译人类语言并打败了那些花费几十年构建的复杂机器翻译系统。它令人惊叹、聪明且难以置信。你可以看到下一步就是抛开翻译句对开始将这种技术应用到互联网上的每一段文字——大型语言模型LLM由此诞生我猜上面有一些错误请告诉我。除非另有说明所有图片均由作者提供或为作者在Attention Is All You Need中对图示的注释。

相关新闻