机器学习数字识别——从零到应用

已发表: 2022-03-11

机器学习、计算机视觉、构建强大的 API 和创建漂亮的 UI 是令人兴奋的领域,见证了许多创新。

前两个需要广泛的数学和科学,而 API 和 UI 开发以算法思维和设计灵活的架构为中心。 它们非常不同,因此决定接下来要学习哪一个可能具有挑战性。 本文的目的是演示如何在创建图像处理应用程序时使用这四种方法。

我们要构建的应用程序是一个简单的数字识别器。 你画,机器预测数字。 简单是必不可少的,因为它使我们能够看到大局而不是关注细节。

为简单起见,我们将使用最流行且易于学习的技术。 机器学习部分将使用 Python 作为后端应用程序。 至于应用程序的交互方面,我们将通过一个无需介绍的 JavaScript 库进行操作:React。

机器学习猜测数字

我们应用程序的核心部分是猜测中奖号码的算法。 机器学习将成为用于实现良好猜测质量的工具。 这种基本的人工智能允许系统使用给定数量的数据自动学习。 从广义上讲,机器学习是一个在数据中寻找巧合或一组巧合以依靠它们来猜测结果的过程。

我们的图像识别过程包含三个步骤:

  • 获取用于训练的绘制数字的图像
  • 训练系统通过训练数据猜测数字
  • 使用新/未知数据测试系统

环境

我们需要一个虚拟环境来使用 Python 中的机器学习。 这种方法很实用,因为它管理所有必需的 Python 包,因此您无需担心它们。

让我们使用以下终端命令安装它:

 python3 -m venv virtualenv source virtualenv/bin/activate

训练模型

在我们开始编写代码之前,我们需要为我们的机器选择一个合适的“老师”。 通常,数据科学专业人员会在选择最佳模型之前尝试不同的模型。 我们将跳过需要大量技能的非常高级的模型,并继续使用 k-最近邻算法。

它是一种算法,它获取一些数据样本并将它们排列在一个按给定特征集排序的平面上。 为了更好地理解它,让我们回顾一下下图:

图片:排列在平面上的机器学习数据样本

要检测Green Dot的类型,我们应该检查k最近邻的类型,其中k是参数集。 考虑上图,如果k等于 1、2、3 或 4,则猜测将是黑色三角形,因为大多数绿点最近的k个邻居都是黑色三角形。 如果我们将k增加到 5,那么大多数对象都是蓝色方块,因此猜测将是蓝色方块

创建我们的机器学习模型需要一些依赖项:

  • sklearn.neighbors.KNeighborsClassifier是我们将使用的分类器。
  • sklearn.model_selection.train_test_split是帮助我们将数据拆分为训练数据和用于检查模型正确性的数据的函数。
  • sklearn.model_selection.cross_val_score是为模型的正确性打分的函数。 值越高,正确性越好。
  • sklearn.metrics.classification_report是显示模型猜测的统计报告的函数。
  • sklearn.datasets是用于获取训练数据(数字图像)的包。
  • numpy是一个广泛用于科学的包,因为它提供了一种高效且舒适的方式来在 Python 中操作多维数据结构。
  • matplotlib.pyplot是用于可视化数据的包。

让我们从安装和导入所有这些开始:

 pip install sklearn numpy matplotlib scipy from sklearn.datasets import load_digits from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split, cross_val_score import numpy as np import matplotlib.pyplot as plt

现在,我们需要加载 MNIST 数据库。 MNIST 是机器学习领域数千名新手使用的经典手写图像数据集:

 digits = load_digits()

获取并准备好数据后,我们可以进行下一步,将数据分为两部分:训练测试

我们将使用 75% 的数据来训练我们的模型来猜测数字,我们将使用其余的数据来测试模型的正确性:

 (X_train, X_test, y_train, y_test) = train_test_split( digits.data, digits.target, test_size=0.25, random_state=42 )

数据现在已经安排好了,我们可以使用它了。 我们将尝试为我们的模型找到最佳参数k ,这样猜测会更准确。 在这个阶段,我们不能忘记k值,因为我们必须用不同的k值来评估模型。

让我们看看为什么必须考虑一系列k值以及这如何提高我们模型的准确性:

 ks = np.arange(2, 10) scores = [] for k in ks: model = KNeighborsClassifier(n_neighbors=k) score = cross_val_score(model, X_train, y_train, cv=5) score.mean() scores.append(score.mean()) plt.plot(scores, ks) plt.xlabel('accuracy') plt.ylabel('k') plt.show()

执行此代码将显示以下图表,描述算法在不同k值下的准确性。

图片:用于测试具有不同 k 值的算法准确性的图。

如您所见, k值为 3 可确保我们的模型和数据集的最佳精度。

使用 Flask 构建 API

应用程序核心是一种从图像中预测数字的算法,现已准备就绪。 接下来,我们需要用 API 层来装饰算法以使其可供使用。 让我们使用流行的 Flask Web 框架来简洁明了地完成这项工作。

我们将从在虚拟环境中安装 Flask 和与图像处理相关的依赖项开始:

 pip install Flask Pillow scikit-image

