機器學習數字識別——從零到應用

已發表: 2022-03-11

機器學習、計算機視覺、構建強大的 API 和創建漂亮的 UI 是令人興奮的領域,見證了許多創新。

前兩個需要廣泛的數學和科學,而 API 和 UI 開發以算法思維和設計靈活的架構為中心。 它們非常不同,因此決定接下來要學習哪一個可能具有挑戰性。 本文的目的是演示如何在創建圖像處理應用程序時使用這四種方法。

我們要構建的應用程序是一個簡單的數字識別器。 你畫,機器預測數字。 簡單是必不可少的,因為它使我們能夠看到大局而不是關注細節。

為簡單起見,我們將使用最流行且易於學習的技術。 機器學習部分將使用 Python 作為後端應用程序。 至於應用程序的交互方面,我們將通過一個無需介紹的 JavaScript 庫進行操作:React。

機器學習猜測數字

我們應用程序的核心部分是猜測中獎號碼的算法。 機器學習將成為用於實現良好猜測質量的工具。 這種基本的人工智能允許系統使用給定數量的數據自動學習。 從廣義上講,機器學習是一個在數據中尋找巧合或一組巧合以依靠它們來猜測結果的過程。

我們的圖像識別過程包含三個步驟:

  • 獲取用於訓練的繪製數字的圖像
  • 訓練系統通過訓練數據猜測數字
  • 使用新/未知數據測試系統

環境

我們需要一個虛擬環境來使用 Python 中的機器學習。 這種方法很實用,因為它管理所有必需的 Python 包,因此您無需擔心它們。

讓我們使用以下終端命令安裝它:

 python3 -m venv virtualenv source virtualenv/bin/activate

訓練模型

在我們開始編寫代碼之前,我們需要為我們的機器選擇一個合適的“老師”。 通常,數據科學專業人員會在選擇最佳模型之前嘗試不同的模型。 我們將跳過需要大量技能的非常高級的模型,並繼續使用 k-最近鄰算法。

它是一種算法,它獲取一些數據樣本並將它們排列在一個按給定特徵集排序的平面上。 為了更好地理解它,讓我們回顧一下下圖:

圖片:排列在平面上的機器學習數據樣本

要檢測Green Dot的類型,我們應該檢查k最近鄰的類型,其中k是參數集。 考慮到上圖,如果k等於 1、2、3 或 4,則猜測將是黑色三角形,因為大多數綠點最近的k個鄰居都是黑色三角形。 如果我們將k增加到 5,那麼大多數對像都是藍色方塊,因此猜測將是藍色方塊

創建我們的機器學習模型需要一些依賴項:

  • sklearn.neighbors.KNeighborsClassifier是我們將使用的分類器。
  • sklearn.model_selection.train_test_split是幫助我們將數據拆分為訓練數據和用於檢查模型正確性的數據的函數。
  • sklearn.model_selection.cross_val_score是為模型的正確性打分的函數。 值越高,正確性越好。
  • sklearn.metrics.classification_report是顯示模型猜測的統計報告的函數。
  • sklearn.datasets是用於獲取訓練數據(數字圖像)的包。
  • numpy是一個廣泛用於科學的包,因為它提供了一種高效且舒適的方式來在 Python 中操作多維數據結構。
  • matplotlib.pyplot是用於可視化數據的包。

讓我們從安裝和導入所有這些開始:

 pip install sklearn numpy matplotlib scipy from sklearn.datasets import load_digits from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split, cross_val_score import numpy as np import matplotlib.pyplot as plt

現在,我們需要加載 MNIST 數據庫。 MNIST 是機器學習領域數千名新手使用的經典手寫圖像數據集:

 digits = load_digits()

獲取並準備好數據後,我們可以進行下一步,將數據分為兩部分:訓練測試

