编辑
2025-04-11
深度学习
00

目录

Stable Diffusion 采样器的数学原理
1. 扩散模型基础理论
1.1 正向过程(加噪)
1.2 逆向过程(去噪)
2. 常见采样器的数学原理与实现
2.1 Euler 采样器
数学原理
代码实现
2.2 Euler Ancestral 采样器
数学原理
代码实现
2.3 Heun 采样器
数学原理
代码实现
2.4 DPM++ 2M 采样器
数学原理
代码实现
2.5 LMS (线性多步法) 采样器
数学原理
代码实现
3. 采样器特性比较
3.1 Euler 和 Euler Ancestral
3.2 Heun
3.3 DPM 系列 (DPM++, DPM++ 2M, DPM++ SDE)
3.4 LMS (线性多步法)
4. 噪声调度 (Noise Schedulers)
4.1 Karras 调度
4.2 指数调度
5. 采样器对最终图像的影响
6. 实际应用建议

Stable Diffusion 采样器的数学原理

Stable Diffusion 采样器是扩散模型去噪过程的核心组件,不同采样器采用不同的数学方法来逐步将纯噪声转换为有意义的图像。下面我将详细介绍几种主要采样器的工作原理、数学公式和代码实现。

1. 扩散模型基础理论

在深入具体采样器前,先了解扩散模型的基本原理:

1.1 正向过程(加噪)

正向扩散过程将图像 x0x_0 逐步添加高斯噪声,直到变成纯噪声 xTx_T

xt=αtxt1+1αtϵt1x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1-\alpha_t} \epsilon_{t-1}

其中 αt\alpha_t 是预定义的噪声调度,ϵt1\epsilon_{t-1} 是从标准正态分布采样的噪声。

1.2 逆向过程(去噪)

Stable Diffusion 的核心是训练一个神经网络 ϵθ\epsilon_\theta 来预测每一步中添加的噪声。采样器的任务是使用这个网络逐步从纯噪声 xTx_T 恢复到干净图像 x0x_0

Read file: repositories/k-diffusion/k_diffusion/sampling.py

现在让我详细分析几种主要采样器的实现和数学原理:

2. 常见采样器的数学原理与实现

2.1 Euler 采样器

Euler 采样器是最简单的采样器,使用欧拉法进行数值积分,它直接将扩散方程作为常微分方程(ODE)求解。

数学原理

欧拉法的基本公式为: xt+1=xt+dxdtΔtx_{t+1} = x_t + \frac{dx}{dt} \Delta t

在扩散模型的上下文中,转化为: xi+1=xi+d(xi,σi)(σi+1σi)x_{i+1} = x_i + d(x_i, \sigma_i) \cdot (\sigma_{i+1} - \sigma_i)

其中 d(xi,σi)d(x_i, \sigma_i) 是预测的噪声方向: d(xi,σi)=xidenoisedσid(x_i, \sigma_i) = \frac{x_i - \text{denoised}}{\sigma_i}

代码实现

python
@torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): # 可选的随机扰动增强 gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 # 模型预测去噪后的图像 denoised = model(x, sigma_hat * s_in, **extra_args) # 计算噪声方向 d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) # 计算时间步长 dt = sigmas[i + 1] - sigma_hat # 应用欧拉法迭代 x = x + d * dt return x

2.2 Euler Ancestral 采样器

Euler Ancestral 采样器加入了随机扰动,改善样本多样性。

数学原理

对标准欧拉法进行扩展,在每一步后添加经过缩放的随机噪声:

σdown,σup=get_ancestral_step(σi,σi+1,η)\sigma_{down}, \sigma_{up} = \text{get\_ancestral\_step}(\sigma_i, \sigma_{i+1}, \eta) xi+1/2=xi+d(xi,σi)(σdownσi)x_{i+1/2} = x_i + d(x_i, \sigma_i) \cdot (\sigma_{down} - \sigma_i) xi+1=xi+1/2+σupN(0,I)x_{i+1} = x_{i+1/2} + \sigma_{up} \cdot \mathcal{N}(0, I)

其中 η\eta 控制添加的噪声量,σdown\sigma_{down}σup\sigma_{up} 分别是新的噪声水平和需要添加的噪声量。

代码实现

python
@torch.no_grad() def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): # 模型预测去噪后的图像 denoised = model(x, sigmas[i] * s_in, **extra_args) # 计算ancestral步长参数 sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) # 计算噪声方向 d = to_d(x, sigmas[i], denoised) # 应用欧拉法更新 dt = sigma_down - sigmas[i] x = x + d * dt # 添加随机噪声 if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x

2.3 Heun 采样器

Heun 采样器是欧拉法的二阶改进版本,提供了更准确的数值积分。

数学原理

Heun法的数学公式为: k1=d(xi,σi)k_1 = d(x_i, \sigma_i) xi+1/2=xi+k1(σi+1σi)x_{i+1/2} = x_i + k_1 \cdot (\sigma_{i+1} - \sigma_i) k2=d(xi+1/2,σi+1)k_2 = d(x_{i+1/2}, \sigma_{i+1}) xi+1=xi+k1+k22(σi+1σi)x_{i+1} = x_i + \frac{k_1 + k_2}{2} \cdot (\sigma_{i+1} - \sigma_i)

这相当于预测-修正过程,提高了积分精度。

代码实现

python
@torch.no_grad() def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): # 可选的随机扰动增强 gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 # 模型预测去噪后的图像(第一步) denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # 对于最后一步,使用欧拉法(避免在σ=0处除法问题) x = x + d * dt else: # Heun法:预测-修正步骤 x_2 = x + d * dt # 预测步骤 denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) # 在预测点评估 d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 # 平均梯度 x = x + d_prime * dt # 修正步骤 return x

2.4 DPM++ 2M 采样器

DPM++ 2M 采样器是一种现代的高效采样器,基于DPM-Solver的思想,但优化了中点法规则。

数学原理

DPM++ 2M 使用改进的中点法求解ODE:

xi+1/2=xi+0.5d(xi,σi)(σi+1σi)x_{i+1/2} = x_i + 0.5 \cdot d(x_i, \sigma_i) \cdot (\sigma_{i+1} - \sigma_i) xi+1=xi+d(xi+1/2,σi+1/2)(σi+1σi)x_{i+1} = x_i + d(x_{i+1/2}, \sigma_{i+1/2}) \cdot (\sigma_{i+1} - \sigma_i)

其中 σi+1/2=0.5(σi+σi+1)\sigma_{i+1/2} = 0.5 \cdot (\sigma_i + \sigma_{i+1}) 是中点噪声水平。

代码实现

python
@torch.no_grad() def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): """Implements DPM++ (2M) sampler.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): # 当前噪声等级 sigma = sigmas[i] # 下一个噪声等级 sigma_next = sigmas[i + 1] # 计算混合系数 t, t_next = -sigma.log(), -sigma_next.log() h = t_next - t # 模型预测去噪后的图像 denoised = model(x, sigma * s_in, **extra_args) d = (x - denoised) / sigma if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) # 中点法的第一步:计算中点估计 sigma_mid = sigma * (-0.5 * h).exp() x_mid = x + d * (sigma_mid - sigma) # 在中点评估模型 denoised_mid = model(x_mid, sigma_mid * s_in, **extra_args) d_mid = (x_mid - denoised_mid) / sigma_mid # 中点法的第二步:使用中点梯度更新 x = x + d_mid * (sigma_next - sigma) return x

2.5 LMS (线性多步法) 采样器

LMS 采样器使用线性多步法,根据多个历史步骤提高精度。

数学原理

线性多步法的一般形式为: xn+1=xn+Δtj=0k1βjf(tnj,xnj)x_{n+1} = x_n + \Delta t \sum_{j=0}^{k-1} \beta_j f(t_{n-j}, x_{n-j})

其中 βj\beta_j 是根据多项式插值计算的系数。LMS 通常使用最近的 4 个步骤来提高精度。

代码实现

python
@torch.no_grad() def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): """Linear multistep sampler for discrete-time DPMs.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # 存储历史导数 ds = [] for i in trange(len(sigmas) - 1, disable=disable): # 模型预测 denoised = model(x, sigmas[i] * s_in, **extra_args) d = to_d(x, sigmas[i], denoised) ds.append(d) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) if len(ds) > order: ds.pop(0) # 使用线性多步法计算下一步 step_size = sigmas[i + 1] - sigmas[i] x_t = x + step_size * sum( linear_multistep_coeff(len(ds), sigmas[i], sigmas[i + 1], j + 1) * d for j, d in enumerate(reversed(ds)) ) x = x_t return x

3. 采样器特性比较

各种采样器在速度、质量和特性上有不同的权衡:

3.1 Euler 和 Euler Ancestral

  • Euler: 简单、稳定,但精度较低,需要更多步骤才能获得高质量结果
  • Euler Ancestral: 添加随机性,生成更多样的图像,更适合创意生成,但可能不如确定性方法精确

3.2 Heun

  • 二阶精度,比欧拉法更准确
  • 计算量是欧拉法的两倍(每步需要两次模型评估)
  • 对于相同质量的结果,通常可以使用更少的步骤

3.3 DPM 系列 (DPM++, DPM++ 2M, DPM++ SDE)

  • 最好的质量/速度比
  • DPM++ 2M: 高质量的确定性方法,适合高精度图像
  • DPM++ SDE: 添加随机性以增加多样性,适合创意探索
  • 可以在更少的步骤内获得优质结果

3.4 LMS (线性多步法)

  • 使用更多历史信息,可以获得高精度
  • 内存占用较高,需要存储多个历史步骤
  • 适合需要高精度的应用场景

4. 噪声调度 (Noise Schedulers)

采样过程质量很大程度上取决于噪声调度方法,即 σ\sigma 值如何随时间减小。常用调度包括:

4.1 Karras 调度

python
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device)

Karras 调度提供了更好的质量/步骤权衡,尤其是在步数较少时。它的数学形式为: σ(t)=(σmax1/ρ+t(σmin1/ρσmax1/ρ))ρ\sigma(t) = \left( \sigma_{\max}^{1/\rho} + t \cdot (\sigma_{\min}^{1/\rho} - \sigma_{\max}^{1/\rho}) \right)^{\rho}

4.2 指数调度

python
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): """Constructs an exponential noise schedule.""" sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() return append_zero(sigmas)

指数调度在对数空间线性减小噪声,是一种简单但有效的调度。

5. 采样器对最终图像的影响

不同采样器会对生成的图像产生明显影响:

  1. 细节和精度:

    • 高阶采样器(如 DPM++ 2M、Heun)通常产生更精细的细节
    • 低阶采样器(如 Euler)在同样步数下可能导致细节丢失或模糊
  2. 创意性和多样性:

    • 带有随机成分的采样器(如 Euler Ancestral、DPM++ SDE)产生更多样的结果
    • 确定性采样器(如 DPM++ 2M、Heun)在多次运行时给出相同结果
  3. 速度和效率:

    • 简单采样器(如 Euler)每步计算量小,但需要更多步骤
    • 高阶采样器(如 DPM++ 系列)每步计算量大,但总体需要更少步骤

6. 实际应用建议

根据不同需求选择适合的采样器:

  1. 快速预览: Euler, 20-30 步
  2. 最高质量: DPM++ 2M, 40-50 步
  3. 创意探索: Euler Ancestral 或 DPM++ SDE,添加随机性
  4. 细节精制: Heun 或 LMS,更高的数值精度

如果你需要在保持高质量的同时减少步数,DPM++ 2M 和 DPM++ SDE 通常是最佳选择,它们能在 25-30 步内提供接近 50 步 Euler 的质量。

如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!