90天学会GAN--Day3--从MNIST数据集开始

4. 模型训练

4.1 损失函数

GAN的训练优化目标其实就是如下函数:

可以看到,这里有两个loss:一个是训练鉴别器时使用的 D_loss, 另一个是训练生成器时使用的 G_loss。


(资料图片)

而这个模型的目标是要最小化 G_loss, 以及最大化 D_loss。 

这里我们使用了Adam优化策略和BCE loss 来优化这两个。 于是可以写出:

4.2 模型迭代

在模型迭代的过程中,我们会做如下步骤: 

我们会读取图像和标签(暂时没用)

然后生成一个随机的噪声z 并放入生成器生成一张假的图片,称为fake_img

之后将fake _ image 放入鉴别器得出 fake _ image 的评分

将这个评分与 1 比较得到 G_loss

再将输入的图像和fake_image 加上真假标签后放入鉴别器中得到D _ loss

循环以上过程 opt.epoch 次

由此,我们可以得到这部分的代码:

至此,模型已经训练完毕。

5.保存图片以及模型

这里我们使用 torchvision.utils 库中的 save_image函数来存储图片,用法如下:

我们使用torch.save来保存模型即其中的参数,实际上需要保存的其实就是 generator 和 discriminator 这两个东西,用法如下: 

然后使用的时候就只需要load一下就行了:

之后就像之前一样使用generator和discriminator就可以了。

这样做的好处是:validate的时候就不需要重新跑一次所有的程序了,只需要把之前的模型 load 出来用就行了

关键词: