GSPO vs GRPO
2025-12-30
深度学习
00

目录

假设场景
输入数据(假设的数值)
步骤1:计算 Token 级别的 negativeapproxkl
步骤2:计算序列级别的平均 KL(这是 GSPO 的关键步骤)
步骤3:构建 Token 级别的重要性比率(两种方式)
方式A:gspo(纯序列级别)
方式B:gspo_token(序列级别 + Token 级别组合)
步骤3.1:准备序列级别的部分
步骤3.2:计算 Token 级别的调整项
步骤3.3:相加得到最终结果
完整例子对比
步骤1:计算 negativeapproxkl
步骤2:计算序列级别的平均
方式A:gspo 计算 logimportanceratio
方式B:gspotoken 计算 logimportance_ratio
步骤4:转换成 ratio
总结对比

用一个具体例子详细说明 GSPO 的两个变体的计算过程,特别是 token 级别的重要性比率是如何得到的。

用一个具体的数值例子说明 GSPO 的两种方式的 token 级别重要性比率计算:

假设场景

假设有一个 prompt,模型生成了一个响应:"机器学习很实用"

输入数据(假设的数值)

假设我们有以下的 log_probs 和 old_log_probs(都是 tensor,shape 为 (batch_size=1, seq_len=5)):

python
展开代码
# 响应包含5个token:["机器", "学习", "很", "实用", "<pad>"] # 但实际只有前4个token有效 log_probs = torch.tensor([ [-2.0, -1.8, -1.5, -1.2, 0.0] # 当前策略的对数概率 ]) old_log_probs = torch.tensor([ [-2.2, -2.0, -1.6, -1.3, 0.0] # 旧策略的对数概率 ]) response_mask = torch.tensor([ [1, 1, 1, 1, 0] # 前4个token有效,最后一个padding ])

步骤1:计算 Token 级别的 negative_approx_kl

python
展开代码
negative_approx_kl = log_probs - old_log_probs

计算结果:

Token位置log_probsold_log_probsnegative_approx_kl
0 ("机器")-2.0-2.2-2.0 - (-2.2) = 0.2
1 ("学习")-1.8-2.0-1.8 - (-2.0) = 0.2
2 ("很")-1.5-1.6-1.5 - (-1.6) = 0.1
3 ("实用")-1.2-1.3-1.2 - (-1.3) = 0.1
4 (padding)0.00.00.0 - 0.0 = 0.0

所以:

python
展开代码
negative_approx_kl = torch.tensor([ [0.2, 0.2, 0.1, 0.1, 0.0] ])

步骤2:计算序列级别的平均 KL(这是 GSPO 的关键步骤)

python
展开代码
negative_approx_kl_in_seq = VF.masked_mean(negative_approx_kl, response_mask, dim=-1)

VF.masked_mean 会:

  1. 只对 mask=1 的位置求平均(忽略 padding)
  2. 在最后一个维度(dim=-1)上求平均

计算过程: [ \text{negative_approx_kl_in_seq} = \frac{0.2 + 0.2 + 0.1 + 0.1}{4} = 0.15 ]

结果:

python
展开代码
negative_approx_kl_in_seq = torch.tensor([0.15]) # shape: (1,)

步骤3:构建 Token 级别的重要性比率(两种方式)

方式A:gspo(纯序列级别)

python
展开代码
log_importance_ratio = negative_approx_kl_in_seq * response_mask

计算过程:

  1. negative_approx_kl_in_seq 的值是 0.15(标量)
  2. response_mask[1, 1, 1, 1, 0]
  3. 执行广播(broadcast):
    • 0.15 * 1 = 0.15(对每个有效token)
    • 0.15 * 0 = 0.0(对padding)

结果:

Token位置negative_approx_kl_in_seqresponse_masklog_importance_ratio
00.1510.15
10.1510.15
20.1510.15
30.1510.15
40.1500.0
python
展开代码
log_importance_ratio = torch.tensor([ [0.15, 0.15, 0.15, 0.15, 0.0] ])

