热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于TensorFlow的Keras高级API实现手写体数字识别

前言这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为

前言

这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为大家提供代码哈哈哈哈


一、Keras?


1.Keras简介

Keras是由纯python编写的基于theano/tensorflow的深度学习框架。 Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结果,如果有如下需求,可以优先选择Keras。


2.为什么

目前Keras已经被TensorFlow收录,添加到TensorFlow 中,成为其默认的框架,成为TensorFlow官方的高级API。Keras简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性),用户友好:Keras是为人类而不是天顶星人设计的API。用户的使用体验始终是我们考虑的首要和中心内容。Keras遵循减少认知困难的最佳实践:Keras提供一致而简洁的API, 能够极大减少一般应用下用户的工作量,同时,Keras提供清晰和具有实践意义的bug反馈。


二、全连接神经网络实现


1.思路

导入数据-------> 选择模型------>设计神经网络------->编译------->训练权重参数------->预测


2.实现代码

定义函数 train() 实现(导入数据———>训练权重参数)。
定义函数 text() 实现 预测及输出结果。

导入数据:mnist = tf.keras.datasets.mnist #导入mnist
选择模型:model = tf.keras.models.Sequential()
有两种类型的模型,序贯模型(Sequential)和函数式模型(Model),函数式模型应用更为广泛,序贯模型是函数式模型的一种特殊情况。
序贯模型(Sequential) :单输入单输出,一条路通到底,层与层之间只有相邻关系,没有跨层连接。这种模型编译速度快,操作也比较简单;

设计神经网络:

tf.keras.layers.Flatten(input_shape=(28,28)),tf.keras.layers.Dense(512,activation='relu'),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())

编译:

model.compile(optimizer = 优化器,loss = 损失函数,metrics = ["准确率”]')

训练权重参数:

history = model.fit(x_train,y_train,batch_size=每次训练图片数量,epochs=训练次数,
validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()

train函数全部代码

def train():mnist = tf.keras.datasets.mnist #导入mnist(x_train,y_train),(x_test,y_test) = mnist.load_data() #分割x_train,x_test =x_train/255.0, x_test/255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28,28)),tf.keras.layers.Dense(512,activation='relu'),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())])model.compile(optimizer= 'adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])#评价指标 categorical_accuracy和 sparse_categorical_accuracy#注意修改路径checkpoint_save_path="C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt"if os.path.exists(checkpoint_save_path + '.index'):print('------load the model--------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)#断点续训history = model.fit(x_train,y_train,batch_size=25,epochs=30,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])model.summary()#以下为打印训练准确率及损失率等acc = history.history['sparse_categorical_accuracy']val_acc = history.history['val_sparse_categorical_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']f = Figure(figsize=(6,6),dpi=60)a = f.add_subplot(1,2,1)a.plot(acc,label = 'Training Accuracy')a.plot(val_acc,label = 'Validation Accuracy')#验证精度a.legend() b = f.add_subplot(1,2,2)b.plot(loss,label = 'Training Loss')b.plot(val_loss,label = 'Validation Loss')b.legend() canvas = FigureCanvasTkAgg(f,master=root)canvas.draw()canvas.get_tk_widget().place(x=60,y=100)

test函数全部代码

#预测结果打印
def text():#注意修改路径与函数train上面保存的路径一致model_save_path &#61; "C:/Users/VULCAN/sxti/TEST/Disconnect_detection/mnist.ckpt"model &#61; tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape&#61;(28,28)),tf.keras.layers.Dense(512,activation&#61;&#39;relu&#39;),tf.keras.layers.Dense(128,activation&#61;&#39;relu&#39;),tf.keras.layers.Dense(10,activation&#61;&#39;softmax&#39;,kernel_regularizer&#61;tf.keras.regularizers.l2())])model.load_weights(model_save_path)for i in range(1):img &#61; Image.open("tem2.png")#强制压缩为28&#xff0c;28img &#61; img.resize((28,28),Image.ANTIALIAS)#将原有图像转换为灰度图img_arr &#61; np.array(img.convert("L"))#图片反相for i in range(28):for j in range(28):if img_arr[i][j]<100:img_arr[i][j]&#61;255else:img_arr[i][j]&#61; 0img_arr &#61; img_arr/255.0x_predict &#61; img_arr[tf.newaxis,...]result &#61; model.predict(x_predict)pred &#61; np.argmax(result , axis &#61; 1)#在GUI界面显示结果e4 &#61; l &#61; tk.Label(root,text &#61; pred, bg&#61;"white",font&#61;("Arial,12"),width&#61;8)e4.place(x&#61;990,y&#61;440)

三、GUI设计

这部分我直接附上代码并在代码中作必要的注释。

全部所需的库函数&#xff1a;

#使用Tkinter前需要先导入
import tkinter as tk
#导入对话框模块
import tkinter.filedialog
#创建画布需要的库
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
#创建工具栏需要的库
from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk
#快捷键需要的模块2
from matplotlib.backend_bases import key_press_handler
#导入绘图需要的模块
from matplotlib.figure import Figure
import cv2
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image,ImageTk

其他关于图片文件的导入及摄像头调用的函数定义代码&#xff1a;

#调取摄像头并拍摄图片
def buttonl():capture &#61; cv2.VideoCapture(0) #cv2模块调取摄像头while(capture.isOpened()):ret,frame &#61; capture.read() #ret表示捕获是否成功frame &#61; frame[:,80:560] #拍照默认为640*480cv2.imwrite("tem1.png",frame)dig_Gray &#61; cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)ref2,dig_Gray &#61; cv2.threshold(dig_Gray,100,255,cv2.THRESH_BINARY)cv2.imwrite("tem2.png",dig_Gray)breakglobal photo1,photo2#将图片显示到界面上img1 &#61; Image.open("tem1.png")img1 &#61; img1.resize((128,128))photo1 &#61; ImageTk.PhotoImage(img1)l1 &#61; tk.Label(root,bg&#61;"red",image &#61; photo1).place(x&#61;950,y&#61;100)img2 &#61; Image.open("tem2.png")img2 &#61; img2.resize((128,128))photo2 &#61; ImageTk.PhotoImage(img2)l2 &#61; tk.Label(root,bg&#61;"red",image &#61; photo2).place(x&#61;950,y&#61;250)#保存当前摄像头画面
def frame():capture &#61; cv2.VideoCapture(0)#控件定义while(capture.isOpened()):ref,frame &#61; capture.read()frame &#61; frame[:,80:560]cvimage &#61; cv2.cvtColor(frame,cv2.COLOR_BGR2RGBA)pilImage &#61; Image.fromarray(cvimage)pilImage &#61; pilImage.resize((360,360),Image.ANTIALIAS)tkImage &#61; ImageTk.PhotoImage(image &#61; pilImage)canvas.create_image(0,0,anchor &#61; "nw",image &#61; tkImage)root.update()root.after(10)
#选择文件
def select_pic():file_path &#61; tk.filedialog.askopenfilename(title&#61;"选择文件",initialdir &#61; (os.path.expanduser(r"")))image &#61; Image.open(file_path)image.save("tem1.png")gray &#61; image.convert("L")gray.save("tem2.png")global photo3,photo4
#将图片显示在界面上img3 &#61; Image.open("tem1.png")img3 &#61; image.resize((128,128))photo3 &#61; ImageTk.PhotoImage(img3)l3 &#61; tk.Label(root,bg&#61;"red",image &#61; photo3).place(x&#61;950,y&#61;100)img4 &#61; Image.open("tem2.png")img4 &#61; img4.resize((128,128))photo4 &#61; ImageTk.PhotoImage(img4)l4 &#61; tk.Label(root,bg&#61;"red",image &#61; photo4).place(x&#61;950,y&#61;250)

主函数部分&#xff1a;

if __name__ &#61;&#61;&#39;__main__&#39;:root &#61; tk.Tk()#第二步&#xff0c;给窗口的可视化起名字root.title(&#39;手写体数字识别&#39;)#第三步&#xff0c;设定窗口的大小&#xff08;长*宽&#xff09;root.geometry(&#39;1176x520&#39;) #这里的乘是小xroot.configure(bg &#61; "#C0C0C0")f &#61; Figure(figsize&#61;(6,6), dpi&#61;60)a&#61;f.add_subplot(1,2,1) #添加子图&#xff1a;1行1列第一个a.plot(0,0)b&#61;f.add_subplot(1,2,2) #添加子图&#xff0c;1行1列第二个b.plot(0,0)#将绘制的图形显示到tkinter&#xff1a;创建属于root的canvas画布&#xff0c;并将图f置于画布上 canvas&#61;FigureCanvasTkAgg(f,master&#61;root)canvas.draw()#注意show方法已经过时&#xff0c;改用drawcanvas.get_tk_widget().place(x&#61;60,y&#61;100)b1 &#61; tk.Button(root,text&#61;&#39;训练&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;train).place(x&#61;168,y&#61;35)b2 &#61; tk.Button(root,text&#61;&#39;拍照&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;frame).place(x&#61;550,y&#61;35)b3 &#61; tk.Button(root,text&#61;&#39;测试&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;text).place(x&#61;960,y&#61;35)b4 &#61; tk.Button(root,text&#61;&#39;导入图片&#39;,bg&#61;&#39;white&#39;,font&#61;(&#39;Arial&#39;,12),width&#61;12,height&#61;1,command&#61;select_pic).place(x&#61;680,y&#61;35)b5 &#61; tk.Button(root,text&#61;&#39;识别结果&#39;,font&#61;(&#39;Arial&#39;,12),bg&#61;&#39;white&#39;).place(x&#61;990,y&#61;400)canvas&#61;tk.Canvas(root,bg&#61;"white",width&#61;360,height&#61;360) #绘制画布#控件位置设置canvas.place(x&#61;500,y&#61;100)b6&#61;tk.Button(root,text&#61;"保存",bg&#61;"white",width&#61;15,height&#61;2,command&#61;buttonl).place(x&#61;620,y&#61;420)#第六步&#xff0c;主窗口循环显示root.mainloop()

最后附上界面

在这里插入图片描述


推荐阅读
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文详细介绍了Spring的JdbcTemplate的使用方法,包括执行存储过程、存储函数的call()方法,执行任何SQL语句的execute()方法,单个更新和批量更新的update()和batchUpdate()方法,以及单查和列表查询的query()和queryForXXX()方法。提供了经过测试的API供使用。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • springmvc学习笔记(十):控制器业务方法中通过注解实现封装Javabean接收表单提交的数据
    本文介绍了在springmvc学习笔记系列的第十篇中,控制器的业务方法中如何通过注解实现封装Javabean来接收表单提交的数据。同时还讨论了当有多个注册表单且字段完全相同时,如何将其交给同一个控制器处理。 ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
author-avatar
徐涛
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有