编辑
2025-05-15
算法刷题
00

多头:

py
import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super(SelfAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads" # 线性变换得到 Q, K, V self.values = nn.Linear(embed_size, embed_size) self.keys = nn.Linear(embed_size, embed_size) self.queries = nn.Linear(embed_size, embed_size) # 输出线性层 self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, x): # x shape: (N, seq_len, embed_size) N = x.shape[0] seq_len = x.shape[1] # 线性变换得到 Q, K, V values = self.values(x) # (N, seq_len, embed_size) keys = self.keys(x) # (N, seq_len, embed_size) queries = self.queries(x) # (N, seq_len, embed_size) # 分割多头 values = values.reshape(N, seq_len, self.heads, self.head_dim) keys = keys.reshape(N, seq_len, self.heads, self.head_dim) queries = queries.reshape(N, seq_len, self.heads, self.head_dim) # 计算注意力分数 energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # queries shape: (N, seq_len, heads, head_dim) # keys shape: (N, seq_len, heads, head_dim) # energy shape: (N, heads, seq_len, seq_len) # 缩放点积注意力 attention = F.softmax(energy / (self.embed_size ** (1/2)), dim=3) # 应用注意力到values上 out = torch.einsum("nhql,nlhd->nqhd", [attention, values]) # attention shape: (N, heads, seq_len, seq_len) # values shape: (N, seq_len, heads, head_dim) # out shape: (N, seq_len, heads, head_dim) # 合并多头 out = out.reshape(N, seq_len, self.embed_size) # 输出线性变换 out = self.fc_out(out) return out

简化为单头:

py
class SimpleSelfAttention(nn.Module): def __init__(self, embed_size): super().__init__() self.q = nn.Linear(embed_size, embed_size) self.k = nn.Linear(embed_size, embed_size) self.v = nn.Linear(embed_size, embed_size) def forward(self, x): # x shape: (N, seq_len, embed_size) Q = self.q(x) K = self.k(x) V = self.v(x) # 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1) ** 0.5) attention = F.softmax(scores, dim=-1) # 应用注意力 out = torch.matmul(attention, V) return out
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!