Rozpoznawanie numerów uczenia maszynowego — od zera do zastosowania

Opublikowany: 2022-03-11

Uczenie maszynowe, wizja komputerowa, tworzenie potężnych interfejsów API i tworzenie pięknych interfejsów użytkownika to ekscytujące dziedziny, w których można znaleźć wiele innowacji.

Pierwsze dwa wymagają rozległej matematyki i nauk ścisłych, podczas gdy rozwój API i interfejsu użytkownika koncentruje się na myśleniu algorytmicznym i projektowaniu elastycznych architektur. Są bardzo różne, więc podjęcie decyzji, którego chcesz się nauczyć w następnej kolejności, może być trudne. Celem tego artykułu jest pokazanie, w jaki sposób można wykorzystać wszystkie cztery w tworzeniu aplikacji do przetwarzania obrazu.

Aplikacja, którą zamierzamy zbudować, to prosty aparat do rozpoznawania cyfr. Rysujesz, maszyna przewiduje cyfrę. Prostota jest niezbędna, ponieważ pozwala nam zobaczyć szerszy obraz, a nie skupiać się na szczegółach.

W trosce o prostotę wykorzystamy najpopularniejsze i najłatwiejsze do nauczenia technologie. Część dotycząca uczenia maszynowego będzie używać Pythona do aplikacji zaplecza. Jeśli chodzi o interaktywną stronę aplikacji, będziemy działać za pośrednictwem biblioteki JavaScript, której nie trzeba przedstawiać: React.

Uczenie maszynowe do odgadywania cyfr

Podstawą naszej aplikacji jest algorytm zgadywania wylosowanej liczby. Uczenie maszynowe będzie narzędziem używanym do osiągnięcia dobrej jakości zgadywania. Ten rodzaj podstawowej sztucznej inteligencji pozwala systemowi na automatyczne uczenie się przy określonej ilości danych. W szerszym ujęciu uczenie maszynowe to proces znajdowania zbiegu okoliczności lub zestawu zbiegów okoliczności w danych, aby polegać na nich w celu odgadnięcia wyniku.

Nasz proces rozpoznawania obrazu składa się z trzech kroków:

  • Uzyskaj obrazy narysowanych cyfr do treningu
  • Naucz system, aby odgadywał liczby na podstawie danych treningowych
  • Przetestuj system z nowymi/nieznanymi danymi

Środowisko

Będziemy potrzebować środowiska wirtualnego do pracy z uczeniem maszynowym w Pythonie. Takie podejście jest praktyczne, ponieważ zarządza wszystkimi wymaganymi pakietami Pythona, więc nie musisz się nimi martwić.

Zainstalujmy go za pomocą następujących poleceń terminala:

 python3 -m venv virtualenv source virtualenv/bin/activate

Model szkoleniowy

Zanim zaczniemy pisać kod, musimy wybrać odpowiedniego „nauczyciela” dla naszych maszyn. Zazwyczaj specjaliści od nauki danych wypróbowują różne modele przed wybraniem najlepszego. Pominiemy bardzo zaawansowane modele, które wymagają dużych umiejętności i przejdziemy do algorytmu k-najbliższych sąsiadów.

Jest to algorytm, który pobiera próbki danych i układa je na płaszczyźnie uporządkowanej według zadanego zestawu cech. Aby lepiej to zrozumieć, przyjrzyjmy się następującemu obrazowi:

Obraz: Próbki danych uczenia maszynowego ułożone na płaszczyźnie

Aby wykryć typ Zielonej Kropki , powinniśmy sprawdzić typy k najbliższych sąsiadów, gdzie k jest zestawem argumentów. Biorąc pod uwagę powyższy obrazek, jeśli k jest równe 1, 2, 3 lub 4, przypuszczenie będzie Czarnym Trójkątem , ponieważ większość najbliższych k sąsiadów zielonej kropki to czarne trójkąty. Jeśli zwiększymy k do 5, to większość obiektów to niebieskie kwadraty, stąd przypuszczenie będzie Blue Square .