安装完成后,我们开始创建应用程序的入口点文件:

 touch app.py

该文件的内容将如下所示:

 import os from flask import Flask from views import PredictDigitView, IndexView app = Flask(__name__) app.add_url_rule( '/api/predict', view_func=PredictDigitView.as_view('predict_digit'), methods=['POST'] ) app.add_url_rule( '/', view_func=IndexView.as_view('index'), methods=['GET'] ) if __name__ == 'main': port = int(os.environ.get("PORT", 5000)) app.run(host='0.0.0.0', port=port)

您将收到一条错误消息,指出PredictDigitViewIndexView 。 下一步是创建一个将初始化这些视图的文件:

 from flask import render_template, request, Response from flask.views import MethodView, View from flask.views import View from repo import ClassifierRepo from services import PredictDigitService from settings import CLASSIFIER_STORAGE class IndexView(View): def dispatch_request(self): return render_template('index.html') class PredictDigitView(MethodView): def post(self): repo = ClassifierRepo(CLASSIFIER_STORAGE) service = PredictDigitService(repo) image_data_uri = request.json['image'] prediction = service.handle(image_data_uri) return Response(str(prediction).encode(), status=200)

再一次,我们将遇到有关未解决导入的错误。 Views包依赖于我们还没有的三个文件:

  • 设置
  • 回购
  • 服务

我们将一一实施。

Settings是一个具有配置和常量变量的模块。 它将为我们存储序列化分类器的路径。 它引出了一个合乎逻辑的问题:为什么我需要保存分类器?

因为这是提高应用程序性能的一种简单方法。 我们不会在每次收到请求时都训练分类器,而是存储分类器的准备版本,使其能够开箱即用:

 import os BASE_DIR = os.getcwd() CLASSIFIER_STORAGE = os.path.join(BASE_DIR, 'storage/classifier.txt')

设置机制——获取分类器——将在我们列表中的下一个包Repo中初始化。 这是一个有两种方法的类,可以使用 Python 的内置pickle模块检索和更新经过训练的分类器:

 import pickle class ClassifierRepo: def __init__(self, storage): self.storage = storage def get(self): with open(self.storage, 'wb') as out: try: classifier_str = out.read() if classifier_str != '': return pickle.loads(classifier_str) else: return None except Exception: return None def update(self, classifier): with open(self.storage, 'wb') as in_: pickle.dump(classifier, in_)

我们即将完成我们的 API。 现在它只缺少服务模块。 它的目的是什么?

  • 从存储中获取训练好的分类器
  • 将从 UI 传递的图像转换为分类器可以理解的格式
  • 通过分类器使用格式化图像计算预测
  • 返回预测

让我们编写这个算法:

 from sklearn.datasets import load_digits from classifier import ClassifierFactory from image_processing import process_image class PredictDigitService: def __init__(self, repo): self.repo = repo def handle(self, image_data_uri): classifier = self.repo.get() if classifier is None: digits = load_digits() classifier = ClassifierFactory.create_with_fit( digits.data, digits.target ) self.repo.update(classifier) x = process_image(image_data_uri) if x is None: return 0 prediction = classifier.predict(x)[0] return prediction

在这里您可以看到PredictDigitService有两个依赖项: ClassifierFactoryprocess_image

我们将首先创建一个类来创建和训练我们的模型:

 from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier class ClassifierFactory: @staticmethod def create_with_fit(data, target): model = KNeighborsClassifier(n_neighbors=3) model.fit(data, target) return model

API 已准备好采取行动。 现在我们可以进行图像处理步骤。

图像处理

图像处理是一种对图像执行某些操作以增强图像或从中提取一些有用信息的方法。 在我们的例子中,我们需要将用户绘制的图像平滑地转换为机器学习模型格式。

Image alt:将绘制的图像转换为机器学习格式。

让我们导入一些助手来实现这个目标:

 import numpy as np from skimage import exposure import base64 from PIL import Image, ImageOps, ImageChops from io import BytesIO

我们可以将过渡分为六个不同的部分:

1.用颜色替换透明背景

Image alt:替换示例图像的背景。

 def replace_transparent_background(image): image_arr = np.array(image) if len(image_arr.shape) == 2: return image alpha1 = 0 r2, g2, b2, alpha2 = 255, 255, 255, 255 red, green, blue, alpha = image_arr[:, :, 0], image_arr[:, :, 1], image_arr[:, :, 2], image_arr[:, :, 3] mask = (alpha == alpha1) image_arr[:, :, :4][mask] = [r2, g2, b2, alpha2] return Image.fromarray(image_arr)

2.修剪开放边界

图像:修剪示例图像的边框。

 def trim_borders(image): bg = Image.new(image.mode, image.size, image.getpixel((0,0))) diff = ImageChops.difference(image, bg) diff = ImageChops.add(diff, diff, 2.0, -100) bbox = diff.getbbox() if bbox: return image.crop(bbox) return image

3.添加大小相等的边框

图像:为示例图像添加预设和相同大小的边框。

 def pad_image(image): return ImageOps.expand(image, border=30, fill='#fff')

4.将图像转换为灰度模式

def to_grayscale(image): return image.convert('L')

