Creați date din zgomot aleatoriu cu rețele generative adverse

Publicat: 2022-03-11

De când am aflat despre rețelele generative adversarial (GAN), am fost fascinat de ele. Un GAN este un tip de rețea neuronală care este capabil să genereze date noi de la zero. Îl puteți alimenta cu puțin zgomot aleatoriu ca intrare și poate produce imagini realiste ale dormitoarelor, ale păsărilor sau orice altceva este antrenat să genereze.

Un lucru asupra căruia toți oamenii de știință pot fi de acord este că avem nevoie de mai multe date.

GAN-urile, care pot fi folosite pentru a produce date noi în situații cu date limitate, se pot dovedi a fi cu adevărat utile. Datele pot fi uneori dificile și costisitoare și consumatoare de timp de a genera. Pentru a fi utile, totuși, noile date trebuie să fie suficient de realiste încât orice informații pe care le obținem din datele generate să se aplice în continuare datelor reale. Dacă antrenezi o pisică să vâneze șoareci și folosești șoareci falși, ar fi bine să te asiguri că șoarecii falși arată de fapt ca șoarecii.

Un alt mod de a gândi este că GAN-urile descoperă o structură în date care le permite să creeze date realiste. Acest lucru poate fi util dacă nu putem vedea acea structură singuri sau nu o putem scoate cu alte metode.

Rețele adversare generative

În acest articol, veți afla cum pot fi utilizate GAN-urile pentru a genera date noi. Pentru a menține acest tutorial realist, vom folosi setul de date de detectare a fraudei cu cardul de credit de la Kaggle.

În experimentele mele, am încercat să folosesc acest set de date pentru a vedea dacă pot obține un GAN pentru a crea date suficient de realiste pentru a ne ajuta să detectăm cazurile frauduloase. Acest set de date evidențiază problema limitată a datelor: din 285.000 de tranzacții, doar 492 sunt fraude. 492 de cazuri de fraudă nu este un set de date mare pe care să se antreneze, mai ales când vine vorba de sarcini de învățare automată în care oamenilor le place să aibă seturi de date cu câteva ordine de mărime mai mari. Deși rezultatele experimentului meu nu au fost uimitoare, am învățat multe despre GAN-uri pe parcurs, pe care sunt bucuros să le împărtășesc.

Inainte sa incepi

Înainte de a pătrunde în acest tărâm al GAN-urilor, dacă doriți să vă perfecționați rapid abilitățile de învățare automată sau de învățare profundă, puteți arunca o privire la aceste două postări legate de blog:

  • O introducere în teoria învățării automate și aplicarea acesteia: un tutorial vizual cu exemple
  • Un tutorial de învățare profundă: de la perceptroni la rețele profunde

De ce GAN-uri?

Rețelele adverse generative (GAN) sunt o arhitectură de rețea neuronală care a prezentat îmbunătățiri impresionante față de metodele generative anterioare, cum ar fi auto-encoderele variaționale sau mașinile Boltzman restricționate. GAN-urile au reușit să genereze imagini mai realiste (de exemplu, DCGAN), să permită transferul de stil între imagini (vezi aici și aici), să genereze imagini din descrieri de text (StackGAN) și să învețe din seturi de date mai mici prin învățarea semi-supravegheată. Datorită acestor realizări, ele generează mult interes atât în ​​sectorul academic, cât și în cel comercial.

Directorul de cercetare AI la Facebook, Yann LeCunn, le-a numit chiar cea mai interesantă dezvoltare în învățarea automată din ultimul deceniu.

Cele elementare

Gândește-te la modul în care înveți. Încerci ceva, primești feedback. Vă ajustați strategia și încercați din nou.

Feedback-ul poate veni sub formă de critică, durere sau profit. S-ar putea să vină din propria ta judecată despre cât de bine te-ai descurcat. Adesea, cel mai util feedback este feedback-ul care vine de la o altă persoană, pentru că nu este doar un număr sau o senzație, ci o evaluare inteligentă a cât de bine ai îndeplinit sarcina.

Când un computer este antrenat pentru o sarcină, omul oferă de obicei feedback-ul sub formă de parametri sau algoritmi ajustați. Acest lucru funcționează bine atunci când sarcina este bine definită, cum ar fi să înveți să înmulți două numere. Puteți spune cu ușurință și exact computerului cum a greșit.

Cu o sarcină mai complicată, cum ar fi crearea unei imagini a câinelui, devine mai dificil să oferi feedback. Imaginea este neclară, seamănă mai mult cu o pisică sau seamănă cu ceva? S-ar putea implementa statistici complexe, dar ar fi greu de surprins toate detaliile care fac ca o imagine să pară reală.

Un om poate da o anumită estimare, deoarece avem multă experiență în evaluarea inputurilor vizuale, dar suntem relativ lente și evaluările noastre pot fi foarte subiective. În schimb, am putea antrena o rețea neuronală pentru a învăța sarcina de a discrimina între imaginile reale și cele generate.

Apoi, lăsând generatorul de imagini (de asemenea, o rețea neuronală) și discriminatorul să învețe pe rând unul de la celălalt, se pot îmbunătăți în timp. Aceste două rețele, jucând acest joc, sunt o rețea adversară generativă.

Îl poți auzi pe inventatorul GAN-urilor, Ian Goodfellow, vorbind despre modul în care o ceartă la un bar pe această temă a dus la o noapte febrilă de codare care a dus la primul GAN. Și da, el recunoaște bara în ziarul său. Puteți afla mai multe despre GAN-uri de pe blogul lui Ian Goodfellow pe acest subiect.

Diagrama GAN

Există o serie de provocări atunci când lucrați cu GAN-uri. Antrenarea unei singure rețele neuronale poate fi dificilă din cauza numărului de opțiuni implicate: arhitectură, funcții de activare, metodă de optimizare, rata de învățare și rata abandonului, pentru a numi doar câteva.

GAN-urile dublează toate aceste opțiuni și adaugă noi complexități. Atât generatorul, cât și discriminatorul pot uita trucurile pe care le-au folosit mai devreme în antrenament. Acest lucru poate duce la ca cele două rețele să fie prinse într-un ciclu stabil de soluții care nu se îmbunătățesc în timp. O rețea poate depăși cealaltă rețea, astfel încât niciuna nu mai poate învăța. Sau, generatorul poate să nu exploreze o mare parte din spațiul posibil de soluții, doar suficient pentru a găsi soluții realiste. Această ultimă situație este cunoscută sub numele de colapsul modului.

Colapsul modului este atunci când generatorul învață doar un mic subset din posibilele moduri realiste. De exemplu, dacă sarcina este de a genera imagini cu câini, generatorul ar putea învăța să creeze doar imagini cu câini mici și maro. Generatorul ar fi ratat toate celelalte moduri constând din câini de alte dimensiuni sau culori.

Au fost implementate multe strategii pentru a aborda acest lucru, inclusiv normalizarea loturilor, adăugarea de etichete în datele de antrenament sau prin schimbarea modului în care discriminatorul judecă datele generate.

Oamenii au observat că adăugarea etichetelor la date, adică împărțirea în categorii, aproape întotdeauna îmbunătățește performanța GAN-urilor. În loc să învățați să generați imagini cu animale de companie în general, ar trebui să fie mai ușor să generați imagini cu pisici, câini, pești și dihori, de exemplu.

Poate că cele mai semnificative progrese în dezvoltarea GAN au venit în ceea ce privește schimbarea modului în care discriminatorul evaluează datele, așa că haideți să aruncăm o privire mai atentă la asta.

În formularea originală a GAN-urilor din 2014 de către Goodfellow și colab., discriminatorul generează o estimare a probabilității ca o anumită imagine să fie reală sau generată. Discriminatorului i se va furniza un set de imagini care consta atât din imagini reale, cât și din imagini generate și ar genera o estimare pentru fiecare dintre aceste intrări. Eroarea dintre ieșirea discriminatorului și etichetele reale ar fi apoi măsurată prin pierderea de entropie încrucișată. Pierderea de entropie încrucișată poate fi echivalată cu metrica distanței Jensen-Shannon și a fost demonstrată la începutul anului 2017 de către Arjovsky și colab. că această măsurătoare ar eșua în unele cazuri și nu ar indica direcția corectă în alte cazuri. Acest grup a arătat că metrica distanței Wasserstein (cunoscută și sub denumirea de dispozitivul de mișcare a pământului sau distanța EM) a funcționat și a funcționat mai bine în multe mai multe cazuri.

Pierderea de entropie încrucișată este o măsură a cât de precis a identificat discriminatorul imaginile reale și generate. În schimb, metrica Wasserstein analizează distribuția fiecărei variabile (adică fiecare culoare a fiecărui pixel) în imaginile reale și generate și determină cât de departe sunt distribuțiile pentru datele reale și cele generate. Metrica Wasserstein analizează cât de mult efort, în termeni de masă ori distanță, ar fi necesar pentru a împinge distribuția generată în forma distribuției reale, de unde și numele alternativ „distanța de mișcare a pământului”. Deoarece metrica Wasserstein nu mai evaluează dacă o imagine este reală sau nu, ci oferă în schimb o critică cu privire la cât de departe sunt imaginile generate de imaginile reale, rețeaua „discriminatoare” este denumită rețea „critică” în Wasserstein. arhitectură.

Pentru o explorare puțin mai cuprinzătoare a GAN-urilor, în acest articol, vom explora patru arhitecturi diferite:

  • GAN: GAN original („vanilie”)
  • CGAN: O versiune condiționată a GAN original care utilizează etichete de clasă
  • WGAN: Wasserstein GAN (cu gradient-penalizare)
  • WCGAN: O versiune condiționată a Wasserstein GAN

Dar să aruncăm o privire mai întâi la setul nostru de date.

O privire asupra datelor privind fraudele cu cardul de credit

Vom lucra cu setul de date de detectare a fraudelor cu cardul de credit de la Kaggle.

Setul de date este format din ~285.000 de tranzacții, dintre care doar 492 sunt frauduloase. Datele constau din 31 de caracteristici: „timp”, „cantitate”, „clasă” și 28 de caracteristici suplimentare, anonimizate. Caracteristica de clasă este eticheta care indică dacă o tranzacție este frauduloasă sau nu, cu 0 indicând normal și 1 indicând fraudă. Toate datele sunt numerice și continue (cu excepția etichetei). Setul de date nu are valori lipsă. Setul de date este deja într-o formă destul de bună pentru început, dar voi face puțin mai multă curățare, mai ales ajustând mijloacele tuturor caracteristicilor la zero și abaterile standard la unu. Am descris mai mult procesul meu de curățare în caietul de note de aici. Deocamdată voi arăta doar rezultatul final:

Caracteristici vs. Grafice de clasă

Se pot observa cu ușurință diferențe între datele normale și cele frauduloase din aceste distribuții, dar există și o mulțime de suprapunere. Putem aplica unul dintre algoritmii de învățare automată mai rapid și mai puternic pentru a identifica cele mai utile caracteristici pentru identificarea fraudelor. Acest algoritm, xgboost, este un algoritm de arbore de decizie cu gradient. Îl vom antrena pe 70% din setul de date și îl vom testa pe restul de 30%. Putem configura algoritmul să continue până când nu se îmbunătățește reamintirea (fracțiunea de eșantioane de fraudă detectate) pe setul de date de testare. Acest lucru realizează o reamintire de 76% pe setul de testare, ceea ce lasă în mod clar loc de îmbunătățire. Realizează o precizie de 94%, ceea ce înseamnă că doar 6% dintre cazurile de fraudă prezise au fost de fapt tranzacții normale. Din această analiză, obținem și o listă de funcții sortate în funcție de utilitatea lor în detectarea fraudelor. Putem folosi cele mai importante funcții pentru a ne ajuta să ne vizualizăm rezultatele mai târziu.

Din nou, dacă am avea mai multe date despre fraudă, s-ar putea să le detectăm mai bine. Adică, am putea obține o reamintire mai mare. Vom încerca acum să generăm date noi, realiste privind fraudele, folosind GAN-uri pentru a ne ajuta să detectăm frauda reală.

Generarea de noi date de card de credit cu GAN-uri

Pentru a aplica diferite arhitecturi GAN la acest set de date, voi folosi GAN-Sandbox, care are o serie de arhitecturi GAN populare implementate în Python folosind biblioteca Keras și un back-end TensorFlow. Toate rezultatele mele sunt disponibile ca notebook Jupyter aici. Toate bibliotecile necesare sunt incluse în imaginea Kaggle/Python Docker, dacă aveți nevoie de o configurare ușoară.

Exemplele din GAN-Sandbox sunt configurate pentru procesarea imaginilor. Generatorul produce o imagine 2D cu 3 canale de culoare pentru fiecare pixel, iar discriminatorul/criticul este configurat pentru a evalua astfel de date. Transformările convoluționale sunt utilizate între straturi ale rețelelor pentru a profita de structura spațială a datelor de imagine. Fiecare neuron dintr-un strat convoluțional funcționează doar cu un grup mic de intrări și ieșiri (de exemplu, pixeli adiacenți dintr-o imagine) pentru a permite învățarea relațiilor spațiale. Setul nostru de date pentru cardul de credit nu are nicio structură spațială între variabile, așa că am convertit rețelele convoluționale în rețele cu straturi dens conectate. Neuronii din straturi dens conectate sunt conectați la fiecare intrare și ieșire a stratului, permițând rețelei să învețe propriile relații între caracteristici. Voi folosi această configurație pentru fiecare dintre arhitecturi.

Primul GAN ​​pe care îl voi evalua confruntă rețeaua generatorului cu rețeaua discriminatorului, utilizând pierderea de entropie încrucișată de la discriminator pentru a antrena rețelele. Aceasta este arhitectura GAN originală, „vanilie”. Al doilea GAN pe care îl voi evalua adaugă etichete de clasă la date în maniera unui GAN condiționat (CGAN). Acest GAN mai are o variabilă în date, eticheta clasei. Al treilea GAN va folosi metrica distanței Wasserstein pentru a antrena rețelele (WGAN), iar ultimul va folosi etichetele de clasă și metrica distanței Wasserstein (WCGAN).

Arhitecturi GAN

Vom instrui diferitele GAN folosind un set de date de instruire care constă din toate cele 492 de tranzacții frauduloase. Putem adăuga clase la setul de date privind fraudele pentru a facilita arhitecturile GAN condiționate. Am explorat câteva metode diferite de grupare în blocnotes și am urmat o clasificare KMeans care sortează datele despre fraudă în 2 clase.

Voi antrena fiecare GAN pentru 5000 de runde și voi examina rezultatele pe parcurs. În Figura 4, putem vedea datele reale despre fraudă și datele generate de fraudă din diferitele arhitecturi GAN pe măsură ce antrenamentul progresează. Putem vedea datele efective de fraudă împărțite în cele 2 clase KMeans, reprezentate grafic cu cele 2 dimensiuni care discriminează cel mai bine aceste două clase (caracteristicile V10 și V17 din caracteristicile transformate PCA). Cele două GAN-uri care nu folosesc informații despre clasă, GAN și WGAN, au rezultatul lor generat ca o singură clasă. Arhitecturile condiționate, CGAN și WCGAN, arată datele lor generate pe clasă. La pasul 0, toate datele generate arată distribuția normală a intrării aleatorii alimentate la generatoare.

Comparație de ieșire GAN

Putem vedea că arhitectura GAN originală începe să învețe forma și gama datelor reale, dar apoi se prăbușește către o distribuție mică. Acesta este colapsul modului discutat mai devreme. Generatorul a învățat o gamă mică de date pe care discriminatorul le detectează cu greu ca fiind false. Arhitectura CGAN se descurcă puțin mai bine, extinzându-se și apropiindu-se de distribuțiile fiecărei clase de date frauduloase, dar apoi se instalează colapsul modului, așa cum se poate vedea la pasul 5000.

WGAN nu experimentează colapsul modului prezentat de arhitecturile GAN și CGAN. Chiar și fără informații de clasă, începe să presupună distribuția nenormală a datelor efective de fraudă. Arhitectura WCGAN funcționează similar și este capabilă să genereze clase separate de date.

Putem evalua cât de realiste arată datele folosind același algoritm xgboost folosit anterior pentru detectarea fraudelor. Este rapid și puternic și funcționează de la raft fără prea multe reglaje. Vom antrena clasificatorul xgboost folosind jumătate din datele efective de fraudă (246 de mostre) și un număr egal de exemple generate de GAN. Apoi vom testa clasificatorul xgboost folosind cealaltă jumătate a datelor de fraudă reale și un set diferit de 246 de exemple generate de GAN. Această metodă ortogonală (în sens experimental) ne va oferi unele indicații despre cât de mult succes are generatorul în producerea de date realiste. Cu date generate perfect realiste, algoritmul xgboost ar trebui să atingă o acuratețe de 0,50 (50%) - cu alte cuvinte, nu este mai bine decât ghicitul.

Precizie

Putem vedea că acuratețea xgboost a datelor generate de GAN scade la început și apoi crește după pasul de antrenament 1000, pe măsură ce se instalează colapsul modului. Arhitectura CGAN realizează date ceva mai realiste după 2000 de pași, dar apoi colapsul modului se instalează pentru această rețea ca bine. Arhitecturile WGAN și WCGAN obțin date mai realiste mai rapid și continuă să învețe pe măsură ce antrenamentul progresează. WCGAN nu pare să aibă un avantaj față de WGAN, sugerând că aceste clase create ar putea să nu fie utile pentru arhitecturile GAN Wasserstein.

Puteți afla mai multe despre arhitectura WGAN de aici și aici.

Rețeaua critică din arhitecturile WGAN și WCGAN învață să calculeze distanța Wasserstein (Earth-mover, EM) dintre un anumit set de date și datele reale de fraudă. În mod ideal, va măsura o distanță aproape de zero pentru un eșantion de date reale despre fraudă. Criticul, însă, este în proces de a învăța cum să efectueze acest calcul. Atâta timp cât măsoară o distanță mai mare pentru datele generate decât pentru datele reale, rețeaua se poate îmbunătăți. Putem urmări cum se modifică diferența dintre distanțele Wasserstein pentru datele generate și cele reale pe parcursul antrenamentului. Dacă se stabilește, atunci antrenamentul suplimentar ar putea să nu ajute. Putem vedea în figura 6 că pare să existe o îmbunătățire suplimentară atât pentru WGAN, cât și pentru WCGAN pentru acest set de date.

Estimarea distanței EM

Ce am învățat?

Acum putem testa dacă suntem capabili să generăm noi date de fraudă suficient de realiste pentru a ne ajuta să detectăm datele reale de fraudă. Putem lua generatorul instruit care a obținut cel mai mic scor de precizie și îl putem folosi pentru a genera date. Pentru setul nostru de instruire de bază, vom folosi 70% din datele non-fraudă (199.020 de cazuri) și 100 de cazuri de date despre fraudă (~20% din datele despre fraudă). Apoi vom încerca să adăugăm diferite cantități de date reale sau generate de fraudă la acest set de instruire, până la 344 de cazuri (70% din datele de fraudă). Pentru setul de testare, vom folosi celelalte 30% din cazurile non-fraudă (85.295 cazuri) și cazurile de fraudă (148 cazuri). Putem încerca să adăugăm date generate de la un GAN neantrenat și de la cel mai bine antrenat GAN pentru a testa dacă datele generate sunt mai bune decât zgomotul aleatoriu. Din testele noastre, se pare că cea mai bună arhitectură a noastră a fost WCGAN la pasul de antrenament 4800, unde a atins o precizie xgboost de 70% (rețineți că, în mod ideal, precizia ar fi de 50%). Așa că vom folosi această arhitectură pentru a genera noi date privind fraudele.

Putem vedea în figura 7 că retragerea (fracțiunea de eșantioane de fraudă reale identificate cu precizie în setul de testare) nu crește pe măsură ce folosim mai multe date generate de fraudă pentru instruire. Clasificatorul xgboost este capabil să rețină toate informațiile pe care le-a folosit pentru a identifica frauda din cele 100 de cazuri reale și să nu se încurce de datele suplimentare generate, chiar și atunci când le alege din sute de mii de cazuri normale. Datele generate de WCGAN neinstruit nu ajută și nu rănesc, fără a fi surprinzător. Dar nici datele generate de WCGAN instruit nu ajută. Se pare că datele nu sunt suficient de realiste. Putem vedea în figura 7 că atunci când datele reale privind fraudele sunt utilizate pentru a suplimenta setul de instruire, reamintirea crește semnificativ. Dacă WCGAN tocmai ar fi învățat să dubleze exemplele de antrenament, fără a fi deloc creativ, ar fi putut obține rate mai mari de reamintire, așa cum vedem cu datele reale.

Efectul datelor suplimentare

La infinit și dincolo de

Deși nu am reușit să generăm date despre frauda cu cardul de credit suficient de realiste pentru a ne ajuta să detectăm frauda reală, abia am zgâriat suprafața cu aceste metode. Ne-am putea antrena mai mult, cu rețele mai mari și am putea regla parametrii pentru arhitecturile pe care le-am încercat în acest articol. Tendințele în ceea ce privește acuratețea xgboost și pierderea discriminatorului sugerează că mai multă pregătire va ajuta arhitecturile WGAN și WCGAN. O altă opțiune este să revizuim curățarea datelor pe care am efectuat-o, poate să proiectăm câteva variabile noi sau să schimbăm dacă și cum abordăm asimetria caracteristicilor. Poate că diferite scheme de clasificare a datelor despre fraudă ar ajuta.

Am putea încerca și alte arhitecturi GAN. DRAGAN are dovezi teoretice și experimentale care arată că se antrenează mai rapid și mai stabil decât GAN-urile Wasserstein. Am putea integra metode care folosesc învățarea semi-supravegheată, care s-au dovedit promițătoare în învățarea din seturi limitate de antrenament (a se vedea „Tehnici îmbunătățite pentru formarea GAN-urilor”). Am putea încerca o arhitectură care să ne ofere modele ușor de înțeles de către om, astfel încât să putem înțelege mai bine structura datelor (vezi InfoGAN).

De asemenea, ar trebui să fim atenți la noile evoluții în domeniu și, nu în ultimul rând, putem lucra la crearea propriilor noastre inovații în acest spațiu în dezvoltare rapidă.

Puteți găsi tot codul relevant pentru acest articol în acest depozit GitHub.

Înrudit : numeroasele aplicații ale coborârii gradientului în TensorFlow