Istnieje kilka zależności potrzebnych do stworzenia naszego modelu uczenia maszynowego:

  • sklearn.neighbors.KNeighborsClassifier to klasyfikator, którego będziemy używać.
  • sklearn.model_selection.train_test_split to funkcja, która pomoże nam podzielić dane na dane treningowe i dane służące do sprawdzenia poprawności modelu.
  • sklearn.model_selection.cross_val_score to funkcja do uzyskania oceny poprawności modelu. Im wyższa wartość, tym lepsza poprawność.
  • sklearn.metrics.classification_report to funkcja wyświetlająca raport statystyczny z domysłami modelu.
  • sklearn.datasets to pakiet używany do pobierania danych do uczenia (obrazy cyfr).
  • numpy to pakiet szeroko stosowany w nauce, ponieważ oferuje produktywny i wygodny sposób manipulowania wielowymiarowymi strukturami danych w Pythonie.
  • matplotlib.pyplot to pakiet używany do wizualizacji danych.

Zacznijmy od zainstalowania i zaimportowania wszystkich z nich:

 pip install sklearn numpy matplotlib scipy from sklearn.datasets import load_digits from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split, cross_val_score import numpy as np import matplotlib.pyplot as plt

Teraz musimy załadować bazę danych MNIST. MNIST to klasyczny zbiór danych odręcznie pisanych obrazów używany przez tysiące nowicjuszy w dziedzinie uczenia maszynowego:

 digits = load_digits()

Gdy dane są już pobrane i gotowe, możemy przejść do kolejnego kroku podziału danych na dwie części: trenowanie i testowanie .

Użyjemy 75% danych, aby nauczyć nasz model odgadywania cyfr, a resztę danych wykorzystamy do przetestowania poprawności modelu:

 (X_train, X_test, y_train, y_test) = train_test_split( digits.data, digits.target, test_size=0.25, random_state=42 )

Dane są teraz uporządkowane i jesteśmy gotowi do ich użycia. Postaramy się znaleźć najlepszy parametr k dla naszego modelu, aby domysły były bardziej precyzyjne. Na tym etapie nie możemy zapomnieć o wartości k , ponieważ musimy ocenić model z różnymi wartościami k .

Zobaczmy, dlaczego ważne jest rozważenie zakresu wartości k i jak poprawia to dokładność naszego modelu:

 ks = np.arange(2, 10) scores = [] for k in ks: model = KNeighborsClassifier(n_neighbors=k) score = cross_val_score(model, X_train, y_train, cv=5) score.mean() scores.append(score.mean()) plt.plot(scores, ks) plt.xlabel('accuracy') plt.ylabel('k') plt.show()

Wykonanie tego kodu pokaże ci następujący wykres opisujący dokładność algorytmu z różnymi wartościami k .

Obraz: Wykres używany do testowania dokładności algorytmu przy różnych wartościach k.

Jak widać, wartość k wynosząca 3 zapewnia najlepszą dokładność dla naszego modelu i zbioru danych.

Używanie Flask do budowania API

Rdzeń aplikacji, czyli algorytm przewidujący cyfry z obrazów, jest już gotowy. Następnie musimy udekorować algorytm warstwą API, aby był dostępny do użycia. Użyjmy popularnego frameworka internetowego Flask, aby zrobić to czysto i zwięźle.

Zaczniemy od zainstalowania Flaska i zależności związanych z przetwarzaniem obrazu w środowisku wirtualnym:

 pip install Flask Pillow scikit-image

Po zakończeniu instalacji przechodzimy do tworzenia pliku punktu wejścia aplikacji:

 touch app.py

Zawartość pliku będzie wyglądać tak:

 import os from flask import Flask from views import PredictDigitView, IndexView app = Flask(__name__) app.add_url_rule( '/api/predict', view_func=PredictDigitView.as_view('predict_digit'), methods=['POST'] ) app.add_url_rule( '/', view_func=IndexView.as_view('index'), methods=['GET'] ) if __name__ == 'main': port = int(os.environ.get("PORT", 5000)) app.run(host='0.0.0.0', port=port)

