参考:从零构建大模型——第 3 章 编码自注意力机制

本文代码:ch03/01_main-chapter-code/ch03.ipynb

编码注意力机制

本章将要实现的不同注意力机制。我们将从一个简化版本的自注意力机制开始,然后逐 步加入可训练的权重。因果注意力机制在自注意力的基础上增加了额外掩码,使得大语 言模型可以一次生成一个单词。最后,多头注意力将注意力机制划分成多个头,从而使 模型能够并行捕获输入数据的各种特征

1. 长序列建模中的问题

注意力机制前的架构局限:编码器 - 解码器 RNN 问题总结

语言翻译的核心挑战与早期解决方案

  1. 翻译核心难题:源语言与目标语言语法结构不同,无法通过“逐词翻译”实现准确转换(如图 3-3)。
  2. 早期主流架构:采用“编码器 - 解码器”双模块深度神经网络,应对语法差异问题:
    • 编码器:读取并处理整个输入文本,生成包含文本含义的中间表示;
    • 解码器:基于编码器的中间表示,逐字生成翻译后的目标语言文本。
  3. 具体实现模型:在 Transformer 出现前,循环神经网络(RNN) 是该架构的主流载体,因 RNN 可将前一步输出作为当前步输入,天然适配文本这类序列数据。

在将文本从一种语言翻译成另一种语言 (比如从德语翻译成英语) 时,不能仅仅逐词翻 译。相反,翻译过程需要理解上下文和进行语法对齐

编码器 - 解码器 RNN 的工作逻辑:

  1. 编码器流程:逐步处理输入文本,每一步更新“隐藏状态”(隐藏层内部值),最终将整个输入文本的含义压缩到“最终隐藏状态” 中(可理解为文本的嵌入向量)。
  2. 解码器流程:以编码器的“最终隐藏状态”为初始依据,逐字生成输出文本,同时每一步更新自身隐藏状态,为下一个词的预测提供上下文信息(如图 3-4)。

在 Transformer 模型出现之前,编码器 - 解码器结构的 RNN 是机器翻译的常见选择。编码 器将源语言的一串词元序列作为输入,并通过隐藏状态 (一个中间神经网络层) 编码整 个输入序列的压缩表示。然后,解码器利用其当前的隐藏状态开始逐个词元进行翻译

因为理解这次讨论并不需 要了解 RNN 的具体工作原理。这里我们主要关注编码器 - 解码器架构的基本概念。

核心缺陷:上下文丢失,难以处理长距离依赖。这是编码器 - 解码器 RNN 最关键的局限,直接推动了注意力机制的诞生:

  1. 信息瓶颈:编码器仅通过“单一最终隐藏状态”传递全部输入信息,解码器在生成过程中无法直接访问编码器的早期隐藏状态,只能依赖当前自身的隐藏状态。
  2. 实际问题:当处理复杂句子或长文本时,“单一最终隐藏状态”难以承载全部上下文信息,易导致长距离依赖关系丢失(例如句子前后端的指代、逻辑关联无法有效传递),影响翻译等任务的准确性。

注意力机制(含自注意力)核心信息归纳总结

注意力机制的诞生背景:解决 RNN 的长文本处理痛点

  1. RNN 的核心缺陷: RNN 在处理长文本时效果差,关键问题在于:编码器需将全部输入信息压缩到单一隐藏状态中传递给解码器,导致解码器无法直接访问输入序列中靠前的词元,易丢失长距离上下文依赖(仅在翻译短句时表现良好)。

  2. 解决方案:Bahdanau 注意力机制(2014 年提出)

    • 核心改进:对编码器 - 解码器 RNN 进行修改,让解码器在每个解码步骤中可选择性访问输入序列的所有部分(而非仅依赖单一隐藏状态,见图 3-5)。
    • 关键逻辑:生成特定输出词元时,通过“注意力权重”区分输入词元的重要性,让模型聚焦于与当前输出关联度更高的输入信息。
    • 说明:Bahdanau 注意力机制是 RNN 框架下的方法,本书不深入其具体实现,重点关注其“选择性关注”的核心思想。

图 3-5 通过使用注意力机制,网络的生成文本解码器部分可以有选择地访问所有输入词元。这 意味着对于生成一个特定的输出词元,某些输入词元比其他输入词元更重要。这种重要 性由注意力权重决定,我们将在后面计算这些权重。需要注意的是,这里展示的是注意 力机制的基本概念,并未描述 Bahdanau 机制 (一种 RNN 方法,但其超出了本书的讨论 范畴) 的具体实现

从注意力到自注意力:Transformer 架构的关键创新

  1. 技术迭代:RNN 并非必需: 2017 年(Bahdanau 机制提出 3 年后),研究人员突破 RNN 框架,提出Transformer 架构,其核心创新之一是引入“自注意力机制”——灵感源自 Bahdanau 注意力的“选择性关注”思想,但摆脱了 RNN 的序列依赖限制。

  2. 自注意力机制的核心逻辑

    • 核心能力:允许输入序列中的每个位置,关注同一序列中的所有其他位置,通过权衡不同位置的重要性,生成更高效的序列表示。
    • 地位:是当代基于 Transformer 的 LLM(如 GPT 系列)的关键组成部分,直接决定模型对上下文依赖的捕捉能力。

自注意力

自注意力是 Transformer 模型中的一种机制,它通过允许一个序列中的每个位置与同一序列 中的其他所有位置进行交互并权衡其重要性,来计算出更高效的输入表示。

通过自注意力机制关注输入的不同部分

自注意力机制中的“自”

在自注意力机制中,“自”指的是该机制通过关联单个输入序列中的不同位置来计算注意力权重的能力。它可以评估并学习输入本身各个部分之间的关系和依赖,比如句子中的单词或图像中的像素。 这与传统的注意力机制形成对比。传统的注意力机制关注的是两个不同序列元素之间的关系,比如在序列到序列模型中,注意力可能在输入序列和输出序列之间。

自注意力机制的“自”,本质是对“单一输入序列内部不同位置”计算注意力权重,能自主学习序列自身各部分的关系与依赖(如句子中单词间的关联、图像中像素间的关联)。

其与传统注意力机制的核心区别在于:

  • 自注意力:关注“同一序列内”元素的关系(如输入序列内部单词互相关联);
  • 传统注意力:关注“两个不同序列间”元素的关系(如翻译任务中输入序列与输出序列的关联,见图 3-5)。

2. 没有可训练权重的简单自注意力机制

