Создайте данные из случайного шума с помощью генеративно-состязательных сетей

Опубликовано: 2022-03-11

С тех пор как я узнал о генеративно-состязательных сетях (GAN), я был очарован ими. GAN — это тип нейронной сети, способный генерировать новые данные с нуля. Вы можете подать ему на вход немного случайного шума, и он сможет создавать реалистичные изображения спален, птиц или чего-то еще, что он обучен генерировать.

Одна вещь, с которой согласны все ученые, заключается в том, что нам нужно больше данных.

Сети GAN, которые можно использовать для получения новых данных в ситуациях с ограниченными данными, могут оказаться действительно полезными. Иногда данные могут быть сложными и дорогими, а их создание требует много времени. Однако, чтобы быть полезными, новые данные должны быть достаточно реалистичными, чтобы любые идеи, которые мы получаем из сгенерированных данных, по-прежнему применялись к реальным данным. Если вы обучаете кошку охотиться на мышей и используете поддельных мышей, вам лучше убедиться, что поддельные мыши действительно выглядят как мыши.

Другой способ думать об этом — GAN обнаруживают структуру данных, которая позволяет им создавать реалистичные данные. Это может быть полезно, если мы не можем увидеть эту структуру самостоятельно или не можем извлечь ее другими методами.

Генеративно-состязательные сети

В этой статье вы узнаете, как можно использовать GAN для генерации новых данных. Чтобы сделать этот урок реалистичным, мы будем использовать набор данных для обнаружения мошенничества с кредитными картами от Kaggle.

В своих экспериментах я пытался использовать этот набор данных, чтобы посмотреть, смогу ли я заставить GAN создавать данные, достаточно реалистичные, чтобы помочь нам обнаруживать случаи мошенничества. Этот набор данных подчеркивает проблему ограниченности данных: из 285 000 транзакций только 492 являются мошенничеством. 492 случая мошенничества — это небольшой набор данных для обучения, особенно когда речь идет о задачах машинного обучения, когда людям нравится иметь наборы данных на несколько порядков больше. Хотя результаты моего эксперимента не были удивительными, я многое узнал о GAN по пути, которым я рад поделиться.

Прежде чем ты начнешь

Прежде чем мы углубимся в эту область GAN, если вы хотите быстро освежить свои навыки машинного обучения или глубокого обучения, вы можете взглянуть на эти два связанных сообщения в блоге:

  • Введение в теорию машинного обучения и ее применение: наглядное пособие с примерами
  • Учебное пособие по глубокому обучению: от персептронов к глубоким сетям

Почему ГАН?

Генеративно-состязательные сети (GAN) — это архитектура нейронной сети, которая продемонстрировала впечатляющие улучшения по сравнению с предыдущими генеративными методами, такими как вариационные автокодировщики или ограниченные машины Больцмана. GAN могут генерировать более реалистичные изображения (например, DCGAN), обеспечивать передачу стилей между изображениями (см. здесь и здесь), генерировать изображения из текстовых описаний (StackGAN) и учиться на небольших наборах данных с помощью частично контролируемого обучения. Благодаря этим достижениям они вызывают большой интерес как в академическом, так и в коммерческом секторах.

Директор по исследованиям искусственного интеллекта в Facebook Янн ЛеКанн даже назвал их самой захватывающей разработкой в ​​области машинного обучения за последнее десятилетие.

Основы

Подумайте о том, как вы учитесь. Вы пробуете что-то, вы получаете обратную связь. Вы корректируете свою стратегию и пробуете снова.

Обратная связь может прийти в форме критики, боли или выгоды. Это может исходить из вашего собственного суждения о том, насколько хорошо вы справились. Часто самая полезная обратная связь — это обратная связь, исходящая от другого человека, потому что это не просто число или ощущение, а разумная оценка того, насколько хорошо вы справились с заданием.

Когда компьютер обучается выполнению задачи, человек обычно обеспечивает обратную связь в виде скорректированных параметров или алгоритмов. Это хорошо работает, когда задача четко определена, например, научиться умножать два числа. Вы можете легко и точно сказать компьютеру, в чем он был неправ.

С более сложной задачей, такой как создание изображения собаки, становится сложнее обеспечить обратную связь. Изображение размыто, больше похоже на кошку или вообще на что-то похоже? Можно было бы реализовать сложную статистику, но было бы сложно зафиксировать все детали, которые делают изображение реальным.