Pojawi się błąd mówiący, że PredictDigitView i IndexView nie są zdefiniowane. Następnym krokiem jest utworzenie pliku, który zainicjuje te widoki:

 from flask import render_template, request, Response from flask.views import MethodView, View from flask.views import View from repo import ClassifierRepo from services import PredictDigitService from settings import CLASSIFIER_STORAGE class IndexView(View): def dispatch_request(self): return render_template('index.html') class PredictDigitView(MethodView): def post(self): repo = ClassifierRepo(CLASSIFIER_STORAGE) service = PredictDigitService(repo) image_data_uri = request.json['image'] prediction = service.handle(image_data_uri) return Response(str(prediction).encode(), status=200)

Po raz kolejny napotkamy błąd o nierozwiązanym imporcie. Pakiet Views opiera się na trzech plikach, których jeszcze nie mamy:

  • Ustawienia
  • Repo
  • Usługa

Będziemy je wdrażać jeden po drugim.

Ustawienia to moduł z konfiguracjami i zmiennymi stałymi. Przechowa dla nas ścieżkę do zserializowanego klasyfikatora. Nasuwa się logiczne pytanie: dlaczego muszę zapisywać klasyfikator?

Ponieważ jest to prosty sposób na poprawę wydajności Twojej aplikacji. Zamiast trenować klasyfikator za każdym razem, gdy otrzymujesz żądanie, będziemy przechowywać przygotowaną wersję klasyfikatora, umożliwiając jego działanie po wyjęciu z pudełka:

 import os BASE_DIR = os.getcwd() CLASSIFIER_STORAGE = os.path.join(BASE_DIR, 'storage/classifier.txt')

Mechanizm ustawień — pobieranie klasyfikatora — zostanie zainicjowany w kolejnym pakiecie na naszej liście, Repo . Jest to klasa z dwiema metodami pobierania i aktualizowania wytrenowanego klasyfikatora za pomocą wbudowanego modułu pickle Pythona:

 import pickle class ClassifierRepo: def __init__(self, storage): self.storage = storage def get(self): with open(self.storage, 'wb') as out: try: classifier_str = out.read() if classifier_str != '': return pickle.loads(classifier_str) else: return None except Exception: return None def update(self, classifier): with open(self.storage, 'wb') as in_: pickle.dump(classifier, in_)

Jesteśmy blisko sfinalizowania naszego API. Teraz brakuje mu tylko modułu Serwis . Jaki jest jego cel?

  • Pobierz wyszkolony klasyfikator z magazynu
  • Przekształć obraz przekazany z interfejsu użytkownika do formatu zrozumiałego dla klasyfikatora
  • Oblicz prognozę ze sformatowanym obrazem za pomocą klasyfikatora
  • Zwróć przepowiednię

Zakodujmy ten algorytm:

 from sklearn.datasets import load_digits from classifier import ClassifierFactory from image_processing import process_image class PredictDigitService: def __init__(self, repo): self.repo = repo def handle(self, image_data_uri): classifier = self.repo.get() if classifier is None: digits = load_digits() classifier = ClassifierFactory.create_with_fit( digits.data, digits.target ) self.repo.update(classifier) x = process_image(image_data_uri) if x is None: return 0 prediction = classifier.predict(x)[0] return prediction

Tutaj widać, że PredictDigitService ma dwie zależności: ClassifierFactory i process_image .

Zaczniemy od stworzenia klasy do tworzenia i trenowania naszego modelu:

 from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier class ClassifierFactory: @staticmethod def create_with_fit(data, target): model = KNeighborsClassifier(n_neighbors=3) model.fit(data, target) return model

API jest gotowe do działania. Teraz możemy przejść do etapu przetwarzania obrazu.

Przetwarzanie obrazu