让我们首先实现一个不包含任何可训练权重的简化的自注意力机制变体,如图 3-7 所示。目 标是在引入可训练权重之前,阐明自注意力中的一些关键概念。

为理解基础逻辑,先从“无训练权重”的简化版本入手,核心目标是为输入序列中每个元素生成上下文向量(context vector)——融合序列所有元素信息的增强型嵌入向量。以下以输入句子“Your journey starts with one step.”(已转换为 3 维嵌入向量)为例,分步解析。

自注意力机制的目标是为每个输入元素计算一个上下文向量,该向量结合了其他所有输 入元素的信息。在该图的示例中,我们计算了上下文向量 。计算 时,各个输入元 素的重要性或贡献度由注意力权重 决定。这些注意力权重是针对输入元素 及其他所有输入元素计算的。

图 3-7 显示了一个输入序列,记为 ,它由 个元素组成,分别表示为 。这个序列通常代表文本(如一个句子),并且该文本已经被转换为词元嵌入。

例如,考虑输入文本 “Your journey starts with one step.” 在这种情况下,文本序列中的每个元素(如 )都对应一个 维的嵌入向量,该向量代表了一个特定的词元,比如 “Your”。在图 3-7 中,这些输入向量被表示为三维嵌入。

在自注意力机制中,我们的目标是为输入序列中的每个元素 计算上下文向量 上下文向量(context vector)可以被理解为一种包含了序列中所有元素信息的嵌入向量。

为了说明这个概念,我们重点关注第二个输入元素 (对应于词元 “journey”)的嵌入向量及其对应的上下文向量 ,如图 3-7 底部所示。这个增强的上下文向量 是一个嵌入,包含了关于 及其他所有输入元素( )的信息。

上下文向量在自注意力机制中起着关键作用。它们的目的是通过结合序列中其他所有元素的信息,为输入序列 (如一个句子) 中的每个元素创建丰富表示,如图 3-7 所示。这在大语言模型 中至关重要,因为这些模型需要理解句子中单词之间的关系和相关性。稍后我们将添加可训练的 权重,以帮助大语言模型学习如何构建这些上下文向量,使它们能用于生成下一个词元。但首先, 我们将实现一个简化的自注意力机制,逐步计算这些权重和上下文向量。

1. 输入准备:词元嵌入向量

输入为 6 个词元的 3 维嵌入张量(PyTorch 实现):

import torch 
inputs = torch.tensor( 
    [[0.43, 0.15, 0.89], # Your (x^1) 
     [0.55, 0.87, 0.66], # journey (x^2,后续作为查询向量) 
     [0.57, 0.85, 0.64], # starts (x^3) 
     [0.22, 0.58, 0.33], # with (x^4) 
     [0.77, 0.25, 0.10], # one (x^5) 
     [0.05, 0.80, 0.55]] # step (x^6) 
)

2. 步骤 1:计算注意力分数(Attention Score)

实现自注意力机制的第一步是计算中间值 ,即所谓注意力分数,如图 3-8 所示。

由于空间有限,图中显示的前几个输入张量的数值是截断后的版本,比如 0.87 被截断为 0.8。在这个截断 版本中,单词“journey”和“starts”的嵌入可能会由于随机原因看起来相似。

图 3-8 本节的总体目标是通过将第二个输入元素 作为查询,来演示上下文向量 的计算过程。该图展示了第一个中间步骤,即通过点积计算查询 与其他所有输入元素之间的注意力分数 (请注意,为了减少视觉混乱,数值被截断为小数点后一位数字)。

图 3-8 展示了如何计算查询词元与每个输入词元之间的中间注意力分数。我们通过计算查询 词元 与其他所有输入词元的点积来确定这些分数:

注意力分数

注意力分数衡量“查询元素”与序列中其他元素的关联度,通过向量点积计算(点积越大,元素相似度/对齐度越高,关联越强)。点积不仅被视为一种将两个向量转化为标量值的数学工具,而且也是度量相似度的一种方 式,因为它可以量化两个向量之间的对齐程度: 点积越大,向量之间的对齐程度或相似度就 越高。在自注意机制中,点积决定了序列中每个元素对其他元素的关注程度: 点积越大,两 个元素之间的相似度和注意力分数就越高。

x^2(“journey”)为查询向量,计算其与所有输入元素的注意力分数:

query = inputs[1]  # 选择x^2作为查询
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)  # 点积计算分数
print(attn_scores_2)
# 输出结果:tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
  • 结果解读:x^2 与自身(x^2)的分数最高(1.4950),与 x^3(“starts”)分数次高(1.4754),说明二者关联最紧密。

3. 步骤 2:注意力分数归一化(获注意力权重)

在下一步中,如图 3-9 所示,我们将对先前计算的每个注意力分数进行归一化处理。归一化的目的是让注意力权重总和为 1,便于解释为“相对重要性”,且提升训练稳定性。主流方法是Softmax 函数(保证权重为正,且极值处理更优)。以下是一种实现这一归一化步骤的简单方法。

(1)简单归一化(求和归一化,仅作对比)

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())
# 输出:tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656]),总和=1.0

在计算完与输入查询 相关的注意力分数 之后,下一步是通过对这些注意力分数进行归一化,来获得注意力权重

(2)Softmax 归一化(实际应用首选)

在实际应用中,使用 softmax 函数进行归一化更为常见,而且是一种更可取的做法。这种方法更 好地处理了极值,并在训练期间提供了更有利的梯度特性。以下是用于归一化注意力分数的 softmax 函数的基础实现:

# 基础实现(这种简单的 softmax 实现在处理大输入值或小输入值时可能会遇到数值稳定性,实际用PyTorch内置函数)
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)
 
# PyTorch优化实现(推荐)
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
# 输出:tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581]),总和=1.0
  • 另外,softmax 函数可以保证注意力权重总是正值,这使得输出可以被解释为概率或相对重要性, 其中权重越高表示重要程度越高。
  • 结果解读:x^2 的权重最高(0.2379),x^3 次之(0.2333),说明在生成 x^2 的上下文向量时,二者贡献最大。

4. 步骤 3:计算上下文向量(加权求和)

现在我们已经计算了归一化的注意力权重,接下来进入最后一步。如图 3-10 所示,通过将嵌入的输入词元 与相应的注意力权重相乘,再将得到的向量求和来计算上下文向量 。因此,上下文向量 是所有输入向量的加权总和,通过将每个输入向量与其对应的注意力权重相乘而获得:

上下文向量是“所有输入元素嵌入向量按注意力权重加权求和”的结果,融合了序列全部元素的关键信息。

x^2 的上下文向量 z^2 计算为例:

query = inputs[1] # 第二个输入词元 作为查询向量
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i  # 加权求和
print(context_vec_2)
# 输出结果:tensor([0.4419, 0.6515, 0.5683])
  • 意义:z^2 不仅包含 x^2 自身的信息,还融合了其他词元(如 x^3“starts”、x^6“step”)的关联信息,成为更丰富的语义表示。

在计算并归一化注意力分数以获取查询 的注意力权重之后,最后一步是计算上下文 向量 。该上下文向量是所有输入向量 按注意力权重加权的组合

5. 计算所有输入词元的注意力权重

到目前为止,我们已经计算了输入 2 的注意力权重和上下文向量,如图 3-11 中突出显示的 那一行所示。接下来,我们将扩展这个计算过程,以计算所有输入的注意力权重和上下文向量。

突出显示的那一行展示了将第二个输入元素作为查询时的注意力权重。本节将对获取 其他所有注意力权重的计算过程进行概括 (请注意,该图中的数值被截断为小数点后两 位数,以减少视觉混乱。每行中的值加起来应为 1.0 或 100%)

如图 3-12 所示,我们遵循与之前相同的 3 个步骤,唯一的区别是在代码中进行了一些修改, 以计算所有上下文向量,而不仅仅是第二个上下文向量

为高效计算序列中所有元素对的注意力分数,需从逐个计算升级为批量处理,核心是利用矩阵乘法替代嵌套循环,提升计算效率。

1. 嵌套循环实现(基础方法)

通过两层循环遍历所有输入元素对,计算点积得到注意力分数矩阵(形状为 为序列长度):

attn_scores = torch.empty(6, 6)  # 6个元素的序列,生成6×6分数矩阵
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)  # 计算x_i与x_j的点积
 
# 输出结果(6×6矩阵,每行i对应x_i与所有x_j的注意力分数):
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

在第 (1) 步中,我们添加一个额外的 for 循环来计算所有输入对的点积

张量中的每个元素表示每对输入之间的注意力分数,如图 3-11 所示。请注意,图中的值已经归 一化,这就是它们与之前张量中的未归一化注意力分数不同的原因。我们稍后会处理归一化的 问题。

2. 矩阵乘法优化(高效方法)

在计算前面的注意力分数张量时,我们使用了 Python 中的 for 循环。然而,for 循环通常 较慢,因此可以使用矩阵乘法来得到相同的结果:输入序列可表示为矩阵 (形状 为嵌入维度),通过矩阵乘法(转置),一次性得到所有元素对的注意力分数,结果与嵌套循环完全一致,但计算效率显著提升:

attn_scores = inputs @ inputs.T  # @表示矩阵乘法,等价于torch.matmul(inputs, inputs.T)
 
# 输出与循环法完全相同,验证优化正确性

3. 批量计算注意力权重:行归一化

对注意力分数矩阵的每一行进行 Softmax 归一化,使每行权重总和为 1(每行对应一个元素的注意力权重分布):

attn_weights = torch.softmax(attn_scores, dim=-1)  # dim=-1表示对最后一维(列)归一化
 
# 输出注意力权重矩阵(6×6,每行和为1):
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
 
# 验证每行和为1:
print("All row sums:", attn_weights.sum(dim=-1))  # 输出全为1.0000

4. 批量计算所有上下文向量:加权求和

通过注意力权重矩阵与输入矩阵的矩阵乘法,一次性得到所有元素的上下文向量(每行为一个元素的上下文向量):

all_context_vecs = attn_weights @ inputs  # 权重矩阵(6×6) × 输入矩阵(6×3) = 上下文矩阵(6×3)
 
# 输出所有上下文向量(6行3列,每行对应一个元素的上下文向量):
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],  # 与之前单独计算的z^(2)完全一致
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])
 
# 验证第2行与之前单独计算的z^(2)一致:
print("Previous 2nd context vector:", context_vec_2)  # 输出tensor([0.4419, 0.6515, 0.5683])

3. 实现带可训练权重的自注意力机制

接下来,我们将实现在原始 Transformer 架构、GPT 模型和大多数其他流行的大语言模型中 使用的自注意机制。这种自注意力机制也被称为缩放点积注意力 (scaled dot-product attention)。

带有可训练权重的自注意力机制是建立在先前概念之上的: 我们希望将上 下文向量计算为某个特定输入元素对于序列中所有输入向量的加权和。你会看到,带有可训练权 重的自注意力机制与我们之前实现的基础自注意力机制只有些微的不同。最显著的区别是这里引入了在模型训练期间更新的权重矩阵。这些可训练的权重矩阵至关 重要,这样模型 (特别是模型内部的注意力模块) 才能学会产生“好的”上下文向量。

逐步计算注意力权重

本节将通过引入 3 个可训练的权重矩阵 ,一步一步地实现自注意力机制。这 3 个矩阵用于将嵌入的输入词元 分别映射为查询向量、键向量和值向量,如图 3-14 所示。

在实现具有可训练权重矩阵的自注意机制的第一步中,我们计算了输入元素 的查询向量()、键向量()和值向量()。与之前类似,我们将第二个输入元素 指定为查询输入。查询向量 是通过第二个输入元素 与权重矩阵 之间的矩阵乘法得到的。同样,我们通过包含权重矩阵 的矩阵乘法得到键向量和值向量。

之前,当我们通过注意力权重计算上下文向量 时,将第二个输入元素 定义为了查询。然后,我们将这一方法推广到了计算所有上下文向量 ,应用于 6 个词的输入句子 “Your journey starts with one step.”

同样,为了便于说明,这里我们只计算一个上下文向量 。之后我们会修改这段代码来计算所有上下文向量。

首先,定义几个变量:

x_2 = inputs[1]  # 第二个输入元素
d_in = inputs.shape[1]  # 输入嵌入维度 d_in=3
d_out = 2  # 输出嵌入维度 d_out=2

请注意,在类 GPT 模型中,输入和输出的维度通常是相同的,但为了便于理解计算过程,这里我们使用不同的输入维度(d_{\text{in}}=3)和输出维度(d_{\text{out}}=2)。

本节详细讲解了如何通过可训练权重矩阵实现自注意力机制,核心步骤如下:

然后,初始化图 3-14 中的 3 个权重矩阵

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

设置 requires_grad=False 以减少输出中的其他项,但如果要在模型训练中使用这些权重矩阵,就需要设置 requires_grad=True,以便在训练中更新这些矩阵。