5.反转颜色

图像:反转样本图像的颜色。

 def invert_colors(image): return ImageOps.invert(image)

6.调整图片大小为8x8格式

图像:将示例图像调整为 8x8 格式。

 def resize_image(image): return image.resize((8, 8), Image.LINEAR)

现在您可以测试应用程序了。 运行应用程序并输入以下命令,将带有此 iStock 图像的请求发送到 API:

图像: 手绘数字八的股票图像。

 export FLASK_APP=app flask run
 curl "http://localhost:5000/api/predict" -X "POST" -H "Content-Type: application/json" -d "{\"image\": \"data:image/png;base64,$(curl "https://media.istockphoto.com/vectors/number-eight-8-hand-drawn-with-dry-brush-vector-id484207302?k=6&m=484207302&s=170667a&w=0&h=s3YANDyuLS8u2so-uJbMA2uW6fYyyRkabc1a6OTq7iI=" | base64)\"}" -i

您应该看到以下输出:

 HTTP/1.1 100 Continue HTTP/1.0 200 OK Content-Type: text/html; charset=utf-8 Content-Length: 1 Server: Werkzeug/0.14.1 Python/3.6.3 Date: Tue, 27 Mar 2018 07:02:08 GMT 8

示例图像描绘了数字 8,我们的应用程序正确识别了它。

通过 React 创建绘图窗格

为了快速引导前端应用程序,我们将使用 CRA 样板:

 create-react-app frontend cd frontend

设置工作场所后,我们还需要一个依赖项来绘制数字。 react-sketch 包完全符合我们的需求:

 npm i react-sketch

该应用程序只有一个组件。 我们可以将这个组件分为两部分:逻辑和视图

视图部分负责表示绘图窗格、提交重置按钮。 在交互时,我们还应该表示预测或错误。 从逻辑上看,它有以下职责:提交图片清除草图

每当用户单击Submit时,该组件将从草图组件中提取图像并调用 API 模块的makePrediction函数。 如果对后端的请求成功,我们将设置预测状态变量。 否则,我们将更新错误状态。

当用户单击Reset时,草图将清除:

 import React, { useRef, useState } from "react"; import { makePrediction } from "./api"; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return null }

逻辑就足够了。 现在我们可以给它添加可视化界面:

 import React, { useRef, useState } from "react"; import { SketchField, Tools } from "react-sketch"; import { makePrediction } from "./api"; import logo from "./logo.svg"; import "./App.css"; const pixels = (count) => `${count}px`; const percents = (count) => `${count}%`; const MAIN_CONTAINER_WIDTH_PX = 200; const MAIN_CONTAINER_HEIGHT = 100; const MAIN_CONTAINER_STYLE = { width: pixels(MAIN_CONTAINER_WIDTH_PX), height: percents(MAIN_CONTAINER_HEIGHT), margin: "0 auto", }; const SKETCH_CONTAINER_STYLE = { border: "1px solid black", width: pixels(MAIN_CONTAINER_WIDTH_PX - 2), height: pixels(MAIN_CONTAINER_WIDTH_PX - 2), backgroundColor: "white", }; const App = () => { const sketchRef = useRef(null); const [error, setError] = useState(); const [prediction, setPrediction] = useState(); const handleSubmit = () => { const image = sketchRef.current.toDataURL(); setPrediction(undefined); setError(undefined); makePrediction(image).then(setPrediction).catch(setError); }; const handleClear = (e) => sketchRef.current.clear(); return ( <div className="App" style={MAIN_CONTAINER_STYLE}> <div> <header className="App-header"> <img src={logo} className="App-logo" alt="logo" /> <h1 className="App-title">Draw a digit</h1> </header> <div style={SKETCH_CONTAINER_STYLE}> <SketchField ref={sketchRef} width="100%" height="100%" tool={Tools.Pencil} imageFormat="jpg" lineColor="#111" lineWidth={10} /> </div> {prediction && <h3>Predicted value is: {prediction}</h3>} <button onClick={handleClear}>Clear</button> <button onClick={handleSubmit}>Guess the number</button> {error && <p style={{ color: "red" }}>Something went wrong</p>} </div> </div> ); }; export default App;

组件已准备就绪,通过执行并转到localhost:3000对其进行测试:

 npm run start

演示应用程序可在此处获得。 您也可以在 GitHub 上浏览源代码。

包起来

这个分类器的质量并不完美,我并不假装它是完美的。 我们用于训练的数据与来自 UI 的数据之间的差异是巨大的。 尽管如此,我们还是在不到 30 分钟的时间内从头开始创建了一个工作应用程序。

图片:动画显示最终确定的应用程序识别手写数字。

在此过程中,我们在四个领域磨练了我们的技能:

  • 机器学习
  • 后端开发
  • 图像处理
  • 前端开发

能够识别手写数字的软件不乏潜在的用例,从教育和管理软件到邮政和金融服务。

因此,我希望这篇文章能激励你提高机器学习能力、图像处理能力以及前后端开发能力,并利用这些技能设计出精彩实用的应用程序。

如果您想拓宽机器学习和图像处理方面的知识,您可能需要查看我们的对抗性机器学习教程。