特点:所有有效 token 的 log_importance_ratio 相同(都是 0.15)。

方式B:gspo_token(序列级别 + Token 级别组合)

python
展开代码
log_importance_ratio = negative_approx_kl_in_seq.detach().unsqueeze(-1) + log_probs - log_probs.detach()

需要分步说明:

步骤3.1:准备序列级别的部分

python
展开代码
negative_approx_kl_in_seq.detach().unsqueeze(-1)
  • .detach():不参与梯度计算
  • .unsqueeze(-1):在最后一个维度扩展,(1,)(1, 1)

结果:

python
展开代码
# shape: (1, 1) → 广播到 (1, 5) torch.tensor([[0.15, 0.15, 0.15, 0.15, 0.15]]) # 每个token都是0.15

步骤3.2:计算 Token 级别的调整项

python
展开代码
log_probs - log_probs.detach()

计算过程:

Token位置log_probslog_probs.detach()差值
0-2.0-2.00.0
1-1.8-1.80.0
2-1.5-1.50.0
3-1.2-1.20.0
40.00.00.0

注意:如果 log_probs 是从前向传播直接得到的(没有额外的梯度操作),那么 log_probs - log_probs.detach() 在当前时刻通常为 0。这里的目的是在后续训练迭代中,log_probs 会变化,从而提供 token 级别的梯度信息。

步骤3.3:相加得到最终结果

python
展开代码
log_importance_ratio = [0.15, 0.15, 0.15, 0.15, 0.15] + [0.0, 0.0, 0.0, 0.0, 0.0] = [0.15, 0.15, 0.15, 0.15, 0.15]

更真实的场景:假设在训练过程中,log_probs 发生了变化:

python
展开代码
# 假设经过一次更新后,log_probs变成了: log_probs_new = torch.tensor([ [-1.9, -1.7, -1.4, -1.1, 0.0] # 稍微提高了 ]) # 但是之前detach的值还是旧的: log_probs_old_detached = torch.tensor([ [-2.0, -1.8, -1.5, -1.2, 0.0] ]) # 那么 token级别的调整项: token_adjustment = log_probs_new - log_probs_old_detached = [-1.9 - (-2.0), -1.7 - (-1.8), -1.4 - (-1.5), -1.1 - (-1.2), 0.0 - 0.0] = [0.1, 0.1, 0.1, 0.1, 0.0] # 序列级别的部分(从negative_approx_kl计算得到,假设还是0.15): seq_level = 0.15 # 最终: log_importance_ratio = seq_level + token_adjustment = [0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.1, 0.15 + 0.0] = [0.25, 0.25, 0.25, 0.25, 0.15]

但实际上,在代码中的使用场景,让我重新理解一下...

让我再仔细看看代码逻辑。实际上,log_probs 是在每次前向传播时计算的,所以 log_probs - log_probs.detach() 在当前前向传播中确实是0。这个设计的目的可能是为了在后续的梯度更新中,让 token 级别的变化能够影响梯度。

让我用更实际的例子来说明:

完整例子对比

假设我们有一个 batch,包含 2 个响应:

python
展开代码
# Batch size = 2 # Response 1: "机器学习" (4个token) # Response 2: "很有趣" (3个token) log_probs = torch.tensor([ [-2.0, -1.8, -1.5, -1.2, 0.0], # Response 1 [-1.9, -1.6, -1.3, 0.0, 0.0] # Response 2 ]) old_log_probs = torch.tensor([ [-2.2, -2.0, -1.6, -1.3, 0.0], # Response 1 旧策略 [-2.1, -1.8, -1.5, 0.0, 0.0] # Response 2 旧策略 ]) response_mask = torch.tensor([ [1, 1, 1, 1, 0], # Response 1: 4个有效token [1, 1, 1, 0, 0] # Response 2: 3个有效token ])

步骤1:计算 negative_approx_kl

python
展开代码
negative_approx_kl = log_probs - old_log_probs

Response 1:

  • Token 0: -2.0 - (-2.2) = 0.2
  • Token 1: -1.8 - (-2.0) = 0.2
  • Token 2: -1.5 - (-1.6) = 0.1
  • Token 3: -1.2 - (-1.3) = 0.1
  • Token 4: 0.0 - 0.0 = 0.0

Response 2:

  • Token 0: -1.9 - (-2.1) = 0.2
  • Token 1: -1.6 - (-1.8) = 0.2
  • Token 2: -1.3 - (-1.5) = 0.2
  • Token 3: 0.0 - 0.0 = 0.0
  • Token 4: 0.0 - 0.0 = 0.0
python
展开代码
negative_approx_kl = torch.tensor([ [0.2, 0.2, 0.1, 0.1, 0.0], # Response 1 [0.2, 0.2, 0.2, 0.0, 0.0] # Response 2 ])

步骤2:计算序列级别的平均

python
展开代码
negative_approx_kl_in_seq = VF.masked_mean(negative_approx_kl, response_mask, dim=-1)

Response 1: [ \frac{0.2 + 0.2 + 0.1 + 0.1}{4} = 0.15 ]

Response 2: [ \frac{0.2 + 0.2 + 0.2}{3} = 0.2 ]

python
展开代码
negative_approx_kl_in_seq = torch.tensor([0.15, 0.2]) # shape: (2,)

方式A:gspo 计算 log_importance_ratio

python
展开代码
log_importance_ratio = negative_approx_kl_in_seq * response_mask

计算过程(广播):

Response 1:

  • 0.15 * [1, 1, 1, 1, 0] = [0.15, 0.15, 0.15, 0.15, 0.0]

Response 2:

  • 0.2 * [1, 1, 1, 0, 0] = [0.2, 0.2, 0.2, 0.0, 0.0]

最终结果:

响应Token 0Token 1Token 2Token 3Token 4
Response 10.150.150.150.150.0
Response 20.20.20.20.00.0

特点:每个响应内部的所有 token 都有相同的 log_importance_ratio

方式B:gspo_token 计算 log_importance_ratio

python
展开代码
log_importance_ratio = negative_approx_kl_in_seq.detach().unsqueeze(-1) + log_probs - log_probs.detach()

由于在当前前向传播中 log_probs - log_probs.detach() 为 0,所以:

Response 1:

  • [0.15, 0.15, 0.15, 0.15, 0.15] + [0, 0, 0, 0, 0] = [0.15, 0.15, 0.15, 0.15, 0.15]

Response 2:

  • [0.2, 0.2, 0.2, 0.2, 0.2] + [0, 0, 0, 0, 0] = [0.2, 0.2, 0.2, 0.2, 0.2]

但在实际训练中,当 log_probs 在梯度更新后变化时,log_probs - log_probs.detach() 会提供 token 级别的差异,从而让每个 token 的 log_importance_ratio 可以不同。

步骤4:转换成 ratio

python
展开代码
ratio = torch.exp(log_importance_ratio)

对于 gspo 方式(Response 1):

  • Token 0: e^0.15 ≈ 1.162
  • Token 1: e^0.15 ≈ 1.162
  • Token 2: e^0.15 ≈ 1.162
  • Token 3: e^0.15 ≈ 1.162

总结对比

方法每个响应的 token 级别 log_importance_ratio特点
gspo所有 token 相同(都是序列级别平均值)纯序列级别优化,稳定性高
gspo_token序列级别基础值 + token 级别调整项结合序列和 token 级别信息,更灵活

关键在于理解:gspo 直接将序列级别的平均值扩展到所有 token,而 gspo_token 在序列级别基础上增加了 token 级别的调整项(虽然在当前前向传播中可能为0,但在梯度更新后会体现差异)。

如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!