Stable Diffusion 采样器是扩散模型去噪过程的核心组件,不同采样器采用不同的数学方法来逐步将纯噪声转换为有意义的图像。下面我将详细介绍几种主要采样器的工作原理、数学公式和代码实现。
在深入具体采样器前,先了解扩散模型的基本原理:
正向扩散过程将图像 逐步添加高斯噪声,直到变成纯噪声 :
其中 是预定义的噪声调度, 是从标准正态分布采样的噪声。
Stable Diffusion 的核心是训练一个神经网络 来预测每一步中添加的噪声。采样器的任务是使用这个网络逐步从纯噪声 恢复到干净图像 。
Read file: repositories/k-diffusion/k_diffusion/sampling.py
现在让我详细分析几种主要采样器的实现和数学原理:
Euler 采样器是最简单的采样器,使用欧拉法进行数值积分,它直接将扩散方程作为常微分方程(ODE)求解。
欧拉法的基本公式为:
在扩散模型的上下文中,转化为:
其中 是预测的噪声方向:
python@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# 可选的随机扰动增强
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
# 模型预测去噪后的图像
denoised = model(x, sigma_hat * s_in, **extra_args)
# 计算噪声方向
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# 计算时间步长
dt = sigmas[i + 1] - sigma_hat
# 应用欧拉法迭代
x = x + d * dt
return x
Euler Ancestral 采样器加入了随机扰动,改善样本多样性。
对标准欧拉法进行扩展,在每一步后添加经过缩放的随机噪声:
其中 控制添加的噪声量, 和 分别是新的噪声水平和需要添加的噪声量。
python@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# 模型预测去噪后的图像
denoised = model(x, sigmas[i] * s_in, **extra_args)
# 计算ancestral步长参数
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# 计算噪声方向
d = to_d(x, sigmas[i], denoised)
# 应用欧拉法更新
dt = sigma_down - sigmas[i]
x = x + d * dt
# 添加随机噪声
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
Heun 采样器是欧拉法的二阶改进版本,提供了更准确的数值积分。
Heun法的数学公式为:
这相当于预测-修正过程,提高了积分精度。
python@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# 可选的随机扰动增强
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
# 模型预测去噪后的图像(第一步)
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# 对于最后一步,使用欧拉法(避免在σ=0处除法问题)
x = x + d * dt
else:
# Heun法:预测-修正步骤
x_2 = x + d * dt # 预测步骤
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) # 在预测点评估
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2 # 平均梯度
x = x + d_prime * dt # 修正步骤
return x
DPM++ 2M 采样器是一种现代的高效采样器,基于DPM-Solver的思想,但优化了中点法规则。
DPM++ 2M 使用改进的中点法求解ODE:
其中 是中点噪声水平。
python@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Implements DPM++ (2M) sampler."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
# 当前噪声等级
sigma = sigmas[i]
# 下一个噪声等级
sigma_next = sigmas[i + 1]
# 计算混合系数
t, t_next = -sigma.log(), -sigma_next.log()
h = t_next - t
# 模型预测去噪后的图像
denoised = model(x, sigma * s_in, **extra_args)
d = (x - denoised) / sigma
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
# 中点法的第一步:计算中点估计
sigma_mid = sigma * (-0.5 * h).exp()
x_mid = x + d * (sigma_mid - sigma)
# 在中点评估模型
denoised_mid = model(x_mid, sigma_mid * s_in, **extra_args)
d_mid = (x_mid - denoised_mid) / sigma_mid
# 中点法的第二步:使用中点梯度更新
x = x + d_mid * (sigma_next - sigma)
return x
LMS 采样器使用线性多步法,根据多个历史步骤提高精度。
线性多步法的一般形式为:
其中 是根据多项式插值计算的系数。LMS 通常使用最近的 4 个步骤来提高精度。
python@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
"""Linear multistep sampler for discrete-time DPMs."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# 存储历史导数
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
# 模型预测
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
if len(ds) > order:
ds.pop(0)
# 使用线性多步法计算下一步
step_size = sigmas[i + 1] - sigmas[i]
x_t = x + step_size * sum(
linear_multistep_coeff(len(ds), sigmas[i], sigmas[i + 1], j + 1) * d
for j, d in enumerate(reversed(ds))
)
x = x_t
return x
各种采样器在速度、质量和特性上有不同的权衡:
采样过程质量很大程度上取决于噪声调度方法,即 值如何随时间减小。常用调度包括:
pythondef get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
Karras 调度提供了更好的质量/步骤权衡,尤其是在步数较少时。它的数学形式为:
pythondef get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
指数调度在对数空间线性减小噪声,是一种简单但有效的调度。
不同采样器会对生成的图像产生明显影响:
细节和精度:
创意性和多样性:
速度和效率:
根据不同需求选择适合的采样器:
如果你需要在保持高质量的同时减少步数,DPM++ 2M 和 DPM++ SDE 通常是最佳选择,它们能在 25-30 步内提供接近 50 步 Euler 的质量。
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!