我們將使用 75% 的數據來訓練我們的模型來猜測數字,我們將使用其餘的數據來測試模型的正確性:

 (X_train, X_test, y_train, y_test) = train_test_split( digits.data, digits.target, test_size=0.25, random_state=42 )

數據現在已經安排好了,我們可以使用它了。 我們將嘗試為我們的模型找到最佳參數k ,這樣猜測會更準確。 在這個階段,我們不能忘記k值,因為我們必須用不同的k值來評估模型。

讓我們看看為什麼必須考慮一系列k值以及這如何提高我們模型的準確性:

 ks = np.arange(2, 10) scores = [] for k in ks: model = KNeighborsClassifier(n_neighbors=k) score = cross_val_score(model, X_train, y_train, cv=5) score.mean() scores.append(score.mean()) plt.plot(scores, ks) plt.xlabel('accuracy') plt.ylabel('k') plt.show()

執行此代碼將顯示以下圖表,描述算法在不同k值下的準確性。

圖片:用於測試具有不同 k 值的算法準確性的圖。

如您所見, k值為 3 可確保我們的模型和數據集的最佳精度。

使用 Flask 構建 API

應用程序核心是一種從圖像中預測數字的算法,現已準備就緒。 接下來,我們需要用 API 層來裝飾算法以使其可供使用。 讓我們使用流行的 Flask Web 框架來簡潔明了地完成這項工作。

我們將從在虛擬環境中安裝 Flask 和與圖像處理相關的依賴項開始:

 pip install Flask Pillow scikit-image

安裝完成後,我們開始創建應用程序的入口點文件:

 touch app.py

該文件的內容將如下所示:

 import os from flask import Flask from views import PredictDigitView, IndexView app = Flask(__name__) app.add_url_rule( '/api/predict', view_func=PredictDigitView.as_view('predict_digit'), methods=['POST'] ) app.add_url_rule( '/', view_func=IndexView.as_view('index'), methods=['GET'] ) if __name__ == 'main': port = int(os.environ.get("PORT", 5000)) app.run(host='0.0.0.0', port=port)

您將收到一條錯誤消息,指出PredictDigitViewIndexView 。 下一步是創建一個將初始化這些視圖的文件:

 from flask import render_template, request, Response from flask.views import MethodView, View from flask.views import View from repo import ClassifierRepo from services import PredictDigitService from settings import CLASSIFIER_STORAGE class IndexView(View): def dispatch_request(self): return render_template('index.html') class PredictDigitView(MethodView): def post(self): repo = ClassifierRepo(CLASSIFIER_STORAGE) service = PredictDigitService(repo) image_data_uri = request.json['image'] prediction = service.handle(image_data_uri) return Response(str(prediction).encode(), status=200)

再一次,我們將遇到有關未解決導入的錯誤。 Views包依賴於我們還沒有的三個文件:

  • 設置
  • 回購
  • 服務

我們將一一實施。

Settings是一個具有配置和常量變量的模塊。 它將為我們存儲序列化分類器的路徑。 它引出了一個合乎邏輯的問題:為什麼我需要保存分類器?

因為這是提高應用程序性能的一種簡單方法。 我們不會在每次收到請求時都訓練分類器,而是存儲分類器的準備版本,使其能夠開箱即用:

 import os BASE_DIR = os.getcwd() CLASSIFIER_STORAGE = os.path.join(BASE_DIR, 'storage/classifier.txt')

設置機制——獲取分類器——將在我們列表中的下一個包Repo中初始化。 這是一個有兩種方法的類,可以使用 Python 的內置pickle模塊檢索和更新經過訓練的分類器:

 import pickle class ClassifierRepo: def __init__(self, storage): self.storage = storage def get(self): with open(self.storage, 'wb') as out: try: classifier_str = out.read() if classifier_str != '': return pickle.loads(classifier_str) else: return None except Exception: return None def update(self, classifier): with open(self.storage, 'wb') as in_: pickle.dump(classifier, in_)