Человек может дать некоторую оценку, потому что у нас есть большой опыт оценки визуального ввода, но мы относительно медлительны, и наши оценки могут быть очень субъективными. Вместо этого мы могли бы обучить нейронную сеть различать реальные и сгенерированные изображения.

Затем, позволяя генератору изображений (также нейронной сети) и дискриминатору по очереди учиться друг у друга, они могут со временем улучшаться. Эти две сети, играющие в эту игру, представляют собой генеративно-состязательную сеть.

Вы можете услышать, как изобретатель GAN Ян Гудфеллоу рассказывает о том, как спор в баре на эту тему привел к лихорадочной ночи кодирования, результатом которой стал первый GAN. И да, он признает бар в своей статье. Вы можете узнать больше о GAN из блога Яна Гудфеллоу по этой теме.

Схема ГАН

При работе с GAN возникает ряд проблем. Обучение одной нейронной сети может быть затруднено из-за множества вариантов выбора: архитектура, функции активации, метод оптимизации, скорость обучения и процент отсева, и это лишь некоторые из них.

GAN удваивают все эти варианты и добавляют новые сложности. И генератор, и дискриминатор могут забыть приемы, которые они использовали ранее при обучении. Это может привести к тому, что две сети попадут в стабильный цикл решений, которые со временем не улучшатся. Одна сеть может пересилить другую сеть, так что ни одна из них больше не сможет учиться. Или генератор может не исследовать большую часть возможного пространства решений, а только достаточно, чтобы найти реалистичные решения. Эта последняя ситуация известна как коллапс моды.

Коллапс режима — это когда генератор изучает только небольшое подмножество возможных реалистичных режимов. Например, если задача состоит в том, чтобы генерировать изображения собак, генератор может научиться создавать только изображения маленьких коричневых собак. Генератор пропустил бы все другие режимы, состоящие из собак других размеров или цветов.

Для решения этой проблемы было реализовано множество стратегий, в том числе нормализация пакетов, добавление меток к обучающим данным или изменение способа, которым дискриминатор оценивает сгенерированные данные.

Люди отмечают, что добавление меток к данным, то есть разбиение их на категории, почти всегда повышает производительность GAN. Вместо того, чтобы учиться генерировать изображения домашних животных в целом, должно быть проще создавать, например, изображения кошек, собак, рыб и хорьков.

Возможно, самые значительные прорывы в разработке GAN произошли с точки зрения изменения того, как дискриминатор оценивает данные, поэтому давайте рассмотрим это подробнее.

В исходной формулировке GAN в 2014 году Гудфеллоу и др. Дискриминатор генерирует оценку вероятности того, что данное изображение было реальным или сгенерировано. Дискриминатору будет предоставлен набор изображений, состоящих как из реальных, так и из сгенерированных изображений, и он будет генерировать оценку для каждого из этих входных данных. Затем ошибка между выходом дискриминатора и фактическими метками будет измеряться кросс-энтропийной потерей. Кросс-энтропийную потерю можно приравнять к метрике расстояния Дженсена-Шеннона, и она была показана в начале 2017 года Аржовски и др. что эта метрика в некоторых случаях не сработает, а в других случаях не укажет в правильном направлении. Эта группа показала, что метрика расстояния Вассерштейна (также известная как землеройная машина или электромагнитное расстояние) работала и работала лучше во многих других случаях.

Перекрестная энтропийная потеря является мерой того, насколько точно дискриминатор идентифицировал реальные и сгенерированные изображения. Вместо этого метрика Вассерштейна рассматривает распределение каждой переменной (т. е. каждого цвета каждого пикселя) в реальных и сгенерированных изображениях и определяет, насколько далеко друг от друга распределения для реальных и сгенерированных данных. Метрика Вассерштейна показывает, сколько усилий, с точки зрения массы, умноженной на расстояние, потребуется, чтобы привести сгенерированное распределение в форму реального распределения, отсюда и альтернативное название «расстояние землеройной машины». Поскольку метрика Вассерштейна больше не оценивает, является ли изображение реальным или нет, а вместо этого обеспечивает критику того, насколько далеки сгенерированные изображения от реальных изображений, сеть «дискриминатора» называется сетью «критика» в метрике Вассерштейна. архитектура.

