使用生成对抗网络从随机噪声创建数据

已发表: 2022-03-11

自从我发现了生成对抗网络 (GAN) 后,我就对它们着迷了。 GAN 是一种能够从头开始生成新数据的神经网络。 你可以给它一点随机噪声作为输入,它可以生成卧室、鸟类或任何经过训练生成的真实图像。

所有科学家都同意的一件事是我们需要更多数据。

可用于在数据有限的情况下生成新数据的 GAN 可以证明是非常有用的。 数据有时可能难以生成、昂贵且耗时。 然而,为了有用,新数据必须足够现实,以便我们从生成的数据中获得的任何见解仍然适用于真实数据。 如果您正在训练猫捕食老鼠,并且您使用的是假老鼠,则最好确保假老鼠实际上看起来像老鼠。

另一种思考方式是 GAN 正在发现数据中的结构,从而使它们能够生成真实的数据。 如果我们自己看不到该结构或无法使用其他方法将其拉出,这将很有用。

生成对抗网络

在本文中,您将了解如何使用 GAN 生成新数据。 为了使本教程切合实际,我们将使用来自 Kaggle 的信用卡欺诈检测数据集。

在我的实验中,我尝试使用这个数据集来看看我是否可以让 GAN 创建足够真实的数据来帮助我们检测欺诈案件。 该数据集突出了有限的数据问题:在 285,000 笔交易中,只有 492 笔是欺诈。 492 个欺诈案例并不是一个可供训练的大型数据集,尤其是在涉及机器学习任务时,人们喜欢拥有大几个数量级的数据集。 尽管我的实验结果并不令人惊讶,但我确实在此过程中学到了很多关于 GAN 的知识,我很乐意分享。

在你开始前

在我们深入研究 GAN 的这个领域之前,如果你想快速复习你的机器学习或深度学习技能,你可以看看这两个相关的博客文章:

  • 机器学习理论及其应用简介:带有示例的可视化教程
  • 深度学习教程:从感知器到深度网络

为什么选择 GAN?

生成对抗网络 (GAN) 是一种神经网络架构,与以前的生成方法(例如变分自动编码器或受限 Bolzman 机器)相比,它已显示出令人印象深刻的改进。 GAN 已经能够生成更逼真的图像(例如,DCGAN),实现图像之间的风格转移(参见此处和此处),从文本描述生成图像(StackGAN),并通过半监督学习从较小的数据集中学习。 由于这些成就,他们在学术和商业领域都引起了极大的兴趣。

Facebook 的 AI 研究总监 Yann LeCunn 甚至称它们是过去十年机器学习领域最激动人心的发展。

基础

想想你是如何学习的。 你尝试一些东西,你会得到一些反馈。 你调整你的策略,然后再试一次。

反馈可能以批评、痛苦或利润的形式出现。 这可能来自您对自己做得如何的判断。 通常,最有用的反馈是来自另一个人的反馈,因为它不仅仅是一个数字或感觉,而是对您完成任务的好坏的智能评估。

当计算机被训练完成一项任务时,人类通常以调整参数或算法的形式提供反馈。 当任务定义明确时,这很有效,例如学习将两个数字相乘。 您可以轻松准确地告诉计算机它是如何出错的。

对于更复杂的任务,例如创建狗的图像,提供反馈变得更加困难。 图像是否模糊,它看起来更像一只猫,还是看起来像任何东西? 可以实现复杂的统计数据,但很难捕获使图像看起来真实的所有细节。

人类可以给出一些估计,因为我们在评估视觉输入方面有很多经验,但是我们相对较慢,而且我们的评估可能非常主观。 相反,我们可以训练一个神经网络来学习区分真实图像和生成图像的任务。

然后,通过让图像生成器(也是一个神经网络)和鉴别器轮流相互学习,它们可以随着时间的推移而改进。 这两个网络,玩这个游戏,是一个生成对抗网络。

你可以听到 GAN 的发明者 Ian Goodfellow 谈到在酒吧里关于这个话题的争论如何导致了一个狂热的编码之夜,从而产生了第一个 GAN。 是的,他确实承认他的论文中的酒吧。 您可以从 Ian Goodfellow 关于此主题的博客中了解有关 GAN 的更多信息。

GAN示意图

使用 GAN 时存在许多挑战。 由于涉及的选择数量众多,训练单个神经网络可能很困难:架构、激活函数、优化方法、学习率和辍学率,仅举几例。