接下来,计算查询向量、键向量和值向量:

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

因为我们通过 将对应的权重矩阵的列数设置为了 2,所以查询的输出结果是一个二维向量,输出为:

tensor([0.4306, 1.4551])

权重参数与注意力权重

在权重矩阵 中,“权重”是“权重参数”的简称,表示在训练过程中优化的神经网络参数。这与注意力权重是不同的。正如我们已经看到的,注意力权重决定了上下文向量对输入的不同部分的依赖程度(网络对输入的不同部分的关注程度)。总之,权重参数是定义网络连接的基本学习系数,而注意力权重是动态且特定于上下文的值。

虽然目前我们的目标只是计算一个上下文向量 ,但仍然需要所有输入元素的键向量和值向量,因为它们参与了计算相对于查询 的注意力权重(参见图 3-14)。

可以通过矩阵乘法得到所有的键向量和值向量:

keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

从输出中可以看出,我们成功地将 6 个输入词元从三维空间映射到了二维嵌入空间,输出为:

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])

接下来是计算注意力分数,如图 3-15 所示。

计算注意力分数

注意力分数的计算是一种点积计算,与 3.3 节中使用的方法类似。不同之处在于,我们 不是直接计算输入元素之间的点积,而是使用通过各自权重矩阵变换后的查询向量和 键向量进行计算

首先,计算出注意力分数

keys_2 = keys[1]  # 注意,Python从0开始进行索引
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

未归一化的注意力评分结果如下所示:

tensor(1.8524)

同样,可以通过矩阵乘法将这个计算推广到所有的注意力分数:

attn_scores_2 = query_2 @ keys.T  # 给定 query 的全部注意力分数
print(attn_scores_2)

如你所见,经过快速检查,输出中的第二个元素与之前计算的 attn_score_22 一致。

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

计算注意力权重

在计算完注意力分数 后,下一步是使用 softmax 函数对这些分数进行归一化,以获得 注意力权重

现在,我们想要将注意力分数转换为注意力权重,如图 3-16 所示。我们通过缩放注意力分数并应用 softmax 函数来计算注意力权重。不过,此时是通过将注意力分数除以键向量的嵌入维度的平方根来进行缩放(取平方根在数学上等同于以 0.5 为指数进行幂运算)。

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k** 0.5, dim=-1)
print(attn_weights_2)

缩放点积注意力的原理 对嵌入维度进行归一化是为了避免梯度过小,从而提升训练性能。例如,在类 GPT 大语 言模型中,嵌入维度通常大于 1000,这可能导致点积非常大,从而在反向传播时由于 softmax 函数的作用导致梯度非常小。当点积增大时,softmax 函数会表现得更像阶跃函数,导致梯度 接近零。这些小梯度可能会显著减慢学习速度或使训练停滞。 因此,通过嵌入维度的平方根进行缩放解释了为什么这种自注意力机制也被称为缩放点 积注意力机制。

计算上下文向量

在自注意力计算的最后一步,通过注意力权重将所有值向量进行加权求和,从而计算 上下文向量

与计算上下文向量时对输入向量进行加权求和 (参见 3.3 节) 的方式类似,现在通过对值向 量进行加权求和来计算上下文向量。在这里,注意力权重作为加权因子,用于权衡每个值向量的 重要性。和之前一样,可以使用矩阵乘法一步获得输出结果:

context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

所生成的向量内容如下所示: tensor([0.3061, 0.8210]) 到目前为止,我们只计算了一个上下文向量 。在 3.4.2 节中,我们将扩展代码来计算输入序列中 的所有上下文向量。

为什么要用查询、键和值

在注意力机制中,“键”(key)、“查询”(query) 和“值”(value) 这些术语借用自信息检 索和数据库领域,这些领域使用类似的概念来进行信息存储、搜索和检索。 查询类似于数据库中的搜索查询。它代表了模型当前关注或试图理解的项 (比如句子中的 一个单词或词元)。查询用于探测输入序列中的其他部分,以确定对它们的关注程度。 键类似于用于数据库索引和搜索的键。在注意力机制中,输入序列中的每个项 (比如句 子中的每个单词) 都有一个对应的键。这些键用于与查询进行匹配。 在这种背景下,值类似于数据库中键 - 值对中的值。它表示输入项的实际内容或表示。一 旦模型确定哪些键以及哪些输入部分与查询 (当前关注的项) 最相关,它就会检索相应的值。

  • 查询(query):类似搜索关键词,代表当前关注的项
  • 键(key):类似索引,用于与查询匹配
  • 值(value):类似实际内容,根据匹配结果被检索和加权

扩展总结

到目前为止,我们已经完成了多个步骤来计算自注意力的输出。这些步骤主要是为了演示清晰,以便逐步了解每个环节。在实际操作中,为了实现第 4 章中的大语言模型,最好将这些代码组织成一个 Python 类,如代码清单 3-1 所示。

代码清单 3-1 一个简化的自注意力类

import torch.nn as nn
 
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
 
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]** 0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

在这段 PyTorch 代码中,SelfAttention_v1 是一个从 nn.Module 派生出来的类。nn.Module 是 PyTorch 模型的一个基本构建块,它为模型层的创建和管理提供了必要的功能。

  • __init__ 方法初始化了可训练的权重矩阵(W_queryW_keyW_value),这些矩阵用于查询向量、键向量和值向量,每个矩阵将输入维度 d_{\text{in}} 转换为输出维度 d_{\text{out}}
  • 在前向传播过程中,我们通过使用 forward 方法将查询向量和键向量相乘来计算注意力分数(attn_scores),然后使用 softmax 对这些分数进行归一化。最后,我们通过使用这些归一化的注意力分数对值向量进行加权来创建上下文向量。

可以通过以下方式来使用这个类:

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

由于输入包含 6 个嵌入向量,因此我们会得到一个用于保存这 6 个上下文向量的矩阵:

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

通过快速检查,可以看到第 2 行([0.3061, 0.8210])与 3.4.1 节中的 context_vec_2 内容相符。

图 3-18 总结了我们刚刚实现的自注意力机制。

图 3-18 在自注意力机制中,我们用 3 个权重矩阵 ( W_q )、( W_k ) 和 ( W_v ) 来变换输入矩阵 ( X ) 中的输入向量。新方法根据所得查询矩阵(( Q ))和键矩阵(( K ))计算注意力权重矩阵。然后,使用注意力权重矩阵和值矩阵(( V ))计算上下文向量(( Z ))。为了视觉清晰,我们关注具有 ( n ) 个词元的单个输入文本,而不是一批多个输入。因此,在这种情况下,三维输入张量被简化为二维矩阵。这种方法允许更直观地可视化和理解所涉及的过程。为了与后面的图保持一致,注意力矩阵中的值不代表真正的注意力权重(该图中的数值被截断为小数点后两位,以减少视觉混乱。每行中的值加起来应为 1.0 或 100%)。

自注意力机制包含了可训练的权重矩阵 ( W_q )、( W_k ) 和 ( W_v )。这些矩阵将输入数据转换为查询向量、键向量和值向量,这些组件在注意力机制中至关重要。随着模型在训练中接触更多数据,它会调整这些可训练的权重,后续章节会对此进行介绍。

代码清单 3-2 一个使用 PyTorch 线性层的自注意力类

可以通过使用 PyTorch 的 nn.Linear 层来进一步优化 SelfAttention_v1 的实现,当偏置单元被禁用时,nn.Linear 层可以有效地执行矩阵乘法。相比手动实现 nn.Parameter(torch.rand(...)),使用 nn.Linear 的一个重要优势是它提供了优化的权重初始化方案,从而有助于模型训练的稳定性和有效性,如代码清单 3-2 所示。

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
 
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
 
        context_vec = attn_weights @ values
        return context_vec

可以像使用 SelfAttention_v1 一样使用 SelfAttention_v2

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

输出结果如下所示:

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

请注意,SelfAttention_v1SelfAttention_v2 因为使用了不同的初始权重矩阵而给出了不同的输出,这是由 nn.Linear 使用了一个更复杂的权值初始化方案所导致的。

接下来,我们将改进自注意力机制,重点是在机制中引入因果机制和多头机制。因果机制的作用是调整注意力机制,防止模型访问序列中未来的信息,这在语言建模等任务中尤为重要,因为每个词的预测只能依赖之前出现的词。

多头机制涉及将注意力机制分成多个“头”。每个头会学习数据的不同特征,使模型能够在不同的位置同时关注来自不同表示子空间的信息。这能够提升模型在复杂任务中的性能。

4. 利用因果注意力隐藏未来词汇

对于许多大语言模型任务,你希望自注意力机制在预测序列中的下一个词元时仅考虑当前位 置之前的词元。因果注意力 (也称为掩码注意力) 是一种特殊的自注意力形式。它限制模型在处 理任何给定词元时,只能基于序列中的先前和当前输入来计算注意力分数,而标准的自注意力机 制可以一次性访问整个输入序列。 现在,我们将通过修改标准自注意力机制来创建因果注意力机制,这是在后续章节中开发 大语言模型的关键步骤。要在类 GPT 模型中实现这一点,对于每个处理的词元,需要掩码当前 词元之后的后续词元,如图 3-19 所示。我们会掩码对角线以上的注意力权重,并归一化未掩码 的注意力权重,使得每一行的权重之和为 1。稍后,我们将通过代码来实现这一掩码和归一化 过程。

在因果注意力机制中,我们掩码了对角线以上的注意力权重,以确保在计算上下文向 量时,大语言模型无法访问未来的词元。例如,对于第 2 行的单词“journey”,仅保留 当前词 (“journey”) 和之前词 (“Your”) 的注意力权重

因果注意力的掩码实现

在大语言模型的自回归生成任务(如预测下一个词元)中,需确保模型仅基于“当前及之前的词元”计算注意力,不能访问“未来词元”(否则会破坏生成的逻辑性与公平性)。因果注意力(掩码注意力)就是为实现这一约束而设计的自注意力变体。

  1. 计算标准注意力权重: 先通过 Softmax 得到无约束的注意力权重矩阵(可复用 SelfAttention_v2 的查询、键矩阵):

    queries = sa_v2.W_query(inputs)
    keys = sa_v2.W_key(inputs)
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

    输出为 6×6 的注意力权重矩阵(每行对应一个词元的权重分布)。

  2. 生成下三角掩码: 用 torch.tril 生成“对角线及以下为 1,对角线以上为 0”的掩码,屏蔽未来词元的权重:

    context_length = attn_scores.shape[0]
    mask_simple = torch.tril(torch.ones(context_length, context_length))
    # 输出示例:
    # tensor([[1., 0., 0., 0., 0., 0.],
    #         [1., 1., 0., 0., 0., 0.],
    #         [1., 1., 1., 0., 0., 0.],
    #         [1., 1., 1., 1., 0., 0.],
    #         [1., 1., 1., 1., 1., 0.],
    #         [1., 1., 1., 1., 1., 1.]])
  3. 掩码并重新归一化: 先将掩码与注意力权重逐元素相乘(对角线以上权重置 0),再对每行重新归一化(保证每行和为 1):

    masked_simple = attn_weights * mask_simple
    row_sums = masked_simple.sum(dim=-1, keepdim=True)
    masked_simple_norm = masked_simple / row_sums

    输出为“对角线以上权重为 0,每行和为 1”的因果注意力权重(如第 2 行仅保留“当前词元”和“前一个词元”的权重)。

信息泄露 当我们应用掩码并重新归一化注意力权重时,初看起来,未来的词元 (打算掩码的) 可能 仍然会影响当前的词元,因为它们的值会参与 softmax 计算。然而,关键的见解是,在掩码后 重新归一化时,我们实际上是在对一个较小的子集重新计算 softmax(因为被掩码的位置不参 与 softmax 计算)。 softmax 函数在数学上的优雅之处在于,尽管最初所有位置都在分母中,但掩码和重新归 一化之后,被掩码的位置的效果被消除——它们不会以任何实际的方式影响 softmax 分数。 简而言之,掩码和重新归一化之后,注意力权重的分布就像最初仅在未掩码的位置计算 一样。这保证了不会有来自未来或其他被掩码的词元的信息泄露。

预掩码(负无穷填充)+ Softmax(更高效,推荐)

尽管可以在此时完成对因果注意力的实现,但我们仍然可以进行改进。让我们利用 softmax 函数的数学特性,以更少的步骤更高效地计算掩码后的注意力权重,如图 3-21 所示。

利用 Softmax负无穷(-inf)输入输出 0的特性,直接在 Softmax 前屏蔽未来词元的注意力分数,一步完成掩码与归一化:

  1. 生成上三角掩码(填充 -∞): 用 torch.triu 生成“对角线以上为 1,其余为 0”的掩码,再将这些位置的注意力分数替换为 -inf

    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
    masked_attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)

    输出中,未来词元对应的注意力分数被置为 -inf(如第 1 行,仅第 1 列保留有效分数,其余列为 -inf)。

  2. 直接 Softmax 得到因果权重: 对预掩码的注意力分数应用 Softmax,自动完成归一化且屏蔽未来词元:

    causal_attn_weights = torch.softmax(masked_attn_scores / keys.shape[-1]**0.5, dim=1)

    输出与“方法 1 最终结果”一致,但计算更高效(无需后处理归一化)。

利用 Dropout 掩码额外的注意力权重

dropout 是深度学习中的一种技术,通过在训练过程中随机忽略一些隐藏层单元来有效地“丢 弃”它们。这种方法有助于减少模型对特定隐藏层单元的依赖,从而避免过拟合。需要强调的是, dropout 仅在训练期间使用,训练结束后会被取消。

在 Transformer 架构中,一些包括 GPT 在内的模型通常会在两个特定时间点使用注意力机制 中的 dropout: 一是计算注意力权重之后,二是将这些权重应用于值向量之后。如图 3-22 所示, 我们将在计算注意力权重之后应用 dropout 掩码,因为这是实践中更常见的做法。

下面的代码示例中使用了 50% 的 dropout 率,这意味着掩码一半的注意力权重。(当我们在接下来的章节中训练 GPT 模型时,将使用较低的 dropout 率,比如 10% 或 20%。)为了便于操作,我们首先将 PyTorch 的 dropout 实现应用于一个由 1 组成的 6×6 张量:

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

如你所见,大约有一半的值被置 0 了:

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

在这里创建一个全 1 矩阵。在对注意力权重矩阵应用 50% 的 dropout 率时,矩阵中有一半的元素会随机被置为 0。为了补偿减少的活跃元素,矩阵中剩余元素的值会按 ( 1/0.5 = 2 ) 的比例进行放大。这种放大对于维持注意力权重的整体平衡非常重要,可以确保在训练和推理过程中,注意力机制的平均影响保持一致。

现在,对注意力权重矩阵进行 dropout 操作:

torch.manual_seed(123)
print(dropout(attn_weights))

在处理后的注意力权重矩阵中,部分元素已被置为 0,其余元素则被重新缩放:

tensor([[2.0000, 0.0000, 0 .0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]], grad_fn=<MulBackward0>

请注意,由于操作系统的差异,最终的 dropout 输出可能会有所不同。有关这种不一致性的详细信息,请查看 PyTorch 的 issue tracker。

理解了因果注意力和 dropout 掩码之后,现在我们可以开发一个简洁的 Python 类。这个类的目的是高效地应用这两种技术。

实现一个简化的因果注意力类

为支持数据加载器产生的批量样本,需确保模型能处理三维张量输入(批次维度 + 序列长度 + 嵌入维度)。可通过复制单样本输入模拟批量数据:

# 模拟批量输入(2个样本,每个含6个词元,每个词元为3维嵌入)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)  # 输出:torch.Size([2, 6, 3])

代码清单 3-3 展示了整合因果注意力和 Dropout 的完整类,支持批量处理,是后续多头注意力的基础:

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        # Q、K、V的线性投影层(无偏置可选)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Dropout层(防止过拟合)
        self.dropout = nn.Dropout(dropout)
        # 注册因果掩码为缓冲区(自动随模型移动设备)
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
 
    def forward(self, x):
        b, num_tokens, d_in = x.shape  # b:批次大小,num_tokens:序列长度
        
        # 计算Q、K、V(保留批次维度)
        keys = self.W_key(x)       # 形状:[b, num_tokens, d_out]
        queries = self.W_query(x)  # 形状:[b, num_tokens, d_out]
        values = self.W_value(x)   # 形状:[b, num_tokens, d_out]
        
        # 计算注意力分数(转置K的最后两维,保持批次维度)
        attn_scores = queries @ keys.transpose(1, 2)  # 形状:[b, num_tokens, num_tokens]
        
        # 应用因果掩码(仅保留当前及之前词元的分数)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],  # 截取与序列长度匹配的掩码
            -torch.inf
        )
        
        # 计算注意力权重(缩放+Softmax)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,  # 除以√d_out,防止分数过大
            dim=-1
        )
        
        # 应用Dropout
        attn_weights = self.dropout(attn_weights)
        
        # 计算上下文向量
        context_vec = attn_weights @ values  # 形状:[b, num_tokens, d_out]
        return context_vec

虽然此时所有新增的代码行都应该是熟悉的,但我们在 __init__ 方法中增加了一个 self.register_buffer() 调用。虽然在 PyTorch 中使用 register_buffer 并非所有情况下都是必需的,但在这里具有一些优势。例如,当我们在大语言模型中使用 CausalAttention 类时,缓冲区会与模型一起自动移动到适当的设备 (CPU 或 GPU),这在训练大语言模型时非常重要。这意味着我们无须手动确保这些张量与模型参数在同一设备上,从而避免了设备不匹配的错误。

可以按照之前使用 SelfAttention 类的方式来使用 CausalAttention 类:

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

最终生成的上下文向量是一个三维张量,其中每个词元现在用二维嵌入来表示。

context_vecs.shape: torch.Size([2, 6, 2])
  • register_buffer 调用也是一个新版本 (下文提供了更多信息)
  • 与之前的 SelfAttention_v1 类相比,添加了一个 dropout 层
  • 将维度 1 和 2 转置,将批维度保持在第一个位置 (0)
  • 在 PyTorch 中,带有尾随下划线的操作将就地执行,从而避免了不必要的内存副本

到目前为止,我们完成了以下步骤: 从一个简化的注意力机制开始,添加了可训练的权重,然后引入了因果注意力掩码。接下来,我们将扩展因果注意力机制并编写多头注意力模块,以在我们的大语言模型中使用

5. 将单头注意力扩展到多头注意力

在本节中,我们将进行最后一步操作,即把先前实现的因果注意力类扩展到多个头上。这也被称为多头注意力。

” 多头 ” 这一术语指的是将注意力机制分成多个 ” 头 “,每个 ” 头 ” 独立工作。在这种情况下,单个因果注意力模块可以被看作单头注意力,因为它只有一组注意力权重按顺序处理输入。

我们将从因果注意力扩展到多头注意力。首先,我们将直观地通过堆叠多个 CausalAttention 模块来构建多头注意力模块。然后,我们将用一种更复杂但计算上更高效的方式来实现这个多头注意力模块。

