暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

机器学习 | 使用Scikit-Learn识别手写数字

数艺学苑 2020-10-13
2722

手写数字识别是一个十分具有挑战性的问题。在实际生活中,对手写文本或数字进行识别分类是一个很重要的功能,数字识别的应用包括邮政分拣、银行支票处理、表单数据输入等领域。

走进机器学习

尽管深度学习算法在识别手写数字这一任务上使用CNN(卷积神经网络)已经取得了突破性的进展,但是我们依然需要学习掌握sklearn这个工具的分类算法,特别是针对这种非结构化的数据。本文旨在通过一个简单的案例来演示如何使用Scikit-Learn识别手写数字,抛砖引玉来引导读者们进入机器学习的知识殿堂。

sklearn简介

Sklearn (全称 Scikit-Learn) 是基于 Python 语言的机器学习工具。它建立在 NumPy, SciPy, Pandas 和 Matplotlib 之上,里面的 API 的设计非常好,所有对象的接口简单,很适合新手上路。

在 Sklearn 里面有六大任务模块:分别是分类、回归、聚类、降维、模型选择和预处理

分类:识别某个对象属于哪个类别,常用的算法有:SVM(支持向量机)、nearest neighbors(最近邻)、random forest(随机森林),常见的应用有:垃圾邮件识别、图像识别。

回归:预测与对象相关联的连续值属性,常见的算法有:SVR(支持向量机)、 ridge regression(岭回归)、Lasso,常见的应用有:药物反应,预测股价。

聚类:将相似对象自动分组,常用的算法有:k-Means、spectral clustering、mean-shift,常见的应用有:客户细分,分组实验结果。

降维:减少要考虑的随机变量的数量,常见的算法有:PCA(主成分分析)、feature selection(特征选择)、non-negative matrix factorization(非负矩阵分解),常见的应用有:可视化,提高效率。

模型选择:比较,验证,选择参数和模型,常用的模块有:grid search(网格搜索)、cross validation(交叉验证)、 metrics(度量)。它的目标是通过参数调整提高精度。

预处理:特征提取和归一化,常用的模块有:preprocessing,feature extraction,常见的应用有:把输入数据(如文本)转换为机器学习算法可用的数据。

安装SKlearn

Scikit-learn需要:

  • Python(> = 2.7或> = 3.4),

  • NumPy(> = 1.8.2),

  • SciPy(> = 0.13.3)。

【注意】Scikit-learn 0.20是支持Python 2.7和Python 3.4的最后一个版本。Scikit-learn 0.21将需要Python 3.5或更高版本。

如果你已经安装了numpy和scipy,那么安装scikit-learn的最简单方法就是使用 pip
或者conda

1pip install -U scikit-learn
2conda install scikit-learn

如果你尚未安装NumPy或SciPy,你也可以使用conda或pip安装它们。使用pip时,请确保使用binary wheels,并且不会从源头重新编译NumPy和SciPy,这可能在使用特定配置的操作系统和硬件(例如Raspberry Pi上的Linux)时发生。从源代码构建numpy和scipy可能很复杂(特别是在Windows上),需要仔细配置以确保它们与线性代数例程的优化实现相关联。为了方便,我们可以使用如下所述的第三方发行版本。

数据集介绍

Digits数据集包含1797张8x8像素的图像,每个图像都是一个灰色的手写数字。

获取方法

1from sklearn import datasets
2digits = datasets.load_digits()

算法选择

sklearn 实现了很多算法,面对这么多的算法,如何去选择呢?其实选择的主要考虑的就是需要解决的问题以及数据量的大小。sklearn官方提供了一个选择算法的引导图。这里提供翻译好的中文版本,供大家参考:

在本文中,我们将使用Scikit-Learn的数字数据集正确识别手写的单个数字(即0-9),这是一个Python库,其中包含许多有用的算法,我们将在这些算法上进行修改与实现,并使用Logistic回归 的分类器来完成手写数字的识别,并预测一些未知的手写数字的值。

Scikit-Learn库的四步建模模式:

  1. 导入所需要的使用的模型。

  2. 制作模型实例。

  3. 依据数据训练模型并存储从数据中所学习到的信息。

  4. 使用训练过程中学习到的模型信息来预测识别新数据的结果。

准备工作

首先为模型导入必要的库并加载数字数据集,导入scikit-learn库的svm模块。我们可以创建一个SVC类型的估算器,然后选择一个初始设置,分配值C和gamma通用值,这些值可以在后续的分析过程中进行对应的调整。

1from sklearn import svm
2svc = svm.SVC(gamma=0.001, C=100)
3
4import numpy as np
5import pandas as pd
6import matplotlib.pyplot as plt
7from sklearn import datasets
8digits = datasets.load_digits()
9# 导入库并加载数据集

数据展示

手写数字的图像包含在digits.images数组中。此数组的每个元素都是一副图像,该图像由一个8x8数值矩阵标识,矩阵中的数值代表从0(白色)到15(黑色)的灰度。

1digits.images[0]
2# 矩阵数组


我们的数据集以数字存储。通过以下命令可以获取数字的灰度图像。

1%matplotlib inline
2plt.imshow(digits.images[0],cmap=plt.cm.gray_r,interpolation='nearest')


由图像(即目标)表示的数值包含在digit.targets数组中。数据集也是由1797张图像组成的训练集,可以通过以下命令查看。

1digits.target
2digits.target.size


数据分组

该数据集包含1797个元素,我们选择将前1791个数据视为训练集,并使用后6个数据作为验证集来进行验证。我们通过使用matplotlib库可以详细看到这六位数字。

 1import matplotlib.pyplot as plt
2%matplotlib inline
3plt.subplot(321)
4plt.imshow(digits.images[1791],cmap=plt.cm.gray_r,interpolation='nearest')
5plt.subplot(322)
6plt.imshow(digits.images[1792],cmap=plt.cm.gray_r,interpolation='nearest')
7plt.subplot(323)
8plt.imshow(digits.images[1793],cmap=plt.cm.gray_r,interpolation='nearest')
9plt.subplot(324)
10plt.imshow(digits.images[1794],cmap=plt.cm.gray_r,interpolation='nearest')
11plt.subplot(325)
12plt.imshow(digits.images[1795],cmap=plt.cm.gray_r,interpolation='nearest')
13plt.subplot(326)
14plt.imshow(digits.images[1796],cmap=plt.cm.gray_r,interpolation='nearest')


SVC估算

接下来开始训练我们先前定义的SVC估算器。

1svc.fit(digits.data[1:1790],digits.target[1:1790])
2# 拟合模型

然后测试估计器,使其识别验证数据集的后六位数字。

1svc.predict(digits.data[1791:1796])
2# 预测模型


如我们所见,svc估算器已完成学习并可识别大部分手写数字,并正确预估了六位数字中的五位。

使用Logistic算法训练模型

logistic算法简介

Logistic回归又称logistic回归分析,是一种广义的线性回归分析模型,虽然名字中有“回归”二字,但实际却是一种分类学习方法。对于回归这个概念,简单的说,回归就是用一条线对N个数据点进行一个拟合,这个拟合的过程就叫做回归。Logistic回归分类算法就是对数据集建立回归公式,以此进行分类。

Logistic回归优点:

  1、实现简单;

  2、分类时计算量非常小,速度很快,存储资源低;

缺点:

  1、容易欠拟合,一般准确度不太高

  2、只能处理两分类问题(在此基础上衍生出来的softmax可以用于多分类),但必须线性可分。

接下来让我们来看看Scikit-Learn四步建模模式。

我们需要将数据集分为训练集和测试集,以确保在训练模型后能够较好的预测识别新数据。

1from sklearn.model_selection import train_test_split
2x_train,x_test,y_train,y_test = train_test_split(digits.data,digits.target,test_size=0.25,random_state=0)
3# 将数据集分为训练和测试集

  1. 步骤一:导入我们所需要使用的模型

     from sklearn.linear_model import LogisticRegression
     # 使用Logistic回归导入
  2. 步骤二:制作模型的实例

     logisticRegr = LogisticRegression()
     # 制作模型实例
  3. 步骤三:训练模型

     logisticRegr.fit(x_train,y_train)
     # 训练模型
  4. 步骤四:验证预测的新数据并评估该模型的性能

     predictions = logisticRegr.predict(x_test)
     score = logisticRegr.score(x_test,y_test)
     print(score)

评估准确性

什么是混淆矩阵?混淆矩阵是一个误差矩阵,通常我们可以通过混淆矩阵来评定监督学习算法的性能。在监督学习中混淆矩阵为方阵,方阵的大小通常为一个(真实值,预测值)或者(预测值,真实值),所以通过混淆矩阵我们更清晰的看出,预测集与真实集中混合的一部分。

混淆矩阵可以清晰的反映出真实值与预测值相互吻合的部分,也可以反映出与预测值不吻合的部分,如下图所示。


实现预测值与真实值在相同特征下的比较,如果同时成立则放入相对应的矩阵位置,如果不成立则放入不相匹配的矩阵位置,
将真实值与预测值相互匹配与不匹配项放入矩阵中,我们称这个矩阵为混淆矩阵。

使用sklearn绘制混淆矩阵

 1import matplotlib.pyplot as plt
2import seaborn as sns
3from sklearn import metrics
4cm = metrics.confusion_matrix(y_test,predictions)
5plt.figure(figsize=(9,9))
6sns.heatmap(cm,annot=True,fmt=".3f",linewidths=.5,square=True,cmap='RdPu',linecolor="pink")
7plt.ylabel('Actual label')
8plt.xlabel('Predicted label')
9all_sample_title='Accuracy score:{0}'.format(score)
10plt.title(all_sample_title,size=15)


矩阵的第九行数据表示 有关数字8的数据有48个,有3个被误判为1,1个被误判为2,1个被误判为9。从整体观察混淆矩阵,效果还是不错的。

从本文中,我们可以学习如何轻松导入数据集、使用Scikit-Learn构建模型、训练模型、使用模型进行预测以及找到我们识别数字的准确性。

在本案例中,我们可以清楚的看到该模型在95%的情况下识别准确度为100%,因此,该模型在95%的情况下都可以正常进行运作。

本文作者




指导老师





长按二维码关注我们



欢迎关注微信公众号

“沈浩老师“



原文链接:

https://medium.com/@navyashree.raghupatro/recognizing-handwritten-digits-with-scikit-learn-8d248dc01b6d

文章转载自数艺学苑,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论