Crea dati dal rumore casuale con le reti contraddittorie generative
Pubblicato: 2022-03-11Da quando ho scoperto le reti generative contraddittorio (GAN), ne sono rimasto affascinato. Un GAN è un tipo di rete neurale in grado di generare nuovi dati da zero. Puoi alimentarlo con un po' di rumore casuale come input e può produrre immagini realistiche di camere da letto, uccelli o qualsiasi cosa sia stato addestrato a generare.
Una cosa su cui tutti gli scienziati possono essere d'accordo è che abbiamo bisogno di più dati.
I GAN, che possono essere utilizzati per produrre nuovi dati in situazioni di dati limitati, possono rivelarsi davvero utili. La generazione dei dati a volte può essere difficile, costosa e dispendiosa in termini di tempo. Per essere utili, tuttavia, i nuovi dati devono essere sufficientemente realistici in modo che qualsiasi intuizione che otteniamo dai dati generati si applichi ancora ai dati reali. Se stai addestrando un gatto a cacciare topi e stai usando topi falsi, faresti meglio a assicurarti che i topi falsi assomiglino davvero.
Un altro modo di pensarci è che i GAN stanno scoprendo una struttura nei dati che consente loro di creare dati realistici. Questo può essere utile se non possiamo vedere quella struttura da soli o non possiamo estrarla con altri metodi.
In questo articolo imparerai come utilizzare i GAN per generare nuovi dati. Per mantenere questo tutorial realistico, utilizzeremo il set di dati di rilevamento delle frodi con carta di credito di Kaggle.
Nei miei esperimenti, ho provato a utilizzare questo set di dati per vedere se riesco a ottenere un GAN per creare dati sufficientemente realistici da aiutarci a rilevare casi fraudolenti. Questo set di dati evidenzia il problema limitato dei dati: su 285.000 transazioni, solo 492 sono frode. 492 casi di frode non sono un set di dati di grandi dimensioni su cui allenarsi, soprattutto quando si tratta di attività di apprendimento automatico in cui alle persone piace avere set di dati di diversi ordini di grandezza più grandi. Sebbene i risultati del mio esperimento non siano stati sorprendenti, ho imparato molto sui GAN lungo il percorso che sono felice di condividere.
Prima che inizi
Prima di addentrarci in questo regno dei GAN, se vuoi rispolverare rapidamente le tue capacità di machine learning o deep learning, puoi dare un'occhiata a questi due post correlati sul blog:
- Un'introduzione alla teoria dell'apprendimento automatico e alla sua applicazione: un tutorial visivo con esempi
- Un tutorial di deep learning: dai perceptron alle reti profonde
Perché i GAN?
Le reti generative contraddittorio (GAN) sono un'architettura di rete neurale che ha mostrato miglioramenti impressionanti rispetto ai metodi generativi precedenti, come codificatori automatici variazionali o macchine boltzman limitate. I GAN sono stati in grado di generare immagini più realistiche (ad esempio, DCGAN), abilitare il trasferimento di stile tra immagini (vedi qui e qui), generare immagini da descrizioni di testo (StackGAN) e imparare da set di dati più piccoli tramite l'apprendimento semi-supervisionato. A causa di questi risultati, stanno generando molto interesse sia nel settore accademico che in quello commerciale.
Il direttore della ricerca sull'intelligenza artificiale di Facebook, Yann LeCunn, li ha persino definiti lo sviluppo più entusiasmante dell'apprendimento automatico nell'ultimo decennio.
Le basi
Pensa a come impari. Se provi qualcosa, ottieni dei feedback. Regoli la tua strategia e riprovi.
Il feedback può arrivare sotto forma di critica, dolore o profitto. Potrebbe derivare dal tuo stesso giudizio su quanto bene hai fatto. Spesso, il feedback più utile è il feedback che arriva da un'altra persona, perché non è solo un numero o una sensazione, ma una valutazione intelligente di quanto bene hai svolto il compito.
Quando un computer viene addestrato per un'attività, l'essere umano di solito fornisce il feedback sotto forma di parametri o algoritmi regolati. Funziona bene quando l'attività è ben definita, come imparare a moltiplicare due numeri. Puoi facilmente ed esattamente dire al computer come si è sbagliato.
Con un compito più complicato, come creare un'immagine di cane, diventa più difficile fornire un feedback. L'immagine è sfocata, assomiglia più a un gatto o assomiglia a qualcosa? Potrebbero essere implementate statistiche complesse, ma sarebbe difficile catturare tutti i dettagli che fanno sembrare reale un'immagine.
Un essere umano può dare qualche stima, perché abbiamo molta esperienza nel valutare l'input visivo, ma siamo relativamente lenti e le nostre valutazioni possono essere altamente soggettive. Potremmo invece addestrare una rete neurale per imparare il compito di discriminare tra immagini reali e generate.
Quindi, lasciando che il generatore di immagini (anch'esso una rete neurale) e il discriminatore si alternino imparino l'uno dall'altro, possono migliorare nel tempo. Queste due reti, giocando a questo gioco, sono una rete contraddittoria generativa.
Puoi ascoltare l'inventore dei GAN, Ian Goodfellow, parlare di come una discussione in un bar su questo argomento abbia portato a una notte febbrile di codifica che ha portato al primo GAN. E sì, riconosce la sbarra nel suo giornale. Puoi saperne di più sui GAN dal blog di Ian Goodfellow su questo argomento.
Ci sono una serie di sfide quando si lavora con i GAN. L'addestramento di una singola rete neurale può essere difficile a causa del numero di scelte coinvolte: architettura, funzioni di attivazione, metodo di ottimizzazione, tasso di apprendimento e tasso di abbandono, solo per citarne alcuni.
I GAN raddoppiano tutte queste scelte e aggiungono nuove complessità. Sia il generatore che il discriminatore possono dimenticare i trucchi che hanno usato in precedenza nel loro addestramento. Ciò può portare le due reti a rimanere intrappolate in un ciclo stabile di soluzioni che non migliorano nel tempo. Una rete può sopraffare l'altra rete, in modo tale che nessuna delle due possa più imparare. Oppure, il generatore potrebbe non esplorare gran parte del possibile spazio di soluzione, ma solo quanto basta per trovare soluzioni realistiche. Quest'ultima situazione è nota come collasso modale.
Il collasso della modalità si verifica quando il generatore apprende solo un piccolo sottoinsieme delle possibili modalità realistiche. Ad esempio, se il compito è generare immagini di cani, il generatore potrebbe imparare a creare solo immagini di piccoli cani marroni. Il generatore avrebbe perso tutte le altre modalità composte da cani di altre taglie o colori.
Sono state implementate molte strategie per affrontare questo problema, inclusa la normalizzazione batch, l'aggiunta di etichette nei dati di addestramento o la modifica del modo in cui il discriminatore giudica i dati generati.
Le persone hanno notato che l'aggiunta di etichette ai dati, ovvero la suddivisione in categorie, migliora quasi sempre le prestazioni dei GAN. Invece di imparare a generare immagini di animali domestici in generale, dovrebbe essere più facile generare immagini di gatti, cani, pesci e furetti, ad esempio.
Forse le scoperte più significative nello sviluppo del GAN sono arrivate in termini di cambiamento del modo in cui il discriminatore valuta i dati, quindi diamo un'occhiata più da vicino.
Nella formulazione originale dei GAN nel 2014 di Goodfellow et al., il discriminatore genera una stima della probabilità che una data immagine fosse reale o generata. Al discriminatore verrebbe fornito un insieme di immagini che consisteva sia in immagini reali che generate e genererebbe una stima per ciascuno di questi input. L'errore tra l'uscita del discriminatore e le etichette effettive verrebbe quindi misurato dalla perdita di entropia incrociata. La perdita di entropia incrociata può essere equiparata alla metrica della distanza di Jensen-Shannon ed è stata mostrata all'inizio del 2017 da Arjovsky et al. che questa metrica fallirebbe in alcuni casi e non punterebbe nella giusta direzione in altri casi. Questo gruppo ha mostrato che la metrica della distanza di Wasserstein (nota anche come movimento terra o distanza EM) funzionava e funzionava meglio in molti altri casi.
La perdita di entropia incrociata è una misura della precisione con cui il discriminatore ha identificato le immagini reali e generate. La metrica di Wasserstein invece esamina la distribuzione di ogni variabile (cioè, ogni colore di ogni pixel) nelle immagini reali e generate, e determina quanto sono distanti le distribuzioni per i dati reali e generati. La metrica di Wasserstein esamina quanto sforzo, in termini di massa moltiplicata per la distanza, sarebbe necessario per portare la distribuzione generata nella forma della distribuzione reale, da cui il nome alternativo "distanza del motore della terra". Poiché la metrica di Wasserstein non valuta più se un'immagine è reale o meno, ma fornisce invece critiche su quanto siano lontane le immagini generate dalle immagini reali, la rete "discriminatrice" è indicata come la rete "critica" nel Wasserstein architettura.
Per un'esplorazione leggermente più completa dei GAN, in questo articolo esploreremo quattro diverse architetture:
- GAN: L'originale ("vaniglia") GAN
- CGAN: una versione condizionale del GAN originale che fa uso di etichette di classe
- WGAN: The Wasserstein GAN (con gradiente di penalità)
- WCGAN: una versione condizionale del GAN di Wasserstein
Ma diamo prima un'occhiata al nostro set di dati.
Uno sguardo ai dati sulle frodi con carta di credito
Lavoreremo con il set di dati di rilevamento delle frodi con carta di credito di Kaggle.
Il set di dati è costituito da circa 285.000 transazioni, di cui solo 492 fraudolente. I dati sono costituiti da 31 funzionalità: "tempo", "importo", "classe" e 28 funzionalità aggiuntive anonime. Il privilegio di classe è l'etichetta che indica se una transazione è fraudolenta o meno, con 0 che indica normale e 1 che indica frode. Tutti i dati sono numerici e continui (tranne l'etichetta). Il set di dati non ha valori mancanti. Il set di dati è già in buone condizioni per cominciare, ma farò un po' più di pulizia, per lo più semplicemente regolando la media di tutte le funzionalità su zero e le deviazioni standard su uno. Ho descritto di più il mio processo di pulizia nel taccuino qui. Per ora vi mostro solo il risultato finale:
Si possono facilmente individuare differenze tra i dati normali e di frode in queste distribuzioni, ma c'è anche molta sovrapposizione. Possiamo applicare uno degli algoritmi di machine learning più veloci e potenti per identificare le funzionalità più utili per identificare le frodi. Questo algoritmo, xgboost, è un algoritmo dell'albero decisionale con gradiente. Lo addestreremo sul 70% del set di dati e lo testeremo sul restante 30%. Possiamo impostare l'algoritmo per continuare fino a quando non migliora il richiamo (la frazione di campioni di frode rilevati) sul set di dati di test. Ciò consente di ottenere un richiamo del 76% sul set di test, il che lascia chiaramente margini di miglioramento. Raggiunge una precisione del 94%, il che significa che solo il 6% dei casi di frode previsti erano in realtà transazioni normali. Da questa analisi, otteniamo anche un elenco di funzionalità ordinate in base alla loro utilità nel rilevamento delle frodi. Possiamo utilizzare le funzionalità più importanti per visualizzare i nostri risultati in un secondo momento.

