2024-10-13
深度学习
00

目录

为什么在进行 Softmax 之前需要对 Attention 进行 Scaling?
Attention 机制概述
为什么要进行 Scaling?
1. 缓解点积值随维度增大的影响
2. 公式推导与解释
代码实现示例
总结

为什么在进行 Softmax 之前需要对 Attention 进行 Scaling?

在 Transformer 模型中,Attention 机制通过计算 Query、Key 和 Value 的相似性来生成注意力权重。在进行 Softmax 之前,我们会对 Attention 分数进行 Scaling(缩放),也就是除以 dk\sqrt{d_k}。那么,为什么需要这样做呢?本文将通过公式推导和代码示例,详细讲解这一操作的意义和作用。

Attention 机制概述

在 Transformer 中,Attention 机制的核心是计算 Query 和 Key 的点积,进而得出不同 Token 之间的相似性。以 Self-Attention 为例,假设输入是矩阵 XRn×dX \in \mathbb{R}^{n \times d},其中 nn 表示序列长度,dd 表示每个 Token 的维度。Attention 的计算过程如下:

  1. 计算 Query、Key 和 Value 矩阵

    • Query: Q=XWQQ = XW_Q
    • Key: K=XWKK = XW_K
    • Value: V=XWVV = XW_V

    其中,WQ,WK,WVRd×dkW_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} 是学习得到的参数矩阵。

  2. 计算 Attention 分数

    Attention(Q,K,V)=Softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中,QKTdk\frac{QK^T}{\sqrt{d_k}} 就是我们要探讨的 Scaling 操作。

为什么要进行 Scaling?

1. 缓解点积值随维度增大的影响

在点积注意力机制中,QQKK 的维度 dkd_k 较大时,QKTQK^T 的值会变大。考虑 QQKK 的每个元素在独立同分布下的期望为 0、方差为 1 的情况:

  • QKTQK^T 的每个元素是 dkd_k 个独立随机变量的和,因此期望为 0,方差为 dkd_k
  • 这意味着 QKTQK^T 的元素会随着 dkd_k 的增大而变得更大,容易导致 Softmax 计算中的梯度消失或过饱和现象,使模型难以训练。

通过将 QKTQK^T 除以 dk\sqrt{d_k},可以将其方差稳定在 1 左右,避免值过大对 Softmax 的影响,提升训练稳定性。

2. 公式推导与解释

假设 QQKK 的每个分量都是服从正态分布 N(0,1)N(0, 1) 的独立同分布随机变量。则 QKTQK^T 的每个元素是 dkd_k 个随机变量的和。由于独立随机变量的和的方差等于方差的和:

Var(QKT)=dk\text{Var}(QK^T) = d_k

对于 Softmax 函数 Softmax(xi)=exiexj\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum e^{x_j}},当 xix_i 值较大时,exie^{x_i} 会增长极快,导致数值过大,影响梯度计算。为了缓解这种情况,将 xix_i 缩小 dk\sqrt{d_k} 倍:

Scaled Attention(Q,K,V)=Softmax(QKTdk)V\text{Scaled Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

此时,QKT/dkQK^T / \sqrt{d_k} 的方差约为 1,减少了数值溢出和梯度消失的风险。

代码实现示例

在实际代码实现中,通常会在计算 Attention 权重时将其除以 dk\sqrt{d_k}

python
import torch import torch.nn.functional as F # 假设 Query, Key, Value 和 d_k Q = torch.randn(1, 8, 64) # (batch_size, seq_len, d_k) K = torch.randn(1, 8, 64) # (batch_size, seq_len, d_k) V = torch.randn(1, 8, 64) # (batch_size, seq_len, d_k) d_k = Q.size(-1) # 计算未缩放的 Attention 分数 scores = torch.bmm(Q, K.transpose(1, 2)) # 缩放 Attention 分数 scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # 使用 Softmax 计算权重 attention_weights = F.softmax(scaled_scores, dim=-1) # 最终的 Attention 输出 output = torch.bmm(attention_weights, V) print("Attention输出:", output)

在这个代码中,scores 是未缩放的点积结果。通过除以 dk\sqrt{d_k} 进行缩放,再输入 Softmax 函数以计算最终的 Attention 权重。

总结

将 Attention 分数除以 dk\sqrt{d_k} 是为了稳定 Softmax 输入的数值范围,防止随着维度增大,点积值过大导致的数值不稳定性,从而提升模型训练效果。这个操作是 Transformer 结构中一个重要的细节,对 Attention 机制的数值稳定性起到了关键作用。

如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!