Twórz dane z losowego szumu za pomocą generatywnych sieci przeciwstawnych
Opublikowany: 2022-03-11Odkąd dowiedziałem się o generatywnych sieciach kontradyktoryjnych (GAN), byłem nimi zafascynowany. GAN to rodzaj sieci neuronowej, która jest w stanie generować nowe dane od podstaw. Możesz podać mu trochę losowego szumu jako dane wejściowe i może generować realistyczne obrazy sypialni, ptaków lub czegokolwiek, do czego jest wyszkolony.
Wszyscy naukowcy mogą się zgodzić, że potrzebujemy więcej danych.
Sieci GAN, które można wykorzystać do tworzenia nowych danych w sytuacjach, w których dane są ograniczone, mogą okazać się naprawdę przydatne. Generowanie danych może być czasami trudne, kosztowne i czasochłonne. Aby jednak były użyteczne, nowe dane muszą być na tyle realistyczne, aby wszelkie spostrzeżenia, jakie uzyskamy z wygenerowanych danych, nadal mają zastosowanie do danych rzeczywistych. Jeśli szkolisz kota do polowania na myszy i używasz fałszywych myszy, lepiej upewnij się, że fałszywe myszy rzeczywiście wyglądają jak myszy.
Innym sposobem myślenia o tym jest to, że GAN odkrywają strukturę danych, która pozwala im tworzyć realistyczne dane. Może to być przydatne, jeśli sami nie możemy zobaczyć tej struktury lub nie możemy jej wyciągnąć innymi metodami.
W tym artykule dowiesz się, jak można wykorzystać GAN do generowania nowych danych. Aby ten samouczek był realistyczny, użyjemy zestawu danych wykrywania oszustw związanych z kartami kredytowymi firmy Kaggle.
W moich eksperymentach próbowałem użyć tego zestawu danych, aby sprawdzić, czy mogę uzyskać GAN do tworzenia danych wystarczająco realistycznych, aby pomóc nam wykrywać przypadki oszustwa. Ten zestaw danych podkreśla ograniczony problem z danymi: z 285 000 transakcji tylko 492 to oszustwa. 492 przypadki oszustw nie są dużym zbiorem danych do przeszkolenia, zwłaszcza jeśli chodzi o zadania uczenia maszynowego, w których ludzie lubią mieć zbiory danych większe o kilka rzędów wielkości. Chociaż wyniki mojego eksperymentu nie były zdumiewające, po drodze dowiedziałem się wiele o GAN, którym chętnie się podzielę.
Zanim zaczniesz
Zanim zagłębimy się w tę dziedzinę GAN, jeśli chcesz szybko odświeżyć swoje umiejętności uczenia maszynowego lub głębokiego uczenia się, możesz zapoznać się z tymi dwoma powiązanymi postami na blogu:
- Wprowadzenie do teorii uczenia maszynowego i jej zastosowania: wizualny samouczek z przykładami
- Samouczek dotyczący głębokiego uczenia się: od perceptronów do głębokich sieci
Dlaczego GAN?
Generacyjne sieci kontradyktoryjne (GAN) to architektura sieci neuronowej, która wykazała imponującą poprawę w stosunku do poprzednich metod generatywnych, takich jak wariacyjne autokodery lub ograniczone maszyny typu boltzman. Sieci GAN były w stanie generować bardziej realistyczne obrazy (np. DCGAN), umożliwiały przenoszenie stylów między obrazami (patrz tutaj i tutaj), generowały obrazy z opisów tekstowych (StackGAN) i uczyły się z mniejszych zestawów danych poprzez uczenie częściowo nadzorowane. Dzięki tym osiągnięciom wzbudzają duże zainteresowanie zarówno w sektorze akademickim, jak i komercyjnym.
Dyrektor ds. badań nad sztuczną inteligencją w Facebooku, Yann LeCunn, nazwał je nawet najbardziej ekscytującym osiągnięciem w dziedzinie uczenia maszynowego w ciągu ostatniej dekady.
Podstawy
Pomyśl o tym, jak się uczysz. Próbujesz czegoś, otrzymujesz informację zwrotną. Dostosowujesz swoją strategię i próbujesz ponownie.
Informacja zwrotna może mieć formę krytyki, bólu lub zysku. Może to wynikać z twojego własnego osądu, jak dobrze sobie poradziłeś. Często najbardziej przydatną informacją zwrotną jest informacja zwrotna, która pochodzi od innej osoby, ponieważ nie jest to tylko liczba lub wrażenie, ale inteligentna ocena tego, jak dobrze wykonałeś zadanie.
Kiedy komputer jest szkolony do wykonywania zadania, człowiek zwykle dostarcza informacji zwrotnej w postaci dostosowanych parametrów lub algorytmów. Działa to dobrze, gdy zadanie jest dobrze zdefiniowane, na przykład uczenie się mnożenia dwóch liczb. Możesz łatwo i dokładnie powiedzieć komputerowi, jak się pomylił.
Przy bardziej skomplikowanym zadaniu, takim jak stworzenie wizerunku psa, trudniej jest przekazać informację zwrotną. Czy obraz jest rozmazany, bardziej przypomina kota, czy w ogóle coś przypomina? Można by zaimplementować złożone statystyki, ale trudno byłoby uchwycić wszystkie szczegóły, które sprawiają, że obraz wydaje się prawdziwy.
Człowiek może coś oszacować, ponieważ mamy duże doświadczenie w ocenie danych wizualnych, ale jesteśmy stosunkowo powolni, a nasze oceny mogą być bardzo subiektywne. Zamiast tego moglibyśmy wytrenować sieć neuronową, aby nauczyć się zadania rozróżniania obrazów rzeczywistych i generowanych.
Następnie, pozwalając generatorowi obrazów (również sieci neuronowej) i dyskryminatorowi na zmianę uczyć się od siebie, mogą one z czasem ulec poprawie. Te dwie sieci, grające w tę grę, są generatywną siecią adwersarzy.
Można usłyszeć, jak wynalazca GAN, Ian Goodfellow, opowiada o tym, jak kłótnia w barze na ten temat doprowadziła do gorączkowej nocy kodowania, która zaowocowała pierwszym GAN. I tak, potwierdza pasek w swojej gazecie. Możesz dowiedzieć się więcej o GAN z bloga Iana Goodfellowa na ten temat.
Praca z sieciami GAN wiąże się z wieloma wyzwaniami. Uczenie pojedynczej sieci neuronowej może być trudne ze względu na liczbę dostępnych opcji: architektura, funkcje aktywacji, metoda optymalizacji, współczynnik uczenia się i współczynnik porzucania, żeby wymienić tylko kilka.
GAN podwajają wszystkie te opcje i dodają nowe komplikacje. Zarówno generator, jak i dyskryminator mogą zapomnieć sztuczki, których używali wcześniej podczas treningu. Może to prowadzić do tego, że dwie sieci zostaną złapane w stabilny cykl rozwiązań, które nie ulegają poprawie w czasie. Jedna sieć może obezwładnić drugą, tak że żadna z nich nie może się już uczyć. Lub generator może nie zbadać dużej części możliwej przestrzeni rozwiązań, a jedynie tyle, aby znaleźć realistyczne rozwiązania. Ta ostatnia sytuacja nazywa się załamaniem trybu.
Załamanie trybu ma miejsce, gdy generator uczy się tylko niewielkiego podzbioru możliwych realistycznych trybów. Na przykład, jeśli zadaniem jest generowanie obrazów psów, generator może nauczyć się tworzyć tylko obrazy małych brązowych psów. Generator ominąłby wszystkie inne tryby składające się z psów o innych rozmiarach lub kolorach.
Wdrożono wiele strategii, aby rozwiązać ten problem, w tym normalizację wsadową, dodawanie etykiet do danych uczących lub zmianę sposobu, w jaki dyskryminator ocenia wygenerowane dane.
Ludzie zauważyli, że dodanie etykiet do danych — to znaczy rozbicie ich na kategorie, prawie zawsze poprawia wydajność sieci GAN. Zamiast uczyć się ogólnie generować obrazy zwierząt domowych, generowanie obrazów kotów, psów, ryb i fretek powinno być łatwiejsze.
Być może najbardziej znaczący przełom w rozwoju GAN nastąpił w zakresie zmiany sposobu, w jaki dyskryminator ocenia dane, więc przyjrzyjmy się temu bliżej.
W pierwotnym sformułowaniu GAN z 2014 r. Goodfellow et al. dyskryminator generuje oszacowanie prawdopodobieństwa, że dany obraz był rzeczywisty lub wygenerowany. Dyskryminator otrzymałby zestaw obrazów, który składałby się zarówno z obrazów rzeczywistych, jak i wygenerowanych, i generowałby oszacowanie dla każdego z tych sygnałów wejściowych. Błąd między wyjściem dyskryminatora a rzeczywistymi etykietami byłby wtedy mierzony przez stratę entropii krzyżowej. Stratę entropii krzyżowej można przyrównać do miernika odległości Jensena-Shannona, co zostało wykazane na początku 2017 r. przez Arjovsky i in. że ta metryka w niektórych przypadkach zawiedzie, a w innych nie będzie wskazywać właściwego kierunku. Ta grupa wykazała, że metryka odległości Wassersteina (znana również jako odległość do robót ziemnych lub odległość EM) działała i działała lepiej w wielu innych przypadkach.
Strata entropii krzyżowej jest miarą tego, jak dokładnie dyskryminator zidentyfikował obrazy rzeczywiste i wygenerowane. Miernik Wassersteina zamiast tego analizuje rozkład każdej zmiennej (tj. każdego koloru każdego piksela) w rzeczywistych i wygenerowanych obrazach i określa, jak daleko od siebie są rozkłady dla danych rzeczywistych i wygenerowanych. Miernik Wassersteina pokazuje, ile wysiłku, pod względem masy i odległości, wymagałoby przekształcenie wygenerowanego rozkładu w kształt rozkładu rzeczywistego, stąd alternatywna nazwa „odległość do poruszania się po ziemi”. Ponieważ metryka Wassersteina nie ocenia już, czy obraz jest prawdziwy, czy nie, ale zamiast tego dostarcza krytyki, jak daleko generowane obrazy są od rzeczywistych obrazów, sieć „dyskryminatora” jest określana jako sieć „krytyczna” w Wasserstein architektura.
W celu nieco bardziej kompleksowej eksploracji sieci GAN, w tym artykule omówimy cztery różne architektury:
- GAN: Oryginalny („waniliowy”) GAN
- CGAN: Warunkowa wersja oryginalnego GAN, która wykorzystuje etykiety klas
- WGAN: Wasserstein GAN (z karą gradientową)
- WCGAN: Warunkowa wersja Wasserstein GAN
Ale najpierw spójrzmy na nasz zbiór danych.
Spojrzenie na dane dotyczące oszustw związanych z kartami kredytowymi
Będziemy pracować z zestawem danych do wykrywania oszustw związanych z kartami kredytowymi firmy Kaggle.
Zestaw danych składa się z ok. 285 000 transakcji, z których tylko 492 to fałszywe. Dane składają się z 31 cech: „czas”, „ilość”, „klasa” i 28 dodatkowych, anonimowych cech. Cechą klasy jest etykieta wskazująca, czy transakcja jest fałszywa, czy nie, przy czym 0 oznacza normalne, a 1 oznacza oszustwo. Wszystkie dane są numeryczne i ciągłe (z wyjątkiem etykiety). W zestawie danych nie ma braków danych. Zestaw danych jest już w całkiem dobrym stanie na początek, ale zrobię trochę więcej czyszczenia, głównie po prostu dostosowując średnie wszystkich funkcji do zera i odchylenia standardowe do jednego. Tutaj opisałem mój proces czyszczenia w zeszycie. Na razie pokażę tylko efekt końcowy:
W tych dystrybucjach można łatwo zauważyć różnice między danymi normalnymi a danymi o oszustwach, ale istnieje również wiele nakładających się danych. Możemy zastosować jeden z szybszych i potężniejszych algorytmów uczenia maszynowego, aby zidentyfikować najbardziej przydatne funkcje do wykrywania oszustw. Ten algorytm, xgboost, jest algorytmem drzewa decyzyjnego ze wzmocnieniem gradientowym. Wytrenujemy go na 70% zbioru danych i przetestujemy na pozostałych 30%. Możemy skonfigurować algorytm tak, aby działał, dopóki nie poprawi pamięci (część wykrytych próbek oszustw) w zestawie danych testowych. Osiąga to 76% przypomnienia na zestawie testowym, co wyraźnie pozostawia pole do poprawy. Osiąga precyzję 94%, co oznacza, że tylko 6% przewidywanych przypadków oszustw to w rzeczywistości normalne transakcje. Z tej analizy otrzymujemy również listę funkcji posortowanych według ich użyteczności w wykrywaniu oszustw. Możemy użyć najważniejszych funkcji, aby później zwizualizować nasze wyniki.

Ponownie, gdybyśmy mieli więcej danych o oszustwach, moglibyśmy je lepiej wykryć. Oznacza to, że moglibyśmy osiągnąć wyższy poziom wycofania. Spróbujemy teraz wygenerować nowe, realistyczne dane o oszustwach za pomocą GAN, aby pomóc nam wykryć rzeczywiste oszustwo.
Generowanie nowych danych karty kredytowej za pomocą GAN
Aby zastosować różne architektury GAN do tego zbioru danych, użyję GAN-Sandbox, który ma wiele popularnych architektur GAN zaimplementowanych w Pythonie przy użyciu biblioteki Keras i zaplecza TensorFlow. Wszystkie moje wyniki są dostępne jako notatnik Jupyter tutaj. Wszystkie niezbędne biblioteki są zawarte w obrazie Docker Kaggle/Python, jeśli potrzebujesz łatwej konfiguracji.
Przykłady w GAN-Sandbox są skonfigurowane do przetwarzania obrazu. Generator wytwarza obraz 2D z 3 kanałami kolorów dla każdego piksela, a dyskryminator/krytyczny jest skonfigurowany do oceny takich danych. Przekształcenia splotowe są wykorzystywane między warstwami sieci w celu wykorzystania przestrzennej struktury danych obrazu. Każdy neuron w warstwie splotowej działa tylko z niewielką grupą wejść i wyjść (np. sąsiednie piksele na obrazie), aby umożliwić poznanie relacji przestrzennych. W naszym zbiorze danych dotyczących kart kredytowych brakuje jakiejkolwiek struktury przestrzennej wśród zmiennych, więc przekształciłem sieci splotowe w sieci z gęsto połączonymi warstwami. Neurony w gęsto połączonych warstwach są połączone z każdym wejściem i wyjściem warstwy, dzięki czemu sieć uczy się własnych relacji między cechami. Użyję tej konfiguracji dla każdej architektury.
Pierwszy GAN, który ocenię, porównuje sieć generatora z siecią dyskryminatora, wykorzystując straty w entropii krzyżowej z dyskryminatora do trenowania sieci. To oryginalna, „waniliowa” architektura GAN. Drugi GAN, który ocenię, dodaje etykiety klas do danych w sposób warunkowy GAN (CGAN). Ten GAN ma jeszcze jedną zmienną w danych, etykietę klasy. Trzeci GAN użyje metryki odległości Wassersteina do trenowania sieci (WGAN), a ostatni użyje etykiet klas i metryki odległości Wassersteina (WCGAN).
Przeszkolimy różne sieci GAN przy użyciu zestawu danych szkoleniowych, który składa się ze wszystkich 492 fałszywych transakcji. Możemy dodać klasy do zbioru danych o oszustwach, aby ułatwić warunkowe architektury GAN. Zbadałem kilka różnych metod klastrowania w notatniku i zastosowałem klasyfikację KMeans, która sortuje dane o oszustwach na 2 klasy.
Wytrenuję każdy GAN przez 5000 rund i po drodze będę sprawdzał wyniki. Na rysunku 4 możemy zobaczyć rzeczywiste dane o oszustwach i wygenerowane dane o oszustwach z różnych architektur GAN w miarę postępu szkolenia. Możemy zobaczyć rzeczywiste dane dotyczące oszustw podzielone na 2 klasy KMeans, wykreślone z 2 wymiarami, które najlepiej odróżniają te dwie klasy (funkcje V10 i V17 z funkcji przekształconych w PCA). Dwa GAN, które nie wykorzystują informacji o klasach, GAN i WGAN, mają wygenerowane dane wyjściowe jako jedna klasa. Architektury warunkowe CGAN i WCGAN pokazują wygenerowane dane według klas. W kroku 0 wszystkie wygenerowane dane pokazują normalny rozkład losowych danych wejściowych podawanych do generatorów.
Widzimy, że oryginalna architektura GAN zaczyna uczyć się kształtu i zakresu rzeczywistych danych, ale potem rozpada się w kierunku małej dystrybucji. Jest to omawiany wcześniej tryb upadku. Generator nauczył się niewielkiego zakresu danych, które dyskryminatorowi trudno jest wykryć jako fałszywe. Architektura CGAN radzi sobie trochę lepiej, rozprzestrzeniając się i zbliżając do dystrybucji każdej klasy danych o oszustwach, ale potem następuje załamanie trybu, jak widać w kroku 5000.
WGAN nie doświadcza załamania trybu wykazywanego przez architektury GAN i CGAN. Nawet bez informacji o klasie, zaczyna zakładać nienormalny rozkład rzeczywistych danych oszustwa. Architektura WCGAN działa podobnie i jest w stanie generować oddzielne klasy danych.
Możemy ocenić, jak realistycznie wyglądają dane, korzystając z tego samego algorytmu xgboost, który był wcześniej używany do wykrywania oszustw. Jest szybki, mocny i działa od ręki bez zbytniego dostrajania. Wyszkolimy klasyfikator xgboost, korzystając z połowy rzeczywistych danych o oszustwach (246 próbek) i takiej samej liczby przykładów wygenerowanych przez GAN. Następnie przetestujemy klasyfikator xgboost, korzystając z drugiej połowy rzeczywistych danych o oszustwach i innego zestawu 246 przykładów wygenerowanych przez GAN. Ta ortogonalna metoda (w sensie eksperymentalnym) da nam pewne wskazówki, jak skutecznie generator generuje realistyczne dane. Przy idealnie realistycznych wygenerowanych danych algorytm xgboost powinien osiągnąć dokładność 0,50 (50%) — innymi słowy, nie jest to lepsze niż zgadywanie.
Widzimy, że dokładność xgboost na danych generowanych przez GAN początkowo maleje, a następnie rośnie po kroku 1000 szkolenia, gdy rozpoczyna się zwijanie trybu. Architektura CGAN uzyskuje nieco bardziej realistyczne dane po 2000 krokach, ale potem dla tej sieci następuje zwijanie trybu, gdy dobrze. Architektury WGAN i WCGAN szybciej uzyskują bardziej realistyczne dane i kontynuują naukę w miarę postępu szkolenia. WCGAN nie wydaje się mieć większej przewagi nad WGAN, co sugeruje, że te stworzone klasy mogą nie być przydatne dla architektur Wasserstein GAN.
Możesz dowiedzieć się więcej o architekturze WGAN tutaj i tutaj.
Sieć krytyków w architekturach WGAN i WCGAN uczy się obliczania odległości Wasserstein (Earth-mover, EM) między danym zbiorem danych a rzeczywistymi danymi oszustwa. Najlepiej byłoby, gdyby zmierzył odległość bliską zeru dla próbki rzeczywistych danych o oszustwach. Krytyk jest jednak w trakcie uczenia się, jak wykonać tę kalkulację. Dopóki mierzy większą odległość dla danych generowanych niż dla danych rzeczywistych, sieć może się poprawić. Możemy obserwować, jak w trakcie treningu zmienia się różnica między odległościami Wassersteina dla danych generowanych i rzeczywistych. Jeśli ustabilizuje się, dalsze szkolenie może nie pomóc. Widzimy na rysunku 6, że wydaje się, że w tym zbiorze danych nastąpi dalsza poprawa zarówno dla WGAN, jak i WCGAN.
Czego się nauczyliśmy?
Teraz możemy przetestować, czy jesteśmy w stanie wygenerować nowe dane o oszustwach na tyle realistyczne, aby pomóc nam wykryć rzeczywiste dane oszustwa. Możemy wziąć wyszkolony generator, który osiągnął najniższy wynik dokładności i użyć go do wygenerowania danych. W naszym podstawowym zestawie szkoleniowym użyjemy 70% danych niezwiązanych z oszustwami (199 020 przypadków) i 100 przypadków danych dotyczących oszustw (~20% danych dotyczących oszustw). Następnie spróbujemy dodać różne ilości rzeczywistych lub wygenerowanych danych dotyczących oszustw do tego zestawu szkoleniowego, do 344 przypadków (70% danych dotyczących oszustw). W zestawie testowym wykorzystamy pozostałe 30% przypadków bez nadużyć finansowych (85 295 przypadków) i przypadków nadużyć (148 przypadków). Możemy spróbować dodać wygenerowane dane z niewytrenowanego GAN i najlepiej wytrenowanego GAN, aby sprawdzić, czy wygenerowane dane są lepsze niż losowy szum. Z naszych testów wynika, że naszą najlepszą architekturą był WCGAN na etapie szkolenia 4800, gdzie osiągnął dokładność xgboost na poziomie 70% (pamiętaj, że idealnie byłoby, gdyby dokładność wynosiła 50%). Wykorzystamy więc tę architekturę do generowania nowych danych o oszustwach.
Na wykresie 7 widać, że wycofanie (ułamek rzeczywistych próbek oszustw dokładnie zidentyfikowanych w zestawie testowym) nie zwiększa się, ponieważ wykorzystujemy do szkolenia więcej wygenerowanych danych dotyczących oszustw. Klasyfikator xgboost jest w stanie zachować wszystkie informacje, których użył do zidentyfikowania oszustw ze 100 rzeczywistych przypadków i nie dać się zmylić dodatkowymi wygenerowanymi danymi, nawet jeśli wybiera je spośród setek tysięcy normalnych przypadków. Nic dziwnego, że wygenerowane dane z nieprzeszkolonego WCGAN nie pomagają ani nie bolą. Ale wygenerowane dane z przeszkolonego WCGAN też nie pomagają. Wygląda na to, że dane nie są wystarczająco realistyczne. Na wykresie 7 widać, że gdy do uzupełnienia zestawu szkoleniowego wykorzystywane są rzeczywiste dane o oszustwach, przypomnienie znacznie wzrasta. Gdyby WCGAN właśnie nauczył się powielać przykłady treningowe, nie wykazując się wcale kreatywnością, mógłby osiągnąć wyższy wskaźnik przypominania, jak widzimy w przypadku rzeczywistych danych.
Do nieskończoności i poza nią
Chociaż nie byliśmy w stanie wygenerować wystarczająco realistycznych danych dotyczących oszustw związanych z kartami kredytowymi, aby pomóc nam wykryć rzeczywiste oszustwa, tymi metodami ledwo zarysowaliśmy powierzchnię. Moglibyśmy trenować dłużej, z większymi sieciami i dostrajać parametry do architektur, które wypróbowaliśmy w tym artykule. Trendy w dokładności xgboost i utracie dyskryminatora sugerują, że więcej szkoleń pomoże architekturom WGAN i WCGAN. Inną opcją jest ponowne przyjrzenie się procesowi czyszczenia danych, który przeprowadziliśmy, być może zaprojektowanie nowych zmiennych lub zmiana, czy iw jaki sposób zajmiemy się skośnością funkcji. Być może pomocne byłyby różne schematy klasyfikacji danych dotyczących oszustw.
Moglibyśmy również wypróbować inne architektury GAN. DRAGAN posiada teoretyczne i eksperymentalne dowody pokazujące, że trenuje szybciej i stabilniej niż GAN Wassersteina. Moglibyśmy zintegrować metody wykorzystujące częściowo nadzorowane uczenie się, które okazały się obiecujące w uczeniu się z ograniczonych zestawów szkoleniowych (patrz „Ulepszone techniki szkolenia GAN”). Moglibyśmy wypróbować architekturę, która da nam modele zrozumiałe dla człowieka, abyśmy mogli lepiej zrozumieć strukturę danych (patrz InfoGAN).
Powinniśmy również zwracać uwagę na nowe osiągnięcia w tej dziedzinie, a na koniec, co nie mniej ważne, możemy pracować nad tworzeniem własnych innowacji w tej szybko rozwijającej się przestrzeni.
Cały odpowiedni kod dla tego artykułu można znaleźć w tym repozytorium GitHub.