多头:
pyimport torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"
# 线性变换得到 Q, K, V
self.values = nn.Linear(embed_size, embed_size)
self.keys = nn.Linear(embed_size, embed_size)
self.queries = nn.Linear(embed_size, embed_size)
# 输出线性层
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x):
# x shape: (N, seq_len, embed_size)
N = x.shape[0]
seq_len = x.shape[1]
# 线性变换得到 Q, K, V
values = self.values(x) # (N, seq_len, embed_size)
keys = self.keys(x) # (N, seq_len, embed_size)
queries = self.queries(x) # (N, seq_len, embed_size)
# 分割多头
values = values.reshape(N, seq_len, self.heads, self.head_dim)
keys = keys.reshape(N, seq_len, self.heads, self.head_dim)
queries = queries.reshape(N, seq_len, self.heads, self.head_dim)
# 计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, seq_len, heads, head_dim)
# keys shape: (N, seq_len, heads, head_dim)
# energy shape: (N, heads, seq_len, seq_len)
# 缩放点积注意力
attention = F.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# 应用注意力到values上
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
# attention shape: (N, heads, seq_len, seq_len)
# values shape: (N, seq_len, heads, head_dim)
# out shape: (N, seq_len, heads, head_dim)
# 合并多头
out = out.reshape(N, seq_len, self.embed_size)
# 输出线性变换
out = self.fc_out(out)
return out
https://leetcode.cn/problems/add-two-numbers
给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。
请你将两个数相加,并以相同形式返回一个表示和的链表。
你可以假设除了数字 0 之外,这两个数都不会以 0 开头。
输入:l1 = [2,4,3], l2 = [5,6,4]
输出:[7,0,8]
解释:342 + 465 = 807.
python# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def addTwoNumbers(
self, l1: Optional[ListNode], l2: Optional[ListNode]
) -> Optional[ListNode]:
dammy = ListNode() # 虚拟节点,最终返回这个节点的下一个节点
… head.next = ListNode(
v_dig
) # 第一轮 比如7+6=13,那么这里就是3。第二轮,2+5+1=8,链表会是 dammy>>3>>8,从个位开始的
head = head.next
if l1:
l1 = l1.next
if l2:
l2 = l2.next
return dammy.next