多头:
pyimport torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"
# 线性变换得到 Q, K, V
self.values = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.queries = nn.Linear(embed_size, embed_size)
# 输出线性层
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x):
# x shape: (N, seq_len, embed_size)
N = x.shape[0]
seq_len = x.shape[1]
# 线性变换得到 Q, K, V
values = self.values(x) # (N, seq_len, embed_size)
keys = self.keys(x) # (N, seq_len, embed_size)
queries = self.queries(x) # (N, seq_len, embed_size)
# 分割多头
values = values.reshape(N, seq_len, self.heads, self.head_dim)
keys = keys.reshape(N, seq_len, self.heads, self.head_dim)
queries = queries.reshape(N, seq_len, self.heads, self.head_dim)
# 计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, seq_len, heads, head_dim)
# keys shape: (N, seq_len, heads, head_dim)
# energy shape: (N, heads, seq_len, seq_len)
# 缩放点积注意力
attention = F.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# 应用注意力到values上
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
# attention shape: (N, heads, seq_len, seq_len)
# values shape: (N, seq_len, heads, head_dim)
# out shape: (N, seq_len, heads, head_dim)
# 合并多头
out = out.reshape(N, seq_len, self.embed_size)
# 输出线性变换
out = self.fc_out(out)
return out
简化为单头:
pyclass SimpleSelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.q = nn.Linear(embed_size, embed_size)
self.k = nn.Linear(embed_size, embed_size)
self.v = nn.Linear(embed_size, embed_size)
def forward(self, x):
# x shape: (N, seq_len, embed_size)
Q = self.q(x)
K = self.k(x)
V = self.v(x)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1) ** 0.5)
attention = F.softmax(scores, dim=-1)
# 应用注意力
out = torch.matmul(attention, V)
return out
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!