我們即將完成我們的 API。 現在它只缺少服務模塊。 它的目的是什麼?

  • 從存儲中獲取訓練好的分類器
  • 將從 UI 傳遞的圖像轉換為分類器可以理解的格式
  • 通過分類器使用格式化圖像計算預測
  • 返回預測

讓我們編寫這個算法:

 from sklearn.datasets import load_digits from classifier import ClassifierFactory from image_processing import process_image class PredictDigitService: def __init__(self, repo): self.repo = repo def handle(self, image_data_uri): classifier = self.repo.get() if classifier is None: digits = load_digits() classifier = ClassifierFactory.create_with_fit( digits.data, digits.target ) self.repo.update(classifier) x = process_image(image_data_uri) if x is None: return 0 prediction = classifier.predict(x)[0] return prediction

在這裡您可以看到PredictDigitService有兩個依賴項: ClassifierFactoryprocess_image

我們將首先創建一個類來創建和訓練我們的模型:

 from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier class ClassifierFactory: @staticmethod def create_with_fit(data, target): model = KNeighborsClassifier(n_neighbors=3) model.fit(data, target) return model

API 已準備好採取行動。 現在我們可以進行圖像處理步驟。

圖像處理

圖像處理是一種對圖像執行某些操作以增強圖像或從中提取一些有用信息的方法。 在我們的例子中,我們需要將用戶繪製的圖像平滑地轉換為機器學習模型格式。

Image alt:將繪製的圖像轉換為機器學習格式。

讓我們導入一些助手來實現這個目標:

 import numpy as np from skimage import exposure import base64 from PIL import Image, ImageOps, ImageChops from io import BytesIO

我們可以將過渡分為六個不同的部分:

1.用顏色替換透明背景

Image alt:替換示例圖像的背景。

 def replace_transparent_background(image): image_arr = np.array(image) if len(image_arr.shape) == 2: return image alpha1 = 0 r2, g2, b2, alpha2 = 255, 255, 255, 255 red, green, blue, alpha = image_arr[:, :, 0], image_arr[:, :, 1], image_arr[:, :, 2], image_arr[:, :, 3] mask = (alpha == alpha1) image_arr[:, :, :4][mask] = [r2, g2, b2, alpha2] return Image.fromarray(image_arr)

2.修剪開放邊界

圖像:修剪示例圖像的邊框。

 def trim_borders(image): bg = Image.new(image.mode, image.size, image.getpixel((0,0))) diff = ImageChops.difference(image, bg) diff = ImageChops.add(diff, diff, 2.0, -100) bbox = diff.getbbox() if bbox: return image.crop(bbox) return image

3.添加大小相等的邊框

圖像:為示例圖像添加預設和相同大小的邊框。

 def pad_image(image): return ImageOps.expand(image, border=30, fill='#fff')

4.將圖像轉換為灰度模式

def to_grayscale(image): return image.convert('L')

5.反轉顏色

圖像:反轉樣本圖像的顏色。

 def invert_colors(image): return ImageOps.invert(image)

6.調整圖片大小為8x8格式

圖像:將示例圖像調整為 8x8 格式。

 def resize_image(image): return image.resize((8, 8), Image.LINEAR)

現在您可以測試應用程序了。 運行應用程序並輸入以下命令,將帶有此 iStock 圖像的請求發送到 API:

圖像: 手繪數字八的股票圖像。

 export FLASK_APP=app flask run
 curl "http://localhost:5000/api/predict" -X "POST" -H "Content-Type: application/json" -d "{\"image\": \"data:image/png;base64,$(curl "https://media.istockphoto.com/vectors/number-eight-8-hand-drawn-with-dry-brush-vector-id484207302?k=6&m=484207302&s=170667a&w=0&h=s3YANDyuLS8u2so-uJbMA2uW6fYyyRkabc1a6OTq7iI=" | base64)\"}" -i

