在前期实验中,我尝试使用mmgen库进行通用图像和线稿的处理,但效果不尽如人意。这让我开始怀疑mmgen库在此类任务中的适用性。经过评估后,我决定转向另一个成熟的解决方案。
最终选择的实现方案来自以下GitHub仓库:
bash[email protected]:junyanz/pytorch-CycleGAN-and-pix2pix.git
采用Docker容器搭建训练环境,具体配置如下:
bashdocker run --gpus all --shm-size=32g -it --net host \ -v ./pytorch-CycleGAN-and-pix2pix/:/pytorch-CycleGAN-and-pix2pix/ \ -v ./anime-sketch-colorization-pair/data:/data \ pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel bash pip install -r requirements.txt
为方便后续使用,已将配置好的环境保存为镜像:
bashkevinchina/deeplearning:2.5.1-cuda12.1-cudnn9-devel-pix2pix
直接使用命令:
bashdocker run --gpus all --shm-size=32g -it --net host \ -v ./pytorch-CycleGAN-and-pix2pix/:/pytorch-CycleGAN-and-pix2pix/ \ -v ./anime-sketch-colorization-pair/data:/data \ kevinchina/deeplearning:2.5.1-cuda12.1-cudnn9-devel-pix2pix bash
首先启动visdom可视化服务器:
bashpython -m visdom.server -p 8097
采用以下两种训练方式之一:
基础训练命令:
bashpython train.py --dataroot /data/ --name anime --gpu_ids 0,1,2,3 \ --model pix2pix --direction BtoA --batch_size 192 --lr_policy cosine \ --num_threads 32 --init_type xavier --norm instance --netG unet_256
带visdom监控的训练:
bash-display_server http://remote_ip --display_port 8097
通过浏览器访问8097端口可查看训练效果:
经过200轮训练后,效果仍有提升空间:
深入研究后发现更优的彩色化方法:输入单通道黑白图,输出为LAB色彩空间中的AB通道。
bashcd /ssd/xiedong/image_color
docker run --gpus all --shm-size=32g -it --net host \
-v ./pytorch-CycleGAN-and-pix2pix/:/pytorch-CycleGAN-and-pix2pix/ \
-v ./unpaired_self_datasets/:/data \
kevinchina/deeplearning:2.5.1-cuda12.1-cudnn9-devel-pix2pix bash
nohup python -m visdom.server -p 8097 &
pip install -U scikit-image
cd /pytorch-CycleGAN-and-pix2pix/
python train.py \
--dataroot /data/ \
--name tongyong_l2ab_4 \
--gpu_ids 0,1,2 \
--model colorization \
--direction AtoB \
--batch_size 196 \
--lr_policy cosine \
--num_threads 32 \
--init_type xavier \
--norm instance \
--netG unet_256 \
--dataset_mode colorization \
--input_nc 1 \
--output_nc 2 \
--phase trainA \
--n_epochs 100
训练后可以绘制损失变化图:
python plot_loss_paper.py
• 系统自动将RGB图像转换为Lab色彩空间 • 分离L通道(亮度)作为输入 • 提取ab通道(色彩)作为输出目标 • 转换过程在ColorizationDataset类中自动完成
验证集输入(valA):
验证集目标(valB):
生成效果对比:
训练完成的模型保存在:
pytorch-CycleGAN-and-pix2pix/checkpoints
参数 | 说明 | 示例值 |
---|---|---|
--dataroot | 数据根目录 | /data/ |
--name | 实验名称 | tongyong_l2ab_4 |
--gpu_ids | 使用的GPU编号 | 0,1,2 |
参数 | 说明 | 可选值 |
---|---|---|
--model | 模型类型 | colorization |
--netG | 生成器架构 | unet_256 |
--norm | 归一化方法 | instance |
参数 | 说明 | 典型值 |
---|---|---|
--lr_policy | 学习率策略 | cosine |
--batch_size | 批次大小 | 196 |
--n_epochs | 训练周期数 | 100 |
典型训练日志示例:
(epoch: 39, iters: 115800, time: 0.007, data: 0.044) G_GAN: 2.019 G_L1: 8.129 D_real: 0.409 D_fake: 0.279
• 生成器总损失: • 判别器总损失:
采用CIE Lab色彩空间,其中:
U-Net结构包含:
采用PatchGAN架构:
要测试您训练好的模型(第35轮)使用您在"/data/testA"的测试数据,您应该在服务器上运行以下命令:
python test.py --dataroot /data/ --name tongyong_l2ab_4 --model colorization --direction AtoB --netG unet_256 --dataset_mode colorization --input_nc 1 --output_nc 2 --phase testA --epoch 35 --eval --num_test 500 --norm instance
如果想处理数据集中的所有图像,您需要将num_test设置为一个足够大的数字,确保它大于数据集中的图像总数。
结果会被保存到:results/tongyong_l2ab_4/testA_35/images/
此命令将:
关键参数说明:
--dataroot /data/
:数据目录的路径--name tongyong_l2ab_4
:您的实验名称(与训练名称匹配)--model colorization
:使用着色模型--netG unet_256
:使用与训练相同的生成器架构--input_nc 1 --output_nc 2
:与您的训练设置相同(L到ab颜色空间)--phase testA
:使用testA目录进行测试--epoch 35
:从第35轮加载模型--eval
:在评估模式下运行(禁用dropout)--num_test 500
:测试多达500张图像(您可以调整这个数字)报错了,需要改几句代码:
对于 PyTorch 2.5 版本,您需要修改服务器上的代码。服务器上的 load_networks
函数中有两个不兼容的参数:
weights_only=True
- 这是较新版本的 PyTorch 引入的参数strict=False
- 这个参数应该传递给 load_state_dict
而不是 torch.load
因此,正确的修改方案是:
models/base_model.py
文件load_networks
函数中的这一行:pythonstate_dict = torch.load(load_path, weights_only=True, map_location=self.device, strict=False)
pythonstate_dict = torch.load(load_path, map_location=str(self.device))
strict=False
的功能,则需要修改后面的 load_state_dict
调用:pythonnet.load_state_dict(state_dict, strict=False)
这样修改后应该可以与 PyTorch 2.5 兼容,同时保持原有功能。
weights_only=True 参数
map_location=self.device vs map_location=str(self.device)
strict=False 参数
testA目录 里应该是彩色的图片,ColorizationDataset 类里会进行这样的处理:
pythondef __getitem__(self, index):
path = self.AB_paths[index]
im = Image.open(path).convert('RGB') # <-- 打开图像并转换为RGB
im = self.transform(im)
im = np.array(im)
lab = color.rgb2lab(im).astype(np.float32) # <-- 将RGB转换为Lab
lab_t = transforms.ToTensor()(lab)
A = lab_t[[0], ...] / 50.0 - 1.0 # <-- 提取L通道(灰度)
B = lab_t[[1, 2], ...] / 110.0 # <-- 提取ab通道(色彩)
return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path}
通过定期验证集测试监控训练进展,关键步骤:
在第七章节后,我们获取到了结果图:./results/tongyong_l2ab_4/testA_35/images/
每三个图是一组,比如:
bash-rw-r--r-- 1 root root 134573 Apr 10 02:17 002_trafficScene_youxian1_000020_fake_B_rgb.png -rw-r--r-- 1 root root 132559 Apr 10 02:17 002_trafficScene_youxian1_000020_real_B_rgb.png -rw-r--r-- 1 root root 86333 Apr 10 02:17 002_trafficScene_youxian1_000020_real_A.png
real_A 是黑白图,real_B_rgb 是真实的RGB图,fake_B_rgb 是模型根据real_A进行上色得到的图。
指标就是为了评估fake_B_rgb是否能接近real_B_rgb。
为了得到评估指标,需要给容器里安装点环境:
apt update && apt install -y libgl1 pip install -r evaluation_requirements.txt # 在我自己的github项目里
运行评估脚本:
bash# 基本版本
python evaluate_basic.py --results_dir ./results/tongyong_l2ab_4/testA_35/images --output_dir ./evaluation_basic_results
# 完整版本
python evaluate_colorization.py --results_dir ./results/tongyong_l2ab_4/testA_35/images --output_dir ./evaluation_results
# 完整版本 带FID指标
python evaluate_colorization.py --results_dir ./results/tongyong_l2ab_4/testA_35/images --output_dir ./evaluation_results --use_fid
得到:
Metric | Mean | Std | Min | Max |
---|---|---|---|---|
SSIM | 0.892293 | 0.087055 | 0.496115 | 0.991259 |
PSNR | 22.315565 | 4.723317 | 9.698197 | 36.125977 |
MSE | 682.852914 | 889.681447 | 15.866557 | 6970.447413 |
MAE | 15.264377 | 9.815629 | 1.971761 | 66.205879 |
Color Error | 11.004399 | 7.5067 | 1.517296 | 51.797058 |
LPIPS | 0.193544 | 0.092814 | 0.009559 | 0.553977 |
FID | 49.714922 | - | - | - |
模型表现不错:
启动web服务:
python colorization_app.py
网页前端使用:
我们的彩色化模型采用U-Net[4]作为生成器的骨干网络。U-Net是一种编码器-解码器结构,增加了跳跃连接(skip connection)以保留细节信息。对于彩色化任务,我们采用unet_256变体,具有8个下采样层。
U-Net生成器的结构可以表示为:
具体而言,它由以下部分组成:
具体的数学表示为:
其中 表示下采样操作, 表示上采样与特征连接操作, 是输入的L通道, 是输出的ab通道预测。
我们使用PatchGAN作为判别器,它仅关注局部图像块(patch)的真实性而非整体图像,这有助于保持色彩的局部一致性和纹理细节。PatchGAN判别器可以表示为:
其中输入是L通道与ab通道(真实或生成)的拼接,输出是一个 的评分图,每个位置对应一个感受野为 的图像区域。判别器包含5个卷积层,中间使用实例归一化(Instance Normalization)和LeakyReLU激活函数。
在训练过程中,网络同时优化两个损失函数:
对于判别器 ,我们使用标准的GAN判别器损失:
对于生成器 ,我们结合GAN损失和L1损失:
其中,我们设置 来强调色彩的准确性。
镜像环境安装了一些别的,新的镜像名称为:kevinchina/deeplearning:2.5.1-cuda12.1-cudnn9-devel-pix2pix-webui
代码仓库:https://github.com/xxddccaa/pytorch-CycleGAN-and-pix2pix-color#
webui启动:
bashcd /ssd/xiedong/image_color
docker run --gpus device=2 --shm-size=32g -it --net host \
-v ./pytorch-CycleGAN-and-pix2pix/:/pytorch-CycleGAN-and-pix2pix/ \
-v ./unpaired_self_datasets/:/data \
kevinchina/deeplearning:2.5.1-cuda12.1-cudnn9-devel-pix2pix-webui bash
cd /pytorch-CycleGAN-and-pix2pix
python colorization_app.py
本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!