暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

【Python】纯Python实现CNN识别手写体数字+GUI展示 MNIST数据集【附源码】

修电脑的杂货店 2021-07-07
2188

开发环境:Python3 + Numpy + PyQt5


项目源码文件夹:



在项目中,mnist_cnn_gui_main.py是主程序,运行这个文件即可执行整个项目。



废话不多说,我们首先来看一下这个项目究竟实现了哪些功能吧!


执行程序,我们首先看到的是一个用PyQt5开发的UI界面,我们可以选择两种模式,首先可以随机选择数据集中的灰度图片,其次我们可以使用自己用鼠标生成的灰度图来进行识别,这种方式会让我们的成果更加有趣新鲜。


下面给大家看一下项目演示视频(部分翻车!但总体来说还是有那味的。)


(演示视频)


手写识别的应用:


字符识别处理的信息可分为两大类:一类是文字信息, 处理的主要是用各国家,各民族的文字如:汉字,英文等书写或印刷的文本信息,目前在印刷体和联机手写方面技术已趋向成熟,并推出了很多应用系统;另一类是数据信息,主要是由阿拉伯数字及少量特殊符号组成的各种编号和统计数据,如:邮政编码,统计报表,财务报表,银行票据等等,处理这类信息的核心技术是手写数字识别。因此,手写数字的识别研究有着重大的现实意义。


手写数字识别研究的难点在于:


第一,不同数字之间字形相差不大,使得准确区分某些数字相当困难;

第二,数字虽然只有十种,而且笔划简单,但同一数字写法千差万别。使得手写数字识别的识别率和识别精度很低。


本设计主要分为三大步骤:

第一阶段,预处理;

第二阶段,特征提取;

第三阶段,分类器设计及识别。


第一阶段预处理阶段主要包括定位、二值化、去噪、切分、大小规格化、 细化等步骤,这里将采用一些成熟的算法。

第二阶段手写数字特征的提取结构化特征时主要根据图像像素的走向 ,准确判断出某段数字或字母的结构,如直线、折线、曲线、分叉线等。同时配合中线 特征等建立起较为准确的特征库。

第三阶段分类器设计及识别时将采用 BP神经网络算法设计分类器,通过。

这些算法本身的高容错率和算法本身的模糊判断等特性,再结合之前建立起的准确的特征库,从而提高手写数字识别时的正确率,达到理想的识别效果。


数据集的来源:


MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。



数据集下载网址:

http://yann.lecun.com/exdb/mnist/


算法原理:


k-近邻(kNN, k-NearestNeighbor)算法是一种基本分类与回归方法,
通俗点来说,就是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的 k 个实例,这 k 个实例的多数属于某个类,就把该输入实例分为这个类。

python 第三方库scikit-learn(sklearn)提供了knn的分类器。

MNIST手写数字数据库(Mixed National Institute of Standards and Technology database)包含70000张手写数字图片。这些数字是通过美国国家统计局的员工和美国高校的学生收集的。每张图片都是28x28的灰度图。


用mnist数据集训练出一个knn分类器,对新输入的手写数字进行识别。


注意:在安装PIL库时需要特别注意,PIL库只支持Python2,在Python3版本中只能安装Pillow库,但是调用的时候依然是:

from PIL import Image, ImageQt


源码解析:


数据的预处理:为了分析我们的数据并从中提取见解,有必要在开始建立机器学习模型之前对数据进行处理,即我们需要以模型可以理解的形式转换数据。由于机器无法理解图像,音频等形式的数据。在本次实验中,我们使用了官网上已经分析好的数据集,不需要重新进行数据预处理即可进行试验。但是,还是需要包装一些调用数据集的函数。


def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""读入MNIST数据集

Parameters
----------
normalize : 将图像的像素值正规化为0.0~1.0
one_hot_label :
one_hot_label为True的情况下,标签作为one-hot数组返回
one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
flatten : 是否将图像展开为一维数组

Returns
-------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
if not os.path.exists(save_file):
init_mnist()

with open(save_file, 'rb') as f:
dataset = pickle.load(f)

if normalize:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0

if one_hot_label:
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)


return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])


算法的选择:使用神经网络算法将数据集进行分类,采集不同图片和对应的标签,分析图片和标签的相关性,后续可根据图片解析出图片属于哪一类标签,即可得到识别出的数字(算法只是算出图片和每个数字的相似度,最后是反馈回一个相似度最大的标签)。


    def numerical_gradient(self, x, t):
"""求梯度(数值微分)


Parameters
----------
x : 输入数据
t : 教师标签


Returns
-------
具有各层的梯度的字典变量
grads['W1']、grads['W2']、...是各层的权重
grads['b1']、grads['b2']、...是各层的偏置
"""
loss_w = lambda w: self.loss(x, t)


grads = {}
for idx in (1, 2, 3):
grads['W' + str(idx)] = numerical_gradient(loss_w, self.params['W' + str(idx)])
grads['b' + str(idx)] = numerical_gradient(loss_w, self.params['b' + str(idx)])


return grads


def gradient(self, x, t):
"""求梯度(误差反向传播法)


Parameters
----------
x : 输入数据
t : 教师标签


Returns
-------
具有各层的梯度的字典变量
grads['W1']、grads['W2']、...是各层的权重
grads['b1']、grads['b2']、...是各层的偏置
"""
# forward
self.loss(x, t)


# backward
dout = 1
dout = self.last_layer.backward(dout)


layers = list(self.layers.values())
layers.reverse()
for layer in layers:
dout = layer.backward(dout)


# 设定
grads = {}
grads['W1'], grads['b1'] = self.layers['Conv1'].dW, self.layers['Conv1'].db
grads['W2'], grads['b2'] = self.layers['Affine1'].dW, self.layers['Affine1'].db
grads['W3'], grads['b3'] = self.layers['Affine2'].dW, self.layers['Affine2'].db


        return grads


对于一些数字,由于数据集的局限性,识别准确度不高,但是这个项目可以把我们领进深度学习的大门,一起探索更神奇的计算机世界!


纸上得来终觉浅,绝知此事要coding...


完整源码请在后台回复“代码”获取!

(部分代码上传不及时需要进入QQ群获取,群文件自取)



对这篇容对你有帮助,或者对本公众号内容有兴趣的同学可以加入官方QQ群详细交流探讨,互相学习共同进步,源码和具体操作流程,也会放到群里,如果有不懂得细节,群里也会有人回答。快加入我们的大家庭QQ群号:559369389  欢迎新成员的到来!


交流分享


官微君是一个立志于

实现电脑代替自己所有工作

而自己能躺在一边数钱的小人物

喜欢硬件编程

项目开发和各种有趣的想法

不管你是有梦想的孩子,还是算法大牛

君君都希望能和你共同进步




从今天起,

小编将在微信后台

以及评论区回复哦~ 

欢迎咨询问题!

小编定当知无不言

言无不尽!


代码:https://gitee.com/dongeast/a-computer-shop/issues


文章转载自修电脑的杂货店,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论