Ancora una volta, se avessimo più dati sulle frodi, potremmo essere in grado di rilevarli meglio. Cioè, potremmo ottenere un richiamo più elevato. Ora cercheremo di generare nuovi e realistici dati sulle frodi utilizzando i GAN per aiutarci a rilevare le frodi effettive.
Generazione di nuovi dati di carte di credito con GAN
Per applicare varie architetture GAN a questo set di dati, utilizzerò GAN-Sandbox, che ha una serie di architetture GAN popolari implementate in Python utilizzando la libreria Keras e un back-end TensorFlow. Tutti i miei risultati sono disponibili come taccuino Jupyter qui. Tutte le librerie necessarie sono incluse nell'immagine Kaggle/Python Docker, se hai bisogno di una facile configurazione.
Gli esempi in GAN-Sandbox sono impostati per l'elaborazione delle immagini. Il generatore produce un'immagine 2D con 3 canali di colore per ogni pixel e il discriminatore/critico è configurato per valutare tali dati. Le trasformazioni convoluzionali vengono utilizzate tra gli strati delle reti per sfruttare la struttura spaziale dei dati dell'immagine. Ciascun neurone in uno strato convoluzionale funziona solo con un piccolo gruppo di input e output (ad esempio pixel adiacenti in un'immagine) per consentire l'apprendimento delle relazioni spaziali. Il nostro set di dati delle carte di credito manca di qualsiasi struttura spaziale tra le variabili, quindi ho convertito le reti convoluzionali in reti con livelli densamente connessi. I neuroni in strati densamente connessi sono collegati a ogni input e output dello strato, consentendo alla rete di apprendere le proprie relazioni tra le caratteristiche. Userò questa configurazione per ciascuna delle architetture.
Il primo GAN che valuterò contrappone la rete del generatore alla rete del discriminatore, sfruttando la perdita di entropia incrociata dal discriminatore per addestrare le reti. Questa è l'architettura GAN originale, "vanilla". Il secondo GAN che valuterò aggiunge etichette di classe ai dati nel modo di un GAN condizionale (CGAN). Questo GAN ha un'altra variabile nei dati, l'etichetta della classe. Il terzo GAN utilizzerà la metrica della distanza di Wasserstein per addestrare le reti (WGAN) e l'ultimo utilizzerà le etichette di classe e la metrica della distanza di Wasserstein (WCGAN).
Addestreremo i vari GAN utilizzando un set di dati di addestramento composto da tutte le 492 transazioni fraudolente. Possiamo aggiungere classi al dataset di frode per facilitare le architetture GAN condizionali. Ho esplorato alcuni diversi metodi di clustering nel notebook e ho optato per una classificazione di KMeans che ordina i dati sulle frodi in 2 classi.
Allenerò ogni GAN per 5000 round ed esaminerò i risultati lungo il percorso. Nella Figura 4, possiamo vedere i dati di frode effettivi e i dati di frode generati dalle diverse architetture GAN man mano che la formazione procede. Possiamo vedere i dati di frode effettivi divisi nelle 2 classi KMeans, tracciati con le 2 dimensioni che meglio discriminano queste due classi (caratteristiche V10 e V17 dalle caratteristiche trasformate dalla PCA). I due GAN che non utilizzano le informazioni sulla classe, GAN e WGAN, hanno l'output generato come un'unica classe. Le architetture condizionali, CGAN e WCGAN, mostrano i dati generati per classe. Al passaggio 0, tutti i dati generati mostrano la distribuzione normale dell'input casuale alimentato ai generatori.
Possiamo vedere che l'architettura GAN originale inizia ad apprendere la forma e l'intervallo dei dati effettivi, ma poi crolla verso una piccola distribuzione. Questo è il collasso della modalità discusso in precedenza. Il generatore ha appreso una piccola gamma di dati che il discriminatore ha difficoltà a rilevare come falsi. L'architettura CGAN fa un po' meglio, diffondendo e avvicinandosi alle distribuzioni di ciascuna classe di dati sulle frodi, ma poi si verifica il collasso della modalità, come si può vedere al passaggio 5000.
Il WGAN non subisce il collasso della modalità mostrato dalle architetture GAN e CGAN. Anche senza le informazioni sulla classe, inizia a presumere la distribuzione non normale dei dati effettivi sulle frodi. L'architettura WCGAN funziona in modo simile ed è in grado di generare classi di dati separate.
Possiamo valutare quanto siano realistici i dati utilizzando lo stesso algoritmo xgboost utilizzato in precedenza per il rilevamento delle frodi. È veloce e potente e funziona immediatamente senza troppa messa a punto. Addestreremo il classificatore xgboost utilizzando metà dei dati di frode effettivi (246 campioni) e un numero uguale di esempi generati da GAN. Quindi testeremo il classificatore xgboost utilizzando l'altra metà dei dati di frode effettivi e un diverso insieme di 246 esempi generati GAN. Questo metodo ortogonale (in senso sperimentale) ci darà qualche indicazione del successo del generatore nel produrre dati realistici. Con dati generati perfettamente realistici, l'algoritmo xgboost dovrebbe raggiungere una precisione di 0,50 (50%), in altre parole, non è meglio che tirare a indovinare.
Possiamo vedere che l'accuratezza di xgboost sui dati generati da GAN diminuisce all'inizio, e poi aumenta dopo il passaggio di training 1000 quando inizia il collasso della modalità. L'architettura CGAN ottiene dati un po' più realistici dopo 2000 passaggi, ma poi il collasso della modalità si attiva per questa rete come bene. Le architetture WGAN e WCGAN ottengono dati più realistici più velocemente e continuano ad apprendere con il progredire della formazione. Il WCGAN non sembra avere molto vantaggio rispetto al WGAN, suggerendo che queste classi create potrebbero non essere utili per le architetture GAN di Wasserstein.
Puoi saperne di più sull'architettura WGAN da qui e qui.
La rete critica nelle architetture WGAN e WCGAN sta imparando a calcolare la distanza di Wasserstein (Earth-mover, EM) tra un dato set di dati e i dati di frode effettivi. Idealmente, misurerà una distanza prossima allo zero per un campione di dati di frode effettivi. Il critico, tuttavia, sta imparando come eseguire questo calcolo. Finché misura una distanza maggiore per i dati generati rispetto ai dati reali, la rete può migliorare. Possiamo osservare come la differenza tra le distanze di Wasserstein per i dati generati e quelli reali cambia nel corso dell'allenamento. Se si stabilizza, l'ulteriore formazione potrebbe non aiutare. Possiamo vedere nella figura 6 che sembra esserci un ulteriore miglioramento sia per il WGAN che per il WCGAN su questo set di dati.
Cosa abbiamo imparato?
Ora possiamo verificare se siamo in grado di generare nuovi dati sulle frodi sufficientemente realistici da aiutarci a rilevare i dati sulle frodi effettive. Possiamo prendere il generatore addestrato che ha ottenuto il punteggio di precisione più basso e usarlo per generare dati. Per il nostro set di formazione di base, utilizzeremo il 70% dei dati non sulle frodi (199.020 casi) e 100 casi dei dati sulle frodi (~20% dei dati sulle frodi). Quindi proveremo ad aggiungere diverse quantità di dati sulle frodi reali o generate a questo set di formazione, fino a 344 casi (70% dei dati sulle frodi). Per il set di test, utilizzeremo il restante 30% dei casi non di frode (85.295 casi) e dei casi di frode (148 casi). Possiamo provare ad aggiungere i dati generati da un GAN non addestrato e dal GAN meglio addestrato per verificare se i dati generati sono migliori del rumore casuale. Dai nostri test, sembra che la nostra migliore architettura sia stata la WCGAN nella fase di addestramento 4800, dove ha raggiunto una precisione xgboost del 70% (ricorda, idealmente, la precisione sarebbe del 50%). Quindi utilizzeremo questa architettura per generare nuovi dati sulle frodi.
Possiamo vedere nella figura 7 che il richiamo (la frazione di campioni di frode effettivi accuratamente identificati nel set di test) non aumenta poiché utilizziamo più dati di frode generati per la formazione. Il classificatore xgboost è in grado di conservare tutte le informazioni utilizzate per identificare le frodi dai 100 casi reali e non essere confuso dai dati aggiuntivi generati, anche quando li seleziona tra centinaia di migliaia di casi normali. I dati generati dal WCGAN non addestrato non aiutano o danneggiano, non sorprende. Ma anche i dati generati dal WCGAN addestrato non aiutano. Sembra che i dati non siano abbastanza realistici. Possiamo vedere nella figura 7 che quando i dati sulle frodi effettivi vengono utilizzati per integrare il set di formazione, il richiamo aumenta in modo significativo. Se il WCGAN avesse appena imparato a duplicare gli esempi di addestramento, senza diventare affatto creativo, avrebbe potuto ottenere tassi di richiamo più elevati come vediamo con i dati reali.
Verso l'infinito e oltre
Sebbene non siamo stati in grado di generare dati sulle frodi con carta di credito sufficientemente realistici da aiutarci a rilevare le frodi effettive, abbiamo a malapena scalfito la superficie con questi metodi. Potremmo allenarci più a lungo, con reti più grandi, e ottimizzare i parametri per le architetture che abbiamo provato in questo articolo. Le tendenze nell'accuratezza di xgboost e nella perdita del discriminatore suggeriscono che una maggiore formazione aiuterà le architetture WGAN e WCGAN. Un'altra opzione è quella di rivisitare la pulizia dei dati che abbiamo eseguito, magari progettare alcune nuove variabili o cambiare se e come affrontiamo l'asimmetria nelle funzionalità. Forse sarebbero utili diversi schemi di classificazione dei dati sulle frodi.
Potremmo anche provare altre architetture GAN. Il DRAGAN ha prove teoriche e sperimentali che dimostrano che si allena più velocemente e in modo più stabile rispetto ai GAN di Wasserstein. Potremmo integrare metodi che fanno uso dell'apprendimento semi-supervisionato, che hanno mostrato risultati promettenti nell'apprendimento da set di formazione limitati (vedi "Tecniche migliorate per i GAN di formazione"). Potremmo provare un'architettura che ci fornisca modelli comprensibili dall'uomo, così potremmo essere in grado di comprendere meglio la struttura dei dati (vedi InfoGAN).
Dovremmo anche tenere d'occhio i nuovi sviluppi nel campo e, ultimo ma non meno importante, possiamo lavorare per creare le nostre innovazioni in questo spazio in rapido sviluppo.
Puoi trovare tutto il codice rilevante per questo articolo in questo repository GitHub.