用一个具体例子详细说明 GSPO 的两个变体的计算过程,特别是 token 级别的重要性比率是如何得到的。
用一个具体的数值例子说明 GSPO 的两种方式的 token 级别重要性比率计算:
假设有一个 prompt,模型生成了一个响应:"机器学习很实用"
假设我们有以下的 log_probs 和 old_log_probs(都是 tensor,shape 为 (batch_size=1, seq_len=5)):
python展开代码# 响应包含5个token:["机器", "学习", "很", "实用", "<pad>"]
# 但实际只有前4个token有效
log_probs = torch.tensor([
[-2.0, -1.8, -1.5, -1.2, 0.0] # 当前策略的对数概率
])
old_log_probs = torch.tensor([
[-2.2, -2.0, -1.6, -1.3, 0.0] # 旧策略的对数概率
])
response_mask = torch.tensor([
[1, 1, 1, 1, 0] # 前4个token有效,最后一个padding
])
python展开代码negative_approx_kl = log_probs - old_log_probs
计算结果:
| Token位置 | log_probs | old_log_probs | negative_approx_kl |
|---|---|---|---|
| 0 ("机器") | -2.0 | -2.2 | -2.0 - (-2.2) = 0.2 |
| 1 ("学习") | -1.8 | -2.0 | -1.8 - (-2.0) = 0.2 |
| 2 ("很") | -1.5 | -1.6 | -1.5 - (-1.6) = 0.1 |
| 3 ("实用") | -1.2 | -1.3 | -1.2 - (-1.3) = 0.1 |
| 4 (padding) | 0.0 | 0.0 | 0.0 - 0.0 = 0.0 |
所以:
python展开代码negative_approx_kl = torch.tensor([
[0.2, 0.2, 0.1, 0.1, 0.0]
])
python展开代码negative_approx_kl_in_seq = VF.masked_mean(negative_approx_kl, response_mask, dim=-1)
VF.masked_mean 会:
dim=-1)上求平均计算过程: [ \text{negative_approx_kl_in_seq} = \frac{0.2 + 0.2 + 0.1 + 0.1}{4} = 0.15 ]
结果:
python展开代码negative_approx_kl_in_seq = torch.tensor([0.15]) # shape: (1,)
gspo(纯序列级别)python展开代码log_importance_ratio = negative_approx_kl_in_seq * response_mask
计算过程:
negative_approx_kl_in_seq 的值是 0.15(标量)response_mask 是 [1, 1, 1, 1, 0]0.15 * 1 = 0.15(对每个有效token)0.15 * 0 = 0.0(对padding)结果:
| Token位置 | negative_approx_kl_in_seq | response_mask | log_importance_ratio |
|---|---|---|---|
| 0 | 0.15 | 1 | 0.15 |
| 1 | 0.15 | 1 | 0.15 |
| 2 | 0.15 | 1 | 0.15 |
| 3 | 0.15 | 1 | 0.15 |
| 4 | 0.15 | 0 | 0.0 |
python展开代码log_importance_ratio = torch.tensor([
[0.15, 0.15, 0.15, 0.15, 0.0]
])
特点:所有有效 token 的 log_importance_ratio 相同(都是 0.15)。
gspo_token(序列级别 + Token 级别组合)python展开代码log_importance_ratio = negative_approx_kl_in_seq.detach().unsqueeze(-1) + log_probs - log_probs.detach()
需要分步说明:
python展开代码negative_approx_kl_in_seq.detach().unsqueeze(-1)
.detach():不参与梯度计算.unsqueeze(-1):在最后一个维度扩展,(1,) → (1, 1)结果:
python展开代码# shape: (1, 1) → 广播到 (1, 5)
torch.tensor([[0.15, 0.15, 0.15, 0.15, 0.15]]) # 每个token都是0.15
python展开代码log_probs - log_probs.detach()
计算过程:
| Token位置 | log_probs | log_probs.detach() | 差值 |
|---|---|---|---|
| 0 | -2.0 | -2.0 | 0.0 |
| 1 | -1.8 | -1.8 | 0.0 |
| 2 | -1.5 | -1.5 | 0.0 |
| 3 | -1.2 | -1.2 | 0.0 |
| 4 | 0.0 | 0.0 | 0.0 |
注意:如果 log_probs 是从前向传播直接得到的(没有额外的梯度操作),那么 log_probs - log_probs.detach() 在当前时刻通常为 0。这里的目的是在后续训练迭代中,log_probs 会变化,从而提供 token 级别的梯度信息。
python展开代码log_importance_ratio = [0.15, 0.15, 0.15, 0.15, 0.15] + [0.0, 0.0, 0.0, 0.0, 0.0]
= [0.15, 0.15, 0.15, 0.15, 0.15]
更真实的场景:假设在训练过程中,log_probs 发生了变化:
python展开代码# 假设经过一次更新后,log_probs变成了:
log_probs_new = torch.tensor([
[-1.9, -1.7, -1.4, -1.1, 0.0] # 稍微提高了
])
# 但是之前detach的值还是旧的:
log_probs_old_detached = torch.tensor([
[-2.0, -1.8, -1.5, -1.2, 0.0]
])
# 那么 token级别的调整项:
token_adjustment = log_probs_new - log_probs_old_detached
= [-1.9 - (-2.0), -1.7 - (-1.8), -1.4 - (-1.5), -1.1 - (-1.2), 0.0 - 0.0]
= [0.1, 0.1, 0.1, 0.1, 0.0]
# 序列级别的部分(从negative_approx_kl计算得到,假设还是0.15):
seq_level = 0.15
# 最终:
log_importance_ratio = seq_level + token_adjustment
= [0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.0]
= [0.25, 0.25, 0.25, 0.25, 0.15]
但实际上,在代码中的使用场景,让我重新理解一下...
让我再仔细看看代码逻辑。实际上,log_probs 是在每次前向传播时计算的,所以 log_probs - log_probs.detach() 在当前前向传播中确实是0。这个设计的目的可能是为了在后续的梯度更新中,让 token 级别的变化能够影响梯度。
让我用更实际的例子来说明:
假设我们有一个 batch,包含 2 个响应:
python展开代码# Batch size = 2
# Response 1: "机器学习" (4个token)
# Response 2: "很有趣" (3个token)
log_probs = torch.tensor([
[-2.0, -1.8, -1.5, -1.2, 0.0], # Response 1
[-1.9, -1.6, -1.3, 0.0, 0.0] # Response 2
])
old_log_probs = torch.tensor([
[-2.2, -2.0, -1.6, -1.3, 0.0], # Response 1 旧策略
[-2.1, -1.8, -1.5, 0.0, 0.0] # Response 2 旧策略
])
response_mask = torch.tensor([
[1, 1, 1, 1, 0], # Response 1: 4个有效token
[1, 1, 1, 0, 0] # Response 2: 3个有效token
])
python展开代码negative_approx_kl = log_probs - old_log_probs
Response 1:
Response 2:
python展开代码negative_approx_kl = torch.tensor([
[0.2, 0.2, 0.1, 0.1, 0.0], # Response 1
[0.2, 0.2, 0.2, 0.0, 0.0] # Response 2
])
python展开代码negative_approx_kl_in_seq = VF.masked_mean(negative_approx_kl, response_mask, dim=-1)
Response 1: [ \frac{0.2 + 0.2 + 0.1 + 0.1}{4} = 0.15 ]
Response 2: [ \frac{0.2 + 0.2 + 0.2}{3} = 0.2 ]
python展开代码negative_approx_kl_in_seq = torch.tensor([0.15, 0.2]) # shape: (2,)
gspo 计算 log_importance_ratiopython展开代码log_importance_ratio = negative_approx_kl_in_seq * response_mask
计算过程(广播):
Response 1:
Response 2:
最终结果:
| 响应 | Token 0 | Token 1 | Token 2 | Token 3 | Token 4 |
|---|---|---|---|---|---|
| Response 1 | 0.15 | 0.15 | 0.15 | 0.15 | 0.0 |
| Response 2 | 0.2 | 0.2 | 0.2 | 0.0 | 0.0 |
特点:每个响应内部的所有 token 都有相同的 log_importance_ratio。
gspo_token 计算 log_importance_ratiopython展开代码log_importance_ratio = negative_approx_kl_in_seq.detach().unsqueeze(-1) + log_probs - log_probs.detach()
由于在当前前向传播中 log_probs - log_probs.detach() 为 0,所以:
Response 1:
Response 2:
但在实际训练中,当 log_probs 在梯度更新后变化时,log_probs - log_probs.detach() 会提供 token 级别的差异,从而让每个 token 的 log_importance_ratio 可以不同。
python展开代码ratio = torch.exp(log_importance_ratio)
对于 gspo 方式(Response 1):
| 方法 | 每个响应的 token 级别 log_importance_ratio | 特点 |
|---|---|---|
gspo | 所有 token 相同(都是序列级别平均值) | 纯序列级别优化,稳定性高 |
gspo_token | 序列级别基础值 + token 级别调整项 | 结合序列和 token 级别信息,更灵活 |
关键在于理解:gspo 直接将序列级别的平均值扩展到所有 token,而 gspo_token 在序列级别基础上增加了 token 级别的调整项(虽然在当前前向传播中可能为0,但在梯度更新后会体现差异)。


本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!