直接偏好优化(Direct Preference Optimization, DPO)是一种用于语言模型对齐的算法,由Rafailov等人在2023年提出,作为强化学习人类反馈(RLHF)的替代方案。DPO的目标与RLHF相同:使语言模型的输出更好地符合人类偏好,但DPO通过简化流程,直接从人类偏好数据中优化模型,无需单独的奖励模型和复杂的强化学习过程。
为什么需要DPO?
核心区别:
适用场景:
DPO的理论基础来自于Bradley-Terry模型,该模型用于描述对两个选项的偏好概率。
当我们有两个回答y₁和y₂时,人类偏好y₁而非y₂的概率可以建模为:
p*(y₁ ≻ y₂ | x) = σ(r*(x, y₁) - r*(x, y₂))
其中:
DPO的关键洞见是,我们可以将奖励函数r表示为最优策略π和参考策略πref的函数:
r*(x, y) = β log(π*(y|x)/πref(y|x)) + β log Z(x)
其中:
将这个公式代入Bradley-Terry模型并化简,我们得到:
p*(y₁ ≻ y₂ | x) = σ(β log(π*(y₁|x)/πref(y₁|x)) - β log(π*(y₂|x)/πref(y₂|x)))
这个公式中的Z(x)项被消去了,使计算变得可行。
基于上述模型,DPO的损失函数为:
L_DPO(πθ; πref) = -E_(yw,yl,x)~D[log(σ(β(log(πθ(yw|x)/πθ(yl|x)) - log(πref(yw|x)/πref(yl|x)))))]
其中:
补充解释:
以下是DPO损失函数的PyTorch实现:
pythonimport torch
import torch.nn.functional as F
def dpo_loss(
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
"""
计算DPO损失
参数:
policy_chosen_logps: 策略模型对偏好回答的对数概率
policy_rejected_logps: 策略模型对非偏好回答的对数概率
reference_chosen_logps: 参考模型对偏好回答的对数概率
reference_rejected_logps: 参考模型对非偏好回答的对数概率
beta: 正则化参数,控制KL散度的强度
返回:
dpo_loss: DPO损失值
"""
# 计算策略模型和参考模型之间的对数概率比率
policy_chosen_logps_ratio = policy_chosen_logps - reference_chosen_logps
policy_rejected_logps_ratio = policy_rejected_logps - reference_rejected_logps
# 计算DPO损失
logits = beta * (policy_chosen_logps_ratio - policy_rejected_logps_ratio)
losses = -F.logsigmoid(logits)
return losses.mean()
# 在训练过程中的使用
def train_step(
policy_model,
reference_model,
batch,
optimizer,
beta=0.1
):
# 冻结参考模型
for param in reference_model.parameters():
param.requires_grad = False
# 获取输入和输出
prompts = batch["prompts"]
chosen_responses = batch["chosen_responses"]
rejected_responses = batch["rejected_responses"]
# 计算策略模型的对数概率
policy_chosen_logps = compute_logprobs(policy_model, prompts, chosen_responses)
policy_rejected_logps = compute_logprobs(policy_model, prompts, rejected_responses)
# 计算参考模型的对数概率
with torch.no_grad():
reference_chosen_logps = compute_logprobs(reference_model, prompts, chosen_responses)
reference_rejected_logps = compute_logprobs(reference_model, prompts, rejected_responses)
# 计算DPO损失
loss = dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
beta
)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def compute_logprobs(model, prompts, responses):
"""计算给定提示和回复的对数概率"""
# 这里的具体实现取决于模型架构
# 对于自回归语言模型,通常需要计算每个token的对数概率并求和
# ...
return logprobs
代码解读:
compute_logprobs
函数的具体实现需要根据模型架构调整,例如对于GPT类模型,可以通过softmax和交叉熵计算每个token的对数概率。DPO算法已被用于多个开源语言模型的训练中,包括:
实际案例:
DPO作为RLHF的替代方案,通过直接从人类偏好数据中学习,避免了奖励建模和强化学习的复杂性。它的主要优势包括:
未来展望:
DPO的出现标志着语言模型对齐技术的重要发展,为更高效地训练符合人类偏好的语言模型提供了新的方向。
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!