使用生成對抗網絡從隨機噪聲創建數據
已發表: 2022-03-11自從我發現了生成對抗網絡 (GAN) 後,我就對它們著迷了。 GAN 是一種能夠從頭開始生成新數據的神經網絡。 你可以給它一點隨機噪聲作為輸入,它可以生成臥室、鳥類或任何經過訓練生成的真實圖像。
所有科學家都同意的一件事是我們需要更多數據。
可用於在數據有限的情況下生成新數據的 GAN 可以證明是非常有用的。 數據有時可能難以生成、昂貴且耗時。 然而,為了有用,新數據必須足夠現實,以便我們從生成的數據中獲得的任何見解仍然適用於真實數據。 如果您正在訓練貓捕食老鼠,並且您使用的是假老鼠,則最好確保假老鼠實際上看起來像老鼠。
另一種思考方式是 GAN 正在發現數據中的結構,從而使它們能夠生成真實的數據。 如果我們自己看不到該結構或無法使用其他方法將其拉出,這將很有用。
在本文中,您將了解如何使用 GAN 生成新數據。 為了使本教程切合實際,我們將使用來自 Kaggle 的信用卡欺詐檢測數據集。
在我的實驗中,我嘗試使用這個數據集來看看我是否可以讓 GAN 創建足夠真實的數據來幫助我們檢測欺詐案件。 該數據集突出了有限的數據問題:在 285,000 筆交易中,只有 492 筆是欺詐。 492 個欺詐案例並不是一個可供訓練的大型數據集,尤其是在涉及機器學習任務時,人們喜歡擁有大幾個數量級的數據集。 儘管我的實驗結果並不令人驚訝,但我確實在此過程中學到了很多關於 GAN 的知識,我很樂意分享。
在你開始前
在我們深入研究 GAN 的這個領域之前,如果你想快速復習你的機器學習或深度學習技能,你可以看看這兩個相關的博客文章:
- 機器學習理論及其應用簡介:帶有示例的可視化教程
- 深度學習教程:從感知器到深度網絡
為什麼選擇 GAN?
生成對抗網絡 (GAN) 是一種神經網絡架構,與以前的生成方法(例如變分自動編碼器或受限 Bolzman 機器)相比,它已顯示出令人印象深刻的改進。 GAN 已經能夠生成更逼真的圖像(例如,DCGAN),支持圖像之間的風格轉移(參見此處和此處),從文本描述生成圖像(StackGAN),並通過半監督學習從較小的數據集中學習。 由於這些成就,他們在學術和商業領域都引起了極大的興趣。
Facebook 的 AI 研究總監 Yann LeCunn 甚至稱它們是過去十年機器學習領域最激動人心的發展。
基礎
想想你是如何學習的。 你嘗試一些東西,你會得到一些反饋。 你調整你的策略,然後再試一次。
反饋可能以批評、痛苦或利潤的形式出現。 這可能來自您對自己做得如何的判斷。 通常,最有用的反饋是來自另一個人的反饋,因為它不僅僅是一個數字或感覺,而是對您完成任務的好壞的智能評估。
當計算機被訓練完成一項任務時,人類通常以調整參數或算法的形式提供反饋。 當任務定義明確時,這很有效,例如學習將兩個數字相乘。 您可以輕鬆準確地告訴計算機它是如何出錯的。
對於更複雜的任務,例如創建狗的圖像,提供反饋變得更加困難。 圖像是否模糊,它看起來更像一隻貓,還是看起來像任何東西? 可以實現複雜的統計數據,但很難捕獲使圖像看起來真實的所有細節。
人類可以給出一些估計,因為我們在評估視覺輸入方面有很多經驗,但是我們相對較慢,而且我們的評估可能非常主觀。 相反,我們可以訓練一個神經網絡來學習區分真實圖像和生成圖像的任務。
然後,通過讓圖像生成器(也是一個神經網絡)和鑑別器輪流相互學習,它們可以隨著時間的推移而改進。 這兩個網絡,玩這個遊戲,是一個生成對抗網絡。
你可以聽到 GAN 的發明者 Ian Goodfellow 談到在酒吧里關於這個話題的爭論如何導致了一個狂熱的編碼之夜,從而產生了第一個 GAN。 是的,他確實承認他的論文中的酒吧。 您可以從 Ian Goodfellow 關於此主題的博客中了解有關 GAN 的更多信息。
使用 GAN 時存在許多挑戰。 由於涉及的選擇數量眾多,訓練單個神經網絡可能很困難:架構、激活函數、優化方法、學習率和輟學率,僅舉幾例。
GAN 將所有這些選擇加倍並增加了新的複雜性。 生成器和判別器都可能忘記他們之前在訓練中使用的技巧。 這可能導致兩個網絡陷入穩定的解決方案循環,並且不會隨著時間的推移而改善。 一個網絡可能會壓倒另一個網絡,以至於兩者都無法再學習。 或者,生成器可能不會探索很多可能的解決方案空間,僅足以找到現實的解決方案。 最後一種情況稱為模式崩潰。
模式崩潰是指生成器只學習可能的現實模式的一小部分。 例如,如果任務是生成狗的圖像,則生成器可以學習只創建小型棕色狗的圖像。 生成器會錯過由其他大小或顏色的狗組成的所有其他模式。
已經實施了許多策略來解決這個問題,包括批量標準化、在訓練數據中添加標籤,或者通過改變鑑別器判斷生成數據的方式。
人們已經註意到,為數據添加標籤——也就是說,將其分解為類別,幾乎總能提高 GAN 的性能。 例如,生成貓、狗、魚和雪貂的圖像應該更容易,而不是學習生成一般的寵物圖像。
也許 GAN 開發中最重要的突破來自於改變鑑別器評估數據的方式,所以讓我們仔細看看。
在 Goodfellow 等人在 2014 年提出的 GAN 的原始公式中,鑑別器生成給定圖像是真實的或生成的概率的估計值。 鑑別器將被提供一組由真實圖像和生成圖像組成的圖像,它將為每個輸入生成一個估計值。 然後,鑑別器輸出和實際標籤之間的誤差將通過交叉熵損失來衡量。 交叉熵損失可以等同於 Jensen-Shannon 距離度量,Arjovsky 等人在 2017 年初證明了這一點。 這個指標在某些情況下會失敗,而在其他情況下不會指向正確的方向。 該小組表明,Wasserstein 距離度量(也稱為推土機或 EM 距離)在更多情況下工作得更好。
交叉熵損失是鑑別器識別真實圖像和生成圖像的準確程度的度量。 Wasserstein 度量取而代之的是查看真實圖像和生成圖像中每個變量(即每個像素的每種顏色)的分佈,並確定真實數據和生成數據的分佈相距多遠。 Wasserstein 度量標準著眼於將生成的分佈推入真實分佈的形狀所需的質量乘以距離的多少努力,因此別名為“地球移動器距離”。 由於 Wasserstein 度量不再評估圖像是否真實,而是提供對生成的圖像與真實圖像的距離的批評,因此“鑑別器”網絡在 Wasserstein 中被稱為“批評者”網絡建築學。
為了對 GAN 進行更全面的探索,在本文中,我們將探索四種不同的架構:
- GAN:原始(“香草”)GAN
- CGAN:使用類標籤的原始 GAN 的條件版本
- WGAN:Wasserstein GAN(帶有梯度懲罰)
- WCGAN:Wasserstein GAN 的條件版本
但讓我們先看一下我們的數據集。
查看信用卡欺詐數據
我們將使用來自 Kaggle 的信用卡欺詐檢測數據集。
該數據集包含約 285,000 筆交易,其中只有 492 筆是欺詐性交易。 數據由 31 個特徵組成:“時間”、“數量”、“類別”和 28 個額外的匿名特徵。 類特徵是表示交易是否欺詐的標籤,0表示正常,1表示欺詐。 所有數據都是數字和連續的(標籤除外)。 數據集沒有缺失值。 數據集一開始就已經很好了,但我會做更多的清理工作,主要是將所有特徵的均值調整為 0,將標準差調整為 1。 我在這裡的筆記本中更多地描述了我的清潔過程。 現在我只展示最終結果:
人們可以很容易地發現這些分佈中正常數據和欺詐數據之間的差異,但也有很多重疊之處。 我們可以應用一種更快、更強大的機器學習算法來識別對識別欺詐最有用的特徵。 這個算法,xgboost,是一種梯度提升的決策樹算法。 我們將在 70% 的數據集上對其進行訓練,並在剩餘的 30% 上對其進行測試。 我們可以將算法設置為繼續運行,直到它不能提高測試數據集上的召回率(檢測到的欺詐樣本的比例)。 這在測試集上實現了 76% 的召回率,顯然還有改進的空間。 它確實達到了 94% 的精度,這意味著只有 6% 的預測欺詐案例實際上是正常交易。 從這個分析中,我們還得到了一個按其在檢測欺詐中的效用排序的特徵列表。 我們可以使用最重要的特徵來幫助稍後可視化我們的結果。
同樣,如果我們有更多的欺詐數據,我們可能能夠更好地檢測到它。 也就是說,我們可以實現更高的召回率。 我們現在將嘗試使用 GAN 生成新的、真實的欺詐數據,以幫助我們檢測實際的欺詐行為。

