CycleGAN包含4个网络:
其中:
CycleGAN使用ResNet架构的生成器:
pythonclass ResnetGenerator(nn.Module):
"""Resnet-based generator with 9个残差块"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
super(ResnetGenerator, self).__init__()
# ...
生成器结构包括:
初始层: 由反射填充、卷积、归一化和ReLU激活组成
pythonmodel = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
下采样层: 两个下采样模块,每个将特征图尺寸减小一半,通道数增加一倍
pythonn_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
残差块: 多个残差块(默认9个),维持特征图尺寸不变
pythonmult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
上采样层: 两个上采样模块,每个将特征图尺寸增大一倍,通道数减少一半
pythonfor i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
输出层: 生成目标域图像
pythonmodel += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
CycleGAN使用PatchGAN判别器,它对输入图像的重叠局部区域进行分类,判断是真实图像还是生成图像:
pythonclass NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
# ...
判别器结构包括:
CycleGAN使用多个损失函数组合:
对生成器和判别器使用对抗损失,默认使用最小二乘对抗损失(LSGAN):
对于生成器G_A:
对于生成器G_B:
代码实现:
pythonself.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss
# ...
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
确保转换的图像可以被转换回原始域:
代码实现:
pythonself.criterionCycle = torch.nn.L1Loss()
# ...
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
可选的损失,帮助保持颜色一致性:
代码实现:
pythonself.criterionIdt = torch.nn.L1Loss()
# ...
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
生成器总损失:
代码实现:
pythonself.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
判别器损失:
代码实现:
pythonloss_D = (loss_D_real + loss_D_fake) * 0.5
CycleGAN的训练过程交替更新生成器和判别器:
首先前向传播,生成假图像
pythonself.forward() # 生成fake_A、fake_B、rec_A、rec_B
更新生成器G_A和G_B
python# 固定判别器参数
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad() # 将生成器梯度清零
self.backward_G() # 计算生成器梯度
self.optimizer_G.step() # 更新生成器权重
更新判别器D_A和D_B
python# 解除判别器参数固定
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad() # 将判别器梯度清零
self.backward_D_A() # 计算D_A梯度
self.backward_D_B() # 计算D_B梯度
self.optimizer_D.step() # 更新判别器权重
图像缓冲池: 使用图像缓冲池存储之前生成的图像,减少模型震荡
pythonself.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size)
学习率调度: 在训练过程中逐渐减小学习率
pythonmodel.update_learning_rate() # 每个epoch开始时更新学习率
权重初始化: 使用特定初始化方法(例如正态分布初始化)
BatchNorm或InstanceNorm归一化: 帮助稳定训练
训练参数设置:
使用Adam优化器,参数如下:
pythonself.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
CycleGAN的创新点:
核心公式总结:
网络结构:
训练技巧:
通过这些设计,CycleGAN成功实现了无需配对数据的图像风格转换任务。
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!