GAN 将所有这些选择加倍并增加了新的复杂性。 生成器和判别器都可能忘记他们之前在训练中使用的技巧。 这可能导致两个网络陷入稳定的解决方案循环,并且不会随着时间的推移而改善。 一个网络可能会压倒另一个网络,以至于两者都无法再学习。 或者,生成器可能不会探索很多可能的解决方案空间,仅足以找到现实的解决方案。 最后一种情况称为模式崩溃。

模式崩溃是指生成器只学习可能的现实模式的一小部分。 例如,如果任务是生成狗的图像,则生成器可以学习只创建小型棕色狗的图像。 生成器会错过由其他大小或颜色的狗组成的所有其他模式。

已经实施了许多策略来解决这个问题,包括批量标准化、在训练数据中添加标签,或者通过改变鉴别器判断生成数据的方式。

人们已经注意到,为数据添加标签——也就是说,将其分解为类别,几乎总能提高 GAN 的性能。 例如,生成猫、狗、鱼和雪貂的图像应该更容易,而不是学习生成一般的宠物图像。

也许 GAN 开发中最重要的突破来自于改变鉴别器评估数据的方式,所以让我们仔细看看。

在 Goodfellow 等人在 2014 年提出的 GAN 的原始公式中,鉴别器生成给定图像是真实的或生成的概率的估计值。 鉴别器将被提供一组由真实图像和生成图像组成的图像,它将为每个输入生成一个估计值。 然后,鉴别器输出和实际标签之间的误差将通过交叉熵损失来衡量。 交叉熵损失可以等同于 Jensen-Shannon 距离度量,Arjovsky 等人在 2017 年初证明了这一点。 这个指标在某些情况下会失败,而在其他情况下不会指向正确的方向。 该小组表明,Wasserstein 距离度量(也称为推土机或 EM 距离)在更多情况下工作得更好。

交叉熵损失是鉴别器识别真实图像和生成图像的准确程度的度量。 Wasserstein 度量取而代之的是查看真实图像和生成图像中每个变量(即每个像素的每种颜色)的分布,并确定真实数据和生成数据的分布相距多远。 Wasserstein 度量以质量乘以距离的形式考察将生成的分布推入真实分布的形状需要付出多少努力,因此别名为“地球移动器距离”。 由于 Wasserstein 度量不再评估图像是否真实,而是提供对生成的图像与真实图像的距离的批评,因此“鉴别器”网络在 Wasserstein 中被称为“批评者”网络建筑学。

为了对 GAN 进行更全面的探索,在本文中,我们将探索四种不同的架构:

  • GAN:原始(“香草”)GAN
  • CGAN:使用类标签的原始 GAN 的条件版本
  • WGAN:Wasserstein GAN(带有梯度惩罚)
  • WCGAN:Wasserstein GAN 的条件版本

但让我们先看一下我们的数据集。

查看信用卡欺诈数据

我们将使用来自 Kaggle 的信用卡欺诈检测数据集。

该数据集包含约 285,000 笔交易,其中只有 492 笔是欺诈性交易。 数据由 31 个特征组成:“时间”、“数量”、“类别”和 28 个额外的匿名特征。 类特征是表示交易是否欺诈的标签,0表示正常,1表示欺诈。 所有数据都是数字和连续的(标签除外)。 数据集没有缺失值。 数据集一开始就已经很好了,但我会做更多的清理工作,主要是将所有特征的均值调整为 0,将标准差调整为 1。 我在这里的笔记本中更多地描述了我的清洁过程。 现在我只展示最终结果:

特征与类图

人们可以很容易地发现这些分布中正常数据和欺诈数据之间的差异,但也有很多重叠之处。 我们可以应用一种更快、更强大的机器学习算法来识别对识别欺诈最有用的特征。 这个算法,xgboost,是一种梯度提升的决策树算法。 我们将在 70% 的数据集上对其进行训练,并在剩余的 30% 上对其进行测试。 我们可以将算法设置为继续运行,直到它不能提高测试数据集上的召回率(检测到的欺诈样本的比例)。 这在测试集上实现了 76% 的召回率,显然还有改进的空间。 它确实达到了 94% 的精度,这意味着只有 6% 的预测欺诈案例实际上是正常交易。 从这个分析中,我们还得到了一个按其在检测欺诈中的效用排序的特征列表。 我们可以使用最重要的特征来帮助稍后可视化我们的结果。