Для более полного изучения GAN в этой статье мы рассмотрим четыре разные архитектуры:

  • ГАН: Оригинальный («ванильный») ГАН
  • CGAN: условная версия оригинальной GAN, в которой используются метки классов.
  • WGAN: ГАН Вассерштейна (со штрафом за уклон)
  • WCGAN: условная версия ГАН Вассерштейна.

Но давайте сначала взглянем на наш набор данных.

Взгляд на данные о мошенничестве с кредитными картами

Мы будем работать с набором данных обнаружения мошенничества с кредитными картами от Kaggle.

Набор данных состоит примерно из 285 000 транзакций, из которых только 492 являются мошенническими. Данные состоят из 31 функции: «время», «количество», «класс» и 28 дополнительных анонимных функций. Признак класса — это метка, указывающая, является ли транзакция мошеннической или нет, где 0 указывает на нормальную операцию, а 1 — на мошенничество. Все данные являются числовыми и непрерывными (кроме метки). В наборе данных нет пропущенных значений. Набор данных уже находится в довольно хорошей форме для начала, но я сделаю еще немного очистки, в основном просто установив средние значения всех признаков на ноль, а стандартные отклонения на единицу. Я описал свой процесс очистки больше в блокноте здесь. Пока просто покажу конечный результат:

Особенности и графики классов

В этих распределениях можно легко обнаружить различия между обычными данными и данными о мошенничестве, но также есть много совпадений. Мы можем применить один из самых быстрых и мощных алгоритмов машинного обучения, чтобы определить наиболее полезные функции для выявления мошенничества. Этот алгоритм, xgboost, представляет собой алгоритм дерева решений с градиентным усилением. Мы обучим его на 70% набора данных и протестируем на оставшихся 30%. Мы можем настроить алгоритм так, чтобы он продолжал работать до тех пор, пока он не улучшит отзыв (долю обнаруженных образцов мошенничества) в тестовом наборе данных. Это обеспечивает 76% отзыва на тестовом наборе, что явно оставляет место для улучшения. Он достигает точности 94%, что означает, что только 6% предсказанных случаев мошенничества на самом деле были обычными транзакциями. Из этого анализа мы также получаем список функций, отсортированных по их полезности в обнаружении мошенничества. Мы можем использовать самые важные функции, чтобы визуализировать наши результаты позже.

Опять же, если бы у нас было больше данных о мошенничестве, мы могли бы лучше его обнаружить. То есть мы могли добиться более высокого отзыва. Теперь мы попытаемся сгенерировать новые, реалистичные данные о мошенничестве, используя GAN, чтобы помочь нам обнаружить фактическое мошенничество.

Создание новых данных кредитной карты с помощью GAN

Чтобы применить различные архитектуры GAN к этому набору данных, я собираюсь использовать GAN-Sandbox, в которой есть ряд популярных архитектур GAN, реализованных в Python с использованием библиотеки Keras и серверной части TensorFlow. Все мои результаты доступны в виде блокнота Jupyter здесь. Все необходимые библиотеки включены в образ Kaggle/Python Docker, если вам нужна простая установка.

Примеры в GAN-Sandbox настроены для обработки изображений. Генератор выдает 2D-изображение с 3 цветовыми каналами для каждого пикселя, а дискриминатор/критик настроен на оценку таких данных. Сверточные преобразования используются между уровнями сетей, чтобы использовать преимущества пространственной структуры данных изображения. Каждый нейрон в сверточном слое работает только с небольшой группой входных и выходных данных (например, смежные пиксели в изображении), что позволяет изучать пространственные отношения. В нашем наборе данных кредитных карт отсутствует какая-либо пространственная структура среди переменных, поэтому я преобразовал сверточные сети в сети с плотно связанными слоями. Нейроны в плотно связанных слоях подключены к каждому входу и выходу слоя, что позволяет сети изучать свои собственные отношения между функциями. Я буду использовать эту настройку для каждой из архитектур.