叠加多个单头注意力层

在实际操作中,实现多头注意力需要构建多个自注意力机制的实例 (参见图 3-18),每个实例都有其独立的权重,然后将这些输出进行合成。虽然这种方法的计算量可能会非常大,但它对诸如基于 Transformer 的大语言模型之类的模型的复杂模式识别是非常重要的。

图 3-24 展示了多头注意力模块的结构,它是由图 3-18 所示的多个单头注意力模块依次叠加在一起组成的。

图 3-24 多头注意力模块包含两个堆叠在一起的单头注意力模块。因此,我们不是使用一个单一的矩阵 来计算值矩阵,而是在一个有两个头的多头注意模块中,现在有两个值权重矩阵: 。这同样适用于其他的权重矩阵,比如 。我们得到了两组上下文向量 ,最终可以将它们合并成一个单一的上下文向量矩阵

正如前面提到的,多头注意力的主要思想是多次 (并行) 运行注意力机制,每次使用学到的不同的线性投影——这些投影是通过将输入数据 (比如注意力机制中的查询向量、键向量和值向 量) 乘以权重矩阵得到的。在代码中,可以通过实现一个简单的 MultiHeadAttentionWrapper 类来达到这一目标,MultiHeadAttentionWrapper 类堆叠了多个之前实现的 CausalAttention 模块实例,如代码清单 3-4 所示。

# 代码清单 3-4 一个实现多头注意力的封装类
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                d_in, d_out, context_length, dropout, qkv_bias
            )
             for _ in range(num_heads)]
        )
 
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

如果采用这个具有两个注意力头 (num_heads=2) 以及 CausalAttention 输出维度为 d_out=2 的 MultiHeadAttentionWrapper 类,那么我们就会得到一个四维的上下文向量 (d_out*num_heads=4),如图 3-25 所示。

使用 MultiHeadAttentionWrapper,我们指定了注意力头的数量 (num_heads)。如 果设置 num_heads=2,那么我们就会得到一个具有两组上下文向量矩阵的张量。在每 个上下文向量矩阵中,行表示对应于词元的上下文向量,列则对应于通过 d_out=4 指 定的嵌入维度。我们沿着列维度连接这些上下文向量矩阵。由于我们有两个注意力头 并且嵌入维度为 2,因此最终的嵌入维度是 2 × 2 = 4

可以像使用 CausalAttention 类一样使用 MultiHeadAttentionWrapper 类:

torch.manual_seed(123)
context_length = batch.shape[1]  # 词元数量
d_in, d_out = 3, 2
 
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
 
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

运行上述代码会产生以下上下文向量张量:

tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],
         [-0.5874, 0.0058, 0.5891, 0.3257],
         [-0.6300, -0.0632, 0.6202, 0.3860],
         [-0.5675, -0.0843, 0.5478, 0.3589],
         [-0.5526, -0.0981, 0.5321, 0.3428],
         [-0.5299, -0.1081, 0.5077, 0.3493]],

        [[-0.4519, 0.2216, 0.4772, 0.1063],
         [-0.5874, 0.0058, 0.5891, 0.3257],
         [-0.6300, -0.0632, 0.6202, 0.3860],
         [-0.5675, -0.0843, 0.5478, 0.3589],
         [-0.5526, -0.0981, 0.5321, 0.3428],
         [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)

context_vecs.shape: torch.Size([2, 6, 4])

输出维度解析

  • 第一维 (2):批次大小,即输入文本数量
  • 第二维 (6):每个输入文本中的词元数量
  • 第三维 (4):每个词元的最终嵌入维度

维度计算:最终嵌入维度 = 单个头输出维度 × 头的数量 = 2 × 2 = 4

当前实现中,多个注意力头是通过列表推导式 [head(x) for head in self.heads] 依次处理的。我们可以通过并行处理所有注意力头来改进实现,通常是通过一次矩阵乘法同时计算所有头的输出,这将在后续章节中详细介绍。

通过权重划分实现多头注意力

到目前为止,我们已经创建了一个 MultiHeadAttentionWrapper,通过叠加多个单头注意力模块来实现多头注意力。这是通过实例化并组合多个 CausalAttention 对象来完成的。

与其维护两个单独的类 MultiHeadAttentionWrapper 和 CausalAttention,不如将这两个概念合并成一个 MultiHeadAttention 类。此外,除了将 MultiHeadAttentionWrapper 与 CausalAttention 代码合并,我们还会进行一些其他调整,以更高效地实现多头注意力。

在 MultiHeadAttentionWrapper 中,多头机制通过创建 CausalAttention 对象的列表 (self.heads) 来实现,每个对象代表一个独立的注意力头。CausalAttention 类单独执行注意力机制,每个头的结果会被拼接。相比之下,下面的 MultiHeadAttention 类会将多头功能整合到一个类内。它通过重新调整投影后的查询张量、键张量和值张量的形状,将输入分为多个头,然后在计算注意力后合并这些头的结果。

在进一步讨论之前,先来看一下 MultiHeadAttention 类,如代码清单 3-5 所示。

import torch
import torch.nn as nn
from torch import Tensor
 
class MultiHeadAttention(nn.Module):
    """
    高效的多头注意力机制实现类,整合了因果掩码和Dropout功能
    
    该类通过调整张量形状实现多头并行计算,而非实例化多个单头注意力模块,
    显著提升计算效率,是Transformer架构的核心组件。
    """
    def __init__(
        self, 
        d_in: int,          # 输入特征维度
        d_out: int,         # 输出特征维度
        context_length: int, # 序列上下文长度(用于生成掩码)
        dropout: float,     # Dropout概率
        num_heads: int,     # 注意力头数量
        qkv_bias: bool = False  # Q、K、V线性层是否使用偏置
    ):
        super().__init__()
        # 确保输出维度可被头数量整除(每个头的维度需一致)
        assert d_out % num_heads == 0, "输出维度d_out必须能被注意力头数量num_heads整除"
        
        self.d_out = d_out               # 输出特征维度
        self.num_heads = num_heads       # 注意力头数量
        self.head_dim = d_out // num_heads  # 每个注意力头的维度
        
        # Q、K、V的线性投影层(将输入映射到输出维度空间)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # 多头输出的最终线性投影层(组合所有头的输出)
        self.out_proj = nn.Linear(d_out, d_out)
        
        # Dropout层(用于注意力权重,防止过拟合)
        self.dropout = nn.Dropout(dropout)
        
        # 注册因果掩码缓冲区(对角线以上为1,用于屏蔽未来信息)
        # 形状: [context_length, context_length]
        self.register_buffer(
            "mask", 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
 
    def forward(self, x: Tensor) -> Tensor:
        """
        前向传播计算
        
        参数:
            x: 输入张量,形状为 [batch_size, num_tokens, d_in]
            
        返回:
            context_vec: 输出上下文向量,形状为 [batch_size, num_tokens, d_out]
        """
        # 获取输入张量的维度信息
        batch_size, num_tokens, d_in = x.shape  # b: 批次大小,num_tokens: 序列长度
        
        # 1. 计算Q、K、V向量(通过线性投影)
        # 形状均为: [batch_size, num_tokens, d_out]
        keys: Tensor = self.W_key(x)     # 键向量
        queries: Tensor = self.W_query(x)  # 查询向量
        values: Tensor = self.W_value(x)   # 值向量
        
        # 2. 调整张量形状,分离出多个注意力头
        # 从 [b, num_tokens, d_out] 重塑为 [b, num_tokens, num_heads, head_dim]
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        
        # 3. 转置操作,将注意力头维度提前,便于并行计算
        # 形状变为: [b, num_heads, num_tokens, head_dim]
        keys = keys.transpose(1, 2)    # 交换序列长度和头数量维度
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # 4. 计算注意力分数(点积)
        # keys.transpose(2, 3) 形状: [b, num_heads, head_dim, num_tokens]
        # 结果形状: [b, num_heads, num_tokens, num_tokens]
        attn_scores: Tensor = queries @ keys.transpose(2, 3)
        
        # 5. 应用因果掩码(屏蔽未来词元的分数)
        # 截取与当前序列长度匹配的掩码部分
        mask_bool: Tensor = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)  # 未来位置填充负无穷
        
        # 6. 计算注意力权重(缩放+Softmax+Dropout)
        attn_weights: Tensor = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,  # 除以√head_dim,防止梯度消失
            dim=-1  # 在最后一个维度(序列长度)上归一化
        )
        attn_weights = self.dropout(attn_weights)  # 应用Dropout
        
        # 7. 计算上下文向量(注意力权重 × 值向量)
        # 结果形状: [b, num_heads, num_tokens, head_dim]
        context_vec: Tensor = attn_weights @ values
        
        # 8. 合并所有注意力头的输出
        # 转置回 [b, num_tokens, num_heads, head_dim]
        context_vec = context_vec.transpose(1, 2)
        # 重塑为 [b, num_tokens, d_out](拼接所有头的输出)
        context_vec = context_vec.contiguous().view(
            batch_size, num_tokens, self.d_out
        )
        
        # 9. 最终线性投影(可选,进一步调整输出特征)
        context_vec = self.out_proj(context_vec)
        
        return context_vec