Przetwarzanie obrazu to metoda wykonywania pewnych operacji na obrazie w celu jego ulepszenia lub wydobycia z niego przydatnych informacji. W naszym przypadku musimy płynnie przenieść obraz narysowany przez użytkownika do formatu modelu uczenia maszynowego.

Image alt: Przekształcanie narysowanych obrazów do formatu uczenia maszynowego.

Zaimportujmy kilku pomocników, aby osiągnąć ten cel:

 import numpy as np from skimage import exposure import base64 from PIL import Image, ImageOps, ImageChops from io import BytesIO

Przejście możemy podzielić na sześć odrębnych części:

1. Zamień przezroczyste tło na kolor

Image alt: zastąpienie tła na przykładowym obrazie.

 def replace_transparent_background(image): image_arr = np.array(image) if len(image_arr.shape) == 2: return image alpha1 = 0 r2, g2, b2, alpha2 = 255, 255, 255, 255 red, green, blue, alpha = image_arr[:, :, 0], image_arr[:, :, 1], image_arr[:, :, 2], image_arr[:, :, 3] mask = (alpha == alpha1) image_arr[:, :, :4][mask] = [r2, g2, b2, alpha2] return Image.fromarray(image_arr)

2. Przytnij otwarte granice

Obraz: przycinanie granic na przykładowym obrazie.

 def trim_borders(image): bg = Image.new(image.mode, image.size, image.getpixel((0,0))) diff = ImageChops.difference(image, bg) diff = ImageChops.add(diff, diff, 2.0, -100) bbox = diff.getbbox() if bbox: return image.crop(bbox) return image

3. Dodaj obramowania o równym rozmiarze

Obraz: dodawanie gotowych ramek o takim samym rozmiarze do przykładowego obrazu.

 def pad_image(image): return ImageOps.expand(image, border=30, fill='#fff')

4. Konwertuj obraz do trybu skali szarości

 def to_grayscale(image): return image.convert('L')

5. Odwróć kolory

Obraz: odwracanie kolorów przykładowego obrazu.

 def invert_colors(image): return ImageOps.invert(image)

6. Zmień rozmiar obrazu na format 8x8

Obraz: zmiana rozmiaru przykładowego obrazu do formatu 8x8.

 def resize_image(image): return image.resize((8, 8), Image.LINEAR)

Teraz możesz przetestować aplikację. Uruchom aplikację i wprowadź poniższe polecenie, aby wysłać żądanie z tym obrazem iStock do interfejsu API:

Wizerunek: Podstawowy wizerunek ręcznie rysowanej liczby osiem.

 export FLASK_APP=app flask run
 curl "http://localhost:5000/api/predict" -X "POST" -H "Content-Type: application/json" -d "{\"image\": \"data:image/png;base64,$(curl "https://media.istockphoto.com/vectors/number-eight-8-hand-drawn-with-dry-brush-vector-id484207302?k=6&m=484207302&s=170667a&w=0&h=s3YANDyuLS8u2so-uJbMA2uW6fYyyRkabc1a6OTq7iI=" | base64)\"}" -i

Powinieneś zobaczyć następujące dane wyjściowe:

 HTTP/1.1 100 Continue HTTP/1.0 200 OK Content-Type: text/html; charset=utf-8 Content-Length: 1 Server: Werkzeug/0.14.1 Python/3.6.3 Date: Tue, 27 Mar 2018 07:02:08 GMT 8

Przykładowy obraz przedstawiał numer 8, a nasza aplikacja poprawnie go jako taka zidentyfikowała.

Tworzenie okienka rysunku za pomocą React

Aby szybko załadować aplikację frontendową, użyjemy boilerplate’u CRA:

 create-react-app frontend cd frontend

Po założeniu miejsca pracy potrzebujemy również zależności do rysowania cyfr. Pakiet React-Sketch idealnie pasuje do naszych potrzeb:

 npm i react-sketch

Aplikacja składa się tylko z jednego komponentu. Możemy podzielić ten składnik na dwie części: logikę i widok .