您應該看到以下輸出:

 HTTP/1.1 100 Continue HTTP/1.0 200 OK Content-Type: text/html; charset=utf-8 Content-Length: 1 Server: Werkzeug/0.14.1 Python/3.6.3 Date: Tue, 27 Mar 2018 07:02:08 GMT 8

示例圖像描繪了數字 8,我們的應用程序正確識別了它。

通過 React 創建繪圖窗格

為了快速引導前端應用程序,我們將使用 CRA 樣板:

 create-react-app frontend cd frontend

設置工作場所後,我們還需要一個依賴項來繪製數字。 react-sketch 包完全符合我們的需求:

 npm i react-sketch

該應用程序只有一個組件。 我們可以將這個組件分為兩部分:邏輯和視圖

視圖部分負責表示繪圖窗格、提交重置按鈕。 在交互時,我們還應該表示預測或錯誤。 從邏輯上看,它有以下職責:提交圖片清除草圖

每當用戶單擊Submit時,該組件將從草圖組件中提取圖像並調用 API 模塊的makePrediction函數。 如果對後端的請求成功,我們將設置預測狀態變量。 否則,我們將更新錯誤狀態。

當用戶單擊Reset時,草圖將清除:

 import React, { useRef, useState } from "react"; import { makePrediction } from "./api"; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return null }

邏輯就足夠了。 現在我們可以給它添加可視化界面:

 import React, { useRef, useState } from "react"; import { SketchField, Tools } from "react-sketch"; import { makePrediction } from "./api"; import logo from "./logo.svg"; import "./App.css"; const pixels = (count) => `${count}px`; const percents = (count) => `${count}%`; const MAIN_CONTAINER_WIDTH_PX = 200; const MAIN_CONTAINER_HEIGHT = 100; const MAIN_CONTAINER_STYLE = { width: pixels(MAIN_CONTAINER_WIDTH_PX), height: percents(MAIN_CONTAINER_HEIGHT), margin: "0 auto", }; const SKETCH_CONTAINER_STYLE = { border: "1px solid black", width: pixels(MAIN_CONTAINER_WIDTH_PX - 2), height: pixels(MAIN_CONTAINER_WIDTH_PX - 2), backgroundColor: "white", }; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return ( <div className="App" style={MAIN_CONTAINER_STYLE}> <div> <header className="App-header"> <img src={logo} className="App-logo" alt="logo" /> <h1 className="App-title">Draw a digit</h1> </header> <div style={SKETCH_CONTAINER_STYLE}> <SketchField ref={sketchRef} width="100%" height="100%" tool={Tools.Pencil} imageFormat="jpg" lineColor="#111" lineWidth={10} /> </div> {prediction && <h3>Predicted value is: {prediction}</h3>} <button onClick={handleClear}>Clear</button> <button onClick={handleSubmit}>Guess the number</button> {error && <p style={{ color: "red" }}>Something went wrong</p>} </div> </div> ); }; export default App;

組件已準備就緒,通過執行並轉到localhost:3000對其進行測試:

 npm run start

演示應用程序可在此處獲得。 您也可以在 GitHub 上瀏覽源代碼。

包起來

這個分類器的質量並不完美,我並不假裝它是完美的。 我們用於訓練的數據與來自 UI 的數據之間的差異是巨大的。 儘管如此,我們還是在不到 30 分鐘的時間內從頭開始創建了一個工作應用程序。

圖片:動畫顯示最終確定的應用程序識別手寫數字。

在此過程中,我們在四個領域磨練了我們的技能:

  • 機器學習
  • 後端開發
  • 圖像處理
  • 前端開發

能夠識別手寫數字的軟件不乏潛在的用例,從教育和管理軟件到郵政和金融服務。

因此,我希望這篇文章能激勵你提高機器學習能力、圖像處理能力和前後端開發能力,並利用這些技能設計出精彩實用的應用程序。

如果您想拓寬機器學習和圖像處理方面的知識,您可能需要查看我們的對抗性機器學習教程。