Первая GAN, которую я буду оценивать, противопоставляет сеть генератора сети дискриминатора, используя потерю перекрестной энтропии от дискриминатора для обучения сетей. Это оригинальная, «ванильная» архитектура GAN. Второй GAN, который я буду оценивать, добавляет метки классов к данным наподобие условного GAN (CGAN). Эта GAN имеет еще одну переменную в данных, метку класса. Третий GAN будет использовать метрику расстояния Вассерштейна для обучения сетей (WGAN), а последний будет использовать метки классов и метрику расстояния Вассерштейна (WCGAN).

GAN-архитектуры

Мы будем обучать различные GAN, используя обучающий набор данных, состоящий из всех 492 мошеннических транзакций. Мы можем добавить классы в набор данных о мошенничестве, чтобы упростить условную архитектуру GAN. Я изучил несколько различных методов кластеризации в блокноте и выбрал классификацию KMeans, которая сортирует данные о мошенничестве по 2 классам.

Я буду тренировать каждую ГАН на 5000 раундов и попутно изучу результаты. На рисунке 4 мы можем видеть фактические данные о мошенничестве и сгенерированные данные о мошенничестве из различных архитектур GAN по мере прохождения обучения. Мы можем видеть фактические данные о мошенничестве, разделенные на 2 класса KMeans, нанесенные на график с 2 измерениями, которые лучше всего различают эти два класса (функции V10 и V17 из преобразованных функций PCA). Две GAN, которые не используют информацию о классе, GAN и WGAN, имеют свои сгенерированные выходные данные как один класс. Условные архитектуры, CGAN и WCGAN, отображают созданные ими данные по классам. На шаге 0 все сгенерированные данные показывают нормальное распределение случайного ввода, подаваемого на генераторы.

Сравнение выходных данных GAN

Мы можем видеть, что исходная архитектура GAN начинает изучать форму и диапазон фактических данных, но затем сворачивается к небольшому распределению. Это коллапс режима, который обсуждался ранее. Генератор изучил небольшой диапазон данных, которые дискриминатору трудно определить как фальшивые. Архитектура CGAN работает немного лучше, распространяясь и приближаясь к распределениям каждого класса данных о мошенничестве, но затем наступает коллапс режима, как видно на шаге 5000.

В WGAN не происходит коллапс режима, характерный для архитектур GAN и CGAN. Даже без информации о классе он начинает предполагать ненормальное распределение фактических данных о мошенничестве. Архитектура WCGAN работает аналогично и способна генерировать отдельные классы данных.

Мы можем оценить, насколько реалистично выглядят данные, используя тот же алгоритм xgboost, который использовался ранее для обнаружения мошенничества. Он быстрый и мощный, работает в готовом виде без особой настройки. Мы будем обучать классификатор xgboost, используя половину фактических данных о мошенничестве (246 образцов) и равное количество примеров, сгенерированных GAN. Затем мы протестируем классификатор xgboost, используя другую половину фактических данных о мошенничестве и другой набор из 246 примеров, сгенерированных GAN. Этот ортогональный метод (в экспериментальном смысле) даст нам некоторое представление о том, насколько успешен генератор в получении реалистичных данных. С совершенно реалистичными сгенерированными данными алгоритм xgboost должен достигать точности 0,50 (50%) — другими словами, это не лучше, чем угадывать.

Точность

Мы можем видеть, что точность xgboost для данных, сгенерированных GAN, сначала уменьшается, а затем увеличивается после шага обучения 1000 по мере того, как наступает коллапс режима. Архитектура CGAN достигает несколько более реалистичных данных после 2000 шагов, но затем для этой сети наступает коллапс режима как хорошо. Архитектуры WGAN и WCGAN позволяют быстрее получать более реалистичные данные и продолжают обучаться по ходу обучения. WCGAN, по-видимому, не имеет большого преимущества перед WGAN, что позволяет предположить, что эти созданные классы могут быть бесполезны для архитектуры GAN Вассерштейна.

Вы можете узнать больше об архитектуре WGAN здесь и здесь.

Критическая сеть в архитектурах WGAN и WCGAN учится вычислять расстояние Вассерштейна (землеход, EM) между заданным набором данных и фактическими данными о мошенничестве. В идеале он будет измерять расстояние, близкое к нулю, для выборки фактических данных о мошенничестве. Критик, однако, находится в процессе обучения тому, как выполнять этот расчет. Пока она измеряет большее расстояние для сгенерированных данных, чем для реальных данных, сеть может улучшаться. Мы можем наблюдать, как в процессе обучения меняется разница между расстояниями Вассерштейна для сгенерированных и реальных данных. Если оно стабилизируется, дальнейшие тренировки могут не помочь. На рисунке 6 мы видим, что в этом наборе данных, по-видимому, есть дальнейшие улучшения как для WGAN, так и для WCGAN.

Оценка расстояния ЭМ

Чему мы научились?

Теперь мы можем проверить, можем ли мы генерировать новые данные о мошенничестве, достаточно реалистичные, чтобы помочь нам обнаружить фактические данные о мошенничестве. Мы можем взять обученный генератор, который показал наименьшую оценку точности, и использовать его для генерации данных. Для нашего базового обучающего набора мы будем использовать 70% данных, не связанных с мошенничеством (199 020 случаев), и 100 случаев данных о мошенничестве (~ 20% данных о мошенничестве). Затем мы попробуем добавить в этот обучающий набор разное количество реальных или сгенерированных данных о мошенничестве, до 344 случаев (70% данных о мошенничестве). Для тестового набора мы будем использовать остальные 30% случаев, не связанных с мошенничеством (85 295 случаев) и случаев мошенничества (148 случаев). Мы можем попробовать добавить сгенерированные данные из необученной GAN и из наиболее обученной GAN, чтобы проверить, лучше ли сгенерированные данные, чем случайный шум. Из наших тестов видно, что нашей лучшей архитектурой была WCGAN на шаге обучения 4800, где она достигла точности xgboost 70% (помните, что в идеале точность должна быть 50%). Поэтому мы будем использовать эту архитектуру для создания новых данных о мошенничестве.

На рисунке 7 видно, что отзыв (доля фактических образцов мошенничества, точно идентифицированных в тестовом наборе) не увеличивается по мере того, как мы используем больше сгенерированных данных о мошенничестве для обучения. Классификатор xgboost способен сохранить всю информацию, которую он использовал для выявления мошенничества из 100 реальных случаев, и не запутаться в дополнительных сгенерированных данных, даже выбирая их из сотен тысяч обычных случаев. Сгенерированные данные из необученного WCGAN не помогают и не мешают, что неудивительно. Но и сгенерированные данные обученного WCGAN не помогают. Похоже, что данные недостаточно реалистичны. На рис. 7 видно, что когда фактические данные о мошенничестве используются для дополнения обучающей выборки, отзыв значительно увеличивается. Если бы WCGAN просто научился дублировать обучающие примеры, вообще не проявляя творчества, он мог бы достичь более высоких показателей припоминания, как мы видим на реальных данных.

Влияние дополнительных данных

Бесконечность не предел

Хотя нам не удалось получить данные о мошенничестве с кредитными картами, достаточно реалистичные, чтобы помочь нам обнаружить фактическое мошенничество, мы едва коснулись этих методов. Мы могли бы тренироваться дольше, с более крупными сетями и настраивать параметры для архитектур, которые мы пробовали в этой статье. Тенденции в точности xgboost и потерях дискриминатора предполагают, что дополнительное обучение поможет архитектурам WGAN и WCGAN. Другой вариант — пересмотреть выполненную нами очистку данных, возможно, разработать некоторые новые переменные или изменить, если и как мы устраняем асимметрию в функциях. Возможно, помогли бы различные схемы классификации данных о мошенничестве.

Мы также могли бы попробовать другие архитектуры GAN. У DRAGAN есть теоретические и экспериментальные данные, показывающие, что он тренируется быстрее и стабильнее, чем GAN Wasserstein. Мы могли бы интегрировать методы, использующие обучение с полуучителем, которые продемонстрировали многообещающие результаты при обучении на ограниченных обучающих наборах (см. «Улучшенные методы обучения GAN»). Мы могли бы попробовать архитектуру, которая дает нам понятные для человека модели, чтобы мы могли лучше понять структуру данных (см. InfoGAN).

Мы также должны следить за новыми разработками в этой области, и, что не менее важно, мы можем работать над созданием собственных инноваций в этой быстро развивающейся области.

Вы можете найти весь соответствующий код для этой статьи в этом репозитории GitHub.

Связанный: Многие приложения градиентного спуска в TensorFlow