使用 GAN 生成新的信用卡數據
為了將各種 GAN 架構應用到這個數據集,我將使用 GAN-Sandbox,它使用 Python 使用 Keras 庫和 TensorFlow 後端實現了許多流行的 GAN 架構。 我的所有結果都可以在此處作為 Jupyter 筆記本獲得。 如果您需要簡單的設置,所有必要的庫都包含在 Kaggle/Python Docker 映像中。
GAN-Sandbox 中的示例是為圖像處理而設置的。 生成器為每個像素生成具有 3 個顏色通道的 2D 圖像,並且鑑別器/批評者被配置為評估此類數據。 在網絡的各層之間使用卷積變換來利用圖像數據的空間結構。 卷積層中的每個神經元僅與一小組輸入和輸出(例如,圖像中的相鄰像素)一起工作,以允許學習空間關係。 我們的信用卡數據集在變量之間缺乏任何空間結構,因此我將捲積網絡轉換為具有密集連接層的網絡。 密集連接層中的神經元連接到該層的每個輸入和輸出,允許網絡學習其自身特徵之間的關係。 我將為每個架構使用此設置。
我將評估的第一個 GAN 將生成器網絡與鑑別器網絡對比,利用鑑別器的交叉熵損失來訓練網絡。 這是原始的“香草”GAN 架構。 我將評估的第二個 GAN 以條件 GAN (CGAN) 的方式將類標籤添加到數據中。 這個 GAN 在數據中還有一個變量,即類標籤。 第三個 GAN 將使用 Wasserstein 距離度量來訓練網絡(WGAN),最後一個將使用類標籤和 Wasserstein 距離度量(WCGAN)。
我們將使用包含所有 492 個欺詐交易的訓練數據集來訓練各種 GAN。 我們可以向欺詐數據集添加類以促進條件 GAN 架構。 我在筆記本中探索了幾種不同的聚類方法,並使用了 KMeans 分類,將欺詐數據分為 2 個類。
我將對每個 GAN 進行 5000 輪訓練,並在此過程中檢查結果。 在圖 4 中,隨著訓練的進行,我們可以看到來自不同 GAN 架構的實際欺詐數據和生成的欺詐數據。 我們可以看到實際欺詐數據分為 2 個 KMeans 類,用最能區分這兩個類的 2 個維度(特徵 V10 和 V17 從 PCA 轉換特徵)繪製。 不使用類信息的兩個 GAN,GAN 和 WGAN,它們生成的輸出都作為一個類。 條件架構 CGAN 和 WCGAN 按類顯示它們生成的數據。 在步驟 0,所有生成的數據都顯示了饋送到生成器的隨機輸入的正態分佈。
我們可以看到,原始的 GAN 架構開始學習實際數據的形狀和範圍,但隨後向小分佈折疊。 這就是前面討論的模式崩潰。 生成器已經學習了鑑別器很難檢測為假的小範圍數據。 CGAN 架構做得更好,分散並接近每類欺詐數據的分佈,但隨後出現模式崩潰,如步驟 5000 所示。
WGAN 不會經歷 GAN 和 CGAN 架構所表現出的模式崩潰。 即使沒有類別信息,它也開始假設實際欺詐數據的非正態分佈。 WCGAN 架構的性能類似,並且能夠生成單獨的數據類別。
我們可以使用之前用於欺詐檢測的相同 xgboost 算法來評估數據的真實性。 它快速而強大,無需太多調整即可使用。 我們將使用一半的實際欺詐數據(246 個樣本)和相同數量的 GAN 生成的示例來訓練 xgboost 分類器。 然後我們將使用另一半實際欺詐數據和一組不同的 246 個 GAN 生成示例來測試 xgboost 分類器。 這種正交方法(在實驗意義上)將為我們提供一些關於生成器在生成真實數據方面的成功程度的指示。 對於完全真實的生成數據,xgboost 算法應該達到 0.50 (50%) 的準確度——換句話說,它並不比猜測好。
我們可以看到 GAN 生成數據的 xgboost 準確度首先下降,然後在訓練步驟 1000 後隨著模式崩潰的出現而增加。CGAN 架構在 2000 步後獲得了更真實的數據,但隨後該網絡的模式崩潰設置為好吧。 WGAN 和 WCGAN 架構更快地獲得更真實的數據,並隨著訓練的進行繼續學習。 WCGAN 似乎比 WGAN 沒有太多優勢,這表明這些創建的類可能對 Wasserstein GAN 架構沒有用處。
您可以從此處和此處了解有關 WGAN 架構的更多信息。
WGAN 和 WCGAN 架構中的批評者網絡正在學習計算給定數據集與實際欺詐數據之間的 Wasserstein(Earth-mover,EM)距離。 理想情況下,它將測量實際欺詐數據樣本的接近於零的距離。 然而,批評家正在學習如何進行這種計算。 只要它為生成的數據測量比真實數據更大的距離,網絡就可以改進。 我們可以觀察生成數據和真實數據的 Wasserstein 距離之間的差異如何在訓練過程中發生變化。 如果它停滯不前,那麼進一步的培訓可能無濟於事。 我們可以在圖 6 中看到,該數據集上的 WGAN 和 WCGAN 似乎都有進一步的改進。
我們學到了什麼?
現在我們可以測試我們是否能夠生成足夠真實的新欺詐數據來幫助我們檢測實際的欺詐數據。 我們可以採用經過訓練的獲得最低準確度分數的生成器並使用它來生成數據。 對於我們的基本訓練集,我們將使用 70% 的非欺詐數據(199,020 個案例)和 100 個欺詐數據案例(約 20% 的欺詐數據)。 然後,我們將嘗試將不同數量的真實或生成的欺詐數據添加到此訓練集中,最多 344 個案例(佔欺詐數據的 70%)。 對於測試集,我們將使用另外 30% 的非欺詐案例(85,295 例)和欺詐案例(148 例)。 我們可以嘗試添加未經訓練的 GAN 和經過最佳訓練的 GAN 生成的數據,以測試生成的數據是否比隨機噪聲更好。 從我們的測試來看,我們最好的架構似乎是訓練步驟 4800 的 WCGAN,它實現了 70% 的 xgboost 準確度(請記住,理想情況下,準確度為 50%)。 所以我們將使用這個架構來生成新的欺詐數據。
我們可以在圖 7 中看到,召回率(在測試集中準確識別的實際欺詐樣本的比例)沒有增加,因為我們使用更多生成的欺詐數據進行訓練。 xgboost 分類器能夠保留它用於從 100 個真實案例中識別欺詐的所有信息,並且不會被額外生成的數據弄糊塗,即使從數十萬個正常案例中挑選出來也是如此。 毫不奇怪,來自未經訓練的 WCGAN 生成的數據沒有幫助或傷害。 但訓練有素的 WCGAN 生成的數據也無濟於事。 看來數據不夠真實。 我們可以在圖 7 中看到,當使用實際欺詐數據來補充訓練集時,召回率顯著增加。 如果 WCGAN 剛剛學會復制訓練示例,完全沒有創意,它可能會獲得更高的召回率,就像我們在真實數據中看到的那樣。
超越無限
雖然我們無法生成足夠真實的信用卡欺詐數據來幫助我們檢測實際的欺詐行為,但我們對這些方法幾乎沒有觸及表面。 我們可以用更大的網絡訓練更長時間,並為我們在本文中嘗試的架構調整參數。 xgboost 準確性和鑑別器損失的趨勢表明,更多的訓練將有助於 WGAN 和 WCGAN 架構。 另一種選擇是重新審視我們執行的數據清理,也許設計一些新變量或改變我們是否以及如何解決特徵中的偏斜。 也許欺詐數據的不同分類方案會有所幫助。
我們也可以嘗試其他 GAN 架構。 DRAGAN 有理論和實驗證據表明它比 Wasserstein GAN 訓練更快、更穩定。 我們可以整合利用半監督學習的方法,這些方法在從有限的訓練集中學習方面顯示出了希望(參見“訓練 GAN 的改進技術”)。 我們可以嘗試一種為我們提供人類可理解模型的架構,這樣我們或許能夠更好地理解數據的結構(參見 InfoGAN)。
我們還應該關注該領域的新發展,最後但同樣重要的是,我們可以在這個快速發展的領域中創造自己的創新。
您可以在此 GitHub 存儲庫中找到本文的所有相關代碼。