同样,如果我们有更多的欺诈数据,我们可能能够更好地检测到它。 也就是说,我们可以实现更高的召回率。 我们现在将尝试使用 GAN 生成新的、真实的欺诈数据,以帮助我们检测实际的欺诈行为。

使用 GAN 生成新的信用卡数据

为了将各种 GAN 架构应用到这个数据集,我将使用 GAN-Sandbox,它使用 Python 使用 Keras 库和 TensorFlow 后端实现了许多流行的 GAN 架构。 我的所有结果都可以在此处作为 Jupyter 笔记本获得。 如果您需要简单的设置,所有必要的库都包含在 Kaggle/Python Docker 映像中。

GAN-Sandbox 中的示例是为图像处理而设置的。 生成器为每个像素生成具有 3 个颜色通道的 2D 图像,并且鉴别器/批评者被配置为评估此类数据。 在网络的各层之间使用卷积变换来利用图像数据的空间结构。 卷积层中的每个神经元仅与一小组输入和输出(例如,图像中的相邻像素)一起工作,以允许学习空间关系。 我们的信用卡数据集在变量之间缺乏任何空间结构,因此我将卷积网络转换为具有密集连接层的网络。 密集连接层中的神经元连接到该层的每个输入和输出,允许网络学习其自身特征之间的关系。 我将为每个架构使用此设置。

我将评估的第一个 GAN 将生成器网络与鉴别器网络对比,利用鉴别器的交叉熵损失来训练网络。 这是原始的“香草”GAN 架构。 我将评估的第二个 GAN 以条件 GAN (CGAN) 的方式将类标签添加到数据中。 这个 GAN 在数据中还有一个变量,即类标签。 第三个 GAN 将使用 Wasserstein 距离度量来训练网络(WGAN),最后一个将使用类标签和 Wasserstein 距离度量(WCGAN)。

GAN 架构

我们将使用包含所有 492 个欺诈交易的训练数据集来训练各种 GAN。 我们可以向欺诈数据集添加类以促进条件 GAN 架构。 我在笔记本中探索了几种不同的聚类方法,并使用了 KMeans 分类,将欺诈数据分为 2 个类。

我将对每个 GAN 进行 5000 轮训练,并在此过程中检查结果。 在图 4 中,随着训练的进行,我们可以看到来自不同 GAN 架构的实际欺诈数据和生成的欺诈数据。 我们可以看到实际欺诈数据分为 2 个 KMeans 类,用最能区分这两个类的 2 个维度(特征 V10 和 V17 从 PCA 转换特征)绘制。 不使用类信息的两个 GAN,GAN 和 WGAN,它们生成的输出都作为一个类。 条件架构 CGAN 和 WCGAN 按类显示它们生成的数据。 在步骤 0,所有生成的数据都显示了馈送到生成器的随机输入的正态分布。

GAN输出比较

我们可以看到,原始的 GAN 架构开始学习实际数据的形状和范围,但随后向小分布折叠。 这就是前面讨论的模式崩溃。 生成器已经学习了鉴别器很难检测为假的小范围数据。 CGAN 架构做得更好,分散并接近每类欺诈数据的分布,但随后出现模式崩溃,如步骤 5000 所示。

WGAN 不会经历 GAN 和 CGAN 架构所表现出的模式崩溃。 即使没有类别信息,它也开始假设实际欺诈数据的非正态分布。 WCGAN 架构的性能类似,并且能够生成单独的数据类别。

我们可以使用之前用于欺诈检测的相同 xgboost 算法来评估数据的真实性。 它快速而强大,无需太多调整即可使用。 我们将使用一半的实际欺诈数据(246 个样本)和相同数量的 GAN 生成的示例来训练 xgboost 分类器。 然后我们将使用另一半实际欺诈数据和一组不同的 246 个 GAN 生成示例来测试 xgboost 分类器。 这种正交方法(在实验意义上)将为我们提供一些关于生成器在生成真实数据方面的成功程度的指示。 对于完全真实的生成数据,xgboost 算法应该达到 0.50 (50%) 的准确度——换句话说,它并不比猜测好。

准确性