Część widoku odpowiada za reprezentację okienka rysunku, przycisków Prześlij i Resetuj . Kiedy wchodzimy w interakcję, powinniśmy również reprezentować przewidywanie lub błąd. Z punktu widzenia logiki ma następujące obowiązki: przesyłanie zdjęć i czyszczenie szkicu .

Za każdym razem, gdy użytkownik kliknie przycisk Prześlij , komponent wyodrębni obraz z komponentu szkicu i odwoła się do funkcji makePrediction modułu API. Jeśli żądanie do zaplecza się powiedzie, ustawimy zmienną stanu przewidywania. W przeciwnym razie zaktualizujemy stan błędu.

Gdy użytkownik kliknie Resetuj , szkic zostanie wyczyszczony:

 import React, { useRef, useState } from "react"; import { makePrediction } from "./api"; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return null }

Logika jest wystarczająca. Teraz możemy dodać do niego interfejs wizualny:

 import React, { useRef, useState } from "react"; import { SketchField, Tools } from "react-sketch"; import { makePrediction } from "./api"; import logo from "./logo.svg"; import "./App.css"; const pixels = (count) => `${count}px`; const percents = (count) => `${count}%`; const MAIN_CONTAINER_WIDTH_PX = 200; const MAIN_CONTAINER_HEIGHT = 100; const MAIN_CONTAINER_STYLE = { width: pixels(MAIN_CONTAINER_WIDTH_PX), height: percents(MAIN_CONTAINER_HEIGHT), margin: "0 auto", }; const SKETCH_CONTAINER_STYLE = { border: "1px solid black", width: pixels(MAIN_CONTAINER_WIDTH_PX - 2), height: pixels(MAIN_CONTAINER_WIDTH_PX - 2), backgroundColor: "white", }; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return ( <div className="App" style={MAIN_CONTAINER_STYLE}> <div> <header className="App-header"> <img src={logo} className="App-logo" alt="logo" /> <h1 className="App-title">Draw a digit</h1> </header> <div style={SKETCH_CONTAINER_STYLE}> <SketchField ref={sketchRef} width="100%" height="100%" tool={Tools.Pencil} imageFormat="jpg" lineColor="#111" lineWidth={10} /> </div> {prediction && <h3>Predicted value is: {prediction}</h3>} <button onClick={handleClear}>Clear</button> <button onClick={handleSubmit}>Guess the number</button> {error && <p style={{ color: "red" }}>Something went wrong</p>} </div> </div> ); }; export default App;

Komponent jest gotowy, przetestuj go, uruchamiając i przechodząc do localhost:3000 po:

 npm run start

Aplikacja demonstracyjna jest dostępna tutaj. Możesz także przeglądać kod źródłowy na GitHub.

Zawijanie

Jakość tego klasyfikatora nie jest idealna i nie udaję, że jest. Różnica między danymi, które wykorzystaliśmy do treningu, a danymi pochodzącymi z interfejsu użytkownika, jest ogromna. Mimo to stworzyliśmy działającą aplikację od podstaw w niecałe 30 minut.

Obraz: Animacja przedstawiająca ukończoną aplikację identyfikującą odręczne cyfry.

W tym czasie doskonaliliśmy nasze umiejętności w czterech obszarach:

  • Nauczanie maszynowe
  • Rozwój zaplecza
  • Przetwarzanie obrazu
  • Rozwój frontendu

Nie brakuje potencjalnych przypadków użycia oprogramowania zdolnego do rozpoznawania odręcznych cyfr, od oprogramowania edukacyjnego i administracyjnego po usługi pocztowe i finansowe.

Dlatego mam nadzieję, że ten artykuł zmotywuje Cię do poprawy umiejętności uczenia maszynowego, przetwarzania obrazu oraz rozwoju front-endu i back-endu, a także wykorzystania tych umiejętności do projektowania wspaniałych i użytecznych aplikacji.

Jeśli chcesz poszerzyć swoją wiedzę na temat uczenia maszynowego i przetwarzania obrazów, możesz zapoznać się z naszym samouczkiem dotyczącym uczenia maszynowego.