图 3-26 在具有两个注意力头的 MultiHeadAttentionWrapper 类中,我们初始化了两个权重 矩阵 ,并计算了两个查询矩阵 (上)。在 MultiHeadAttention 类 中,我们初始化了一个更大的权重矩阵 ,并只与输入矩阵进行一次矩阵乘法操作, 得到一个查询矩阵 ,然后将查询矩阵分割成了 (下)。对键矩阵和值矩阵的 操作与之类似,为了减少视觉混乱,这里没有展示

输入首先经过线性层进行变换 (针对查询矩阵、 键矩阵和值矩阵),然后被重塑为多个头。关键操作是将 d_out 维度分割为 num_headshead_dim,其中 head_dim = d_out / num_heads

此外,我们在 MultiHeadAttention 中添加了一个输出投影层 (self.out_proj),这是 在合并多个头之后的步骤,而 CausalAttention 类中并不存在这个层。这个输出投影层并不是 必需的 (更多细节参见附录 B),但它在许多大语言模型架构中被广泛使用,这就是我们在这里 添加它以保持完整性的原因。

尽管 MultiHeadAttention 类因额外的张量重塑和转置显得比 MultiHeadAttentionWrapper 更复杂,但它的效率更高。原因是我们只需进行一次矩阵乘法来计算键矩阵,例如, keys = self.W_key(x)(查询矩阵和值矩阵也是如此)。在 MultiHeadAttentionWrapper 中,我们需要对每个注意力头重复进行这种矩阵乘法,而矩阵乘法是计算资源消耗较大的操作之一。

MultiHeadAttention 类的用法与我们之前实现的 SelfAttention 类和 CausalAttention 类类似:

torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

结果显示,d_out 参数直接影响输出维度:

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])

现在,我们已经实现了将在实现和训练大语言模型时使用的 MultiHeadAttention 类。需要注意的是,尽管代码功能齐全,但为了保持输出的可读性,我们使用了相对较小的嵌入维度和注意力头数量。

相比之下,最小的 GPT-2 模型 (参数量为 1.17 亿) 有 12 个注意力头,上下文向量嵌入维度为 768,而最大的 GPT-2 模型 (参数量为 15 亿) 有 25 个注意力头,上下文向量嵌入维度为 1600。请注意,在 GPT 模型中,词元输入和上下文嵌入的嵌入维度是相同的 (d_in = d_out)。

小结

  • 注意力机制可以将输入元素转换为增强的上下文向量表示,这些表示涵盖了关于所有输 入的信息。
  • 自注意力机制通过对输入进行加权求和来计算上下文向量表示。
  • 在简化的注意力机制中,注意力权重是通过点积计算得出的。
  • 点积是两个向量的元素逐个相乘并将这些乘积相加的一种简洁计算方法。
  • 尽管矩阵乘法不是必需的,但它可以通过替代嵌套的 for 循环使计算更高效、更紧凑。
  • 在用于大语言模型的自注意力机制 (也被称为“缩放点积注意力”) 中,我们引入了可训 练的权重矩阵来计算输入的中间变换: 查询矩阵、值矩阵和键矩阵。
  • 在处理从左到右读取和生成文本的大语言模型时,我们会添加一个因果注意力掩码,以 防止模型访问未来的词元。
  • 除了使用因果注意力掩码将注意力权重置 0,还可以添加 dropout 掩码来减少大语言模型 中的过拟合。
  • 基于 Transformer 的大语言模型中的注意力模块涉及多个因果注意力实例,这被称为“多 头注意力”。
  • 可以通过堆叠多个因果注意力模块实例来创建多头注意力模块。
  • 创建多头注意力模块的一种更高效的方法是使用批量矩阵乘法。