我们可以看到 GAN 生成数据的 xgboost 准确度首先下降,然后在训练步骤 1000 后随着模式崩溃的出现而增加。CGAN 架构在 2000 步后获得了更真实的数据,但随后该网络的模式崩溃设置为好吧。 WGAN 和 WCGAN 架构更快地获得更真实的数据,并随着训练的进行继续学习。 WCGAN 似乎比 WGAN 没有太多优势,这表明这些创建的类可能对 Wasserstein GAN 架构没有用处。

您可以从此处和此处了解有关 WGAN 架构的更多信息。

WGAN 和 WCGAN 架构中的批评者网络正在学习计算给定数据集与实际欺诈数据之间的 Wasserstein(Earth-mover,EM)距离。 理想情况下,它将测量实际欺诈数据样本的接近于零的距离。 然而,批评家正在学习如何进行这种计算。 只要它为生成的数据测量比真实数据更大的距离,网络就可以改进。 我们可以观察生成数据和真实数据的 Wasserstein 距离之间的差异如何在训练过程中发生变化。 如果它停滞不前,那么进一步的培训可能无济于事。 我们可以在图 6 中看到,该数据集上的 WGAN 和 WCGAN 似乎都有进一步的改进。

电磁距离估计

我们学到了什么?

现在我们可以测试我们是否能够生成足够真实的新欺诈数据来帮助我们检测实际的欺诈数据。 我们可以采用经过训练的获得最低准确度分数的生成器并使用它来生成数据。 对于我们的基本训练集,我们将使用 70% 的非欺诈数据(199,020 个案例)和 100 个欺诈数据案例(约 20% 的欺诈数据)。 然后,我们将尝试将不同数量的真实或生成的欺诈数据添加到此训练集中,最多 344 个案例(占欺诈数据的 70%)。 对于测试集,我们将使用另外 30% 的非欺诈案例(85,295 例)和欺诈案例(148 例)。 我们可以尝试添加未经训练的 GAN 和经过最佳训练的 GAN 生成的数据,以测试生成的数据是否比随机噪声更好。 从我们的测试来看,我们最好的架构似乎是训练步骤 4800 的 WCGAN,它实现了 70% 的 xgboost 准确度(请记住,理想情况下,准确度为 50%)。 所以我们将使用这个架构来生成新的欺诈数据。

我们可以在图 7 中看到,召回率(在测试集中准确识别的实际欺诈样本的比例)没有增加,因为我们使用更多生成的欺诈数据进行训练。 xgboost 分类器能够保留它用于从 100 个真实案例中识别欺诈的所有信息,并且不会被额外生成的数据弄糊涂,即使从数十万个正常案例中挑选出来也是如此。 毫不奇怪,来自未经训练的 WCGAN 生成的数据没有帮助或伤害。 但训练有素的 WCGAN 生成的数据也无济于事。 看来数据不够真实。 我们可以在图 7 中看到,当使用实际欺诈数据来补充训练集时,召回率显着增加。 如果 WCGAN 刚刚学会复制训练示例,完全没有创意,它可能会获得更高的召回率,就像我们在真实数据中看到的那样。

附加数据的影响

超越无限

虽然我们无法生成足够真实的信用卡欺诈数据来帮助我们检测实际的欺诈行为,但我们对这些方法几乎没有触及表面。 我们可以用更大的网络训练更长时间,并为我们在本文中尝试的架构调整参数。 xgboost 准确性和鉴别器损失的趋势表明,更多的训练将有助于 WGAN 和 WCGAN 架构。 另一种选择是重新审视我们执行的数据清理,也许设计一些新变量或改变我们是否以及如何解决特征中的偏斜。 也许欺诈数据的不同分类方案会有所帮助。

我们也可以尝试其他 GAN 架构。 DRAGAN 有理论和实验证据表明它比 Wasserstein GAN 训练更快、更稳定。 我们可以整合利用半监督学习的方法,这些方法在从有限的训练集中学习方面显示出了希望(参见“训练 GAN 的改进技术”)。 我们可以尝试一种为我们提供人类可理解模型的架构,这样我们或许能够更好地理解数据的结构(参见 InfoGAN)。

我们还应该关注该领域的新发展,最后但同样重要的是,我们可以在这个快速发展的领域中创造自己的创新。

您可以在此 GitHub 存储库中找到本文的所有相关代码。

相关: TensorFlow 中梯度下降的许多应用