90天学会GAN--Day3--从MNIST数据集开始
来源:哔哩哔哩 |
时间:2023-06-02 11:09:20
4. 模型训练
4.1 损失函数
GAN的训练优化目标其实就是如下函数:
可以看到,这里有两个loss:一个是训练鉴别器时使用的 D_loss, 另一个是训练生成器时使用的 G_loss。
(资料图片)
而这个模型的目标是要最小化 G_loss, 以及最大化 D_loss。
这里我们使用了Adam优化策略和BCE loss 来优化这两个。 于是可以写出:
4.2 模型迭代
在模型迭代的过程中,我们会做如下步骤:
我们会读取图像和标签(暂时没用)
然后生成一个随机的噪声z 并放入生成器生成一张假的图片,称为fake_img
之后将fake _ image 放入鉴别器得出 fake _ image 的评分
将这个评分与 1 比较得到 G_loss
再将输入的图像和fake_image 加上真假标签后放入鉴别器中得到D _ loss
循环以上过程 opt.epoch 次
由此,我们可以得到这部分的代码:
至此,模型已经训练完毕。
5.保存图片以及模型
这里我们使用 torchvision.utils 库中的 save_image函数来存储图片,用法如下:
我们使用torch.save来保存模型即其中的参数,实际上需要保存的其实就是 generator 和 discriminator 这两个东西,用法如下:
然后使用的时候就只需要load一下就行了:
之后就像之前一样使用generator和discriminator就可以了。
这样做的好处是:validate的时候就不需要重新跑一次所有的程序了,只需要把之前的模型 load 出来用就行了
关键词: