热门标签 | HotTags
当前位置:  开发笔记 > 前端 > 正文

tensorflowLSTM+CTC实现端到端OCR

最近在做OCR相关的东西,关于OCR真的是有悠久了历史了,最开始用tesseract然而效果总是不理想,其中字符分割真的是个博大精深的问题,那么多年那么多算法,然而应用到实际总是有诸多问题。比如说非等

最近在做OCR相关的东西,关于OCR真的是有悠久了历史了,最开始用tesseract然而效果总是不理想,其中字符分割真的是个博大精深的问题,那么多年那么多算法,然而应用到实际总是有诸多问题。比如说非等间距字体的分割,汉字的分割,有光照阴影的图片的字体分割等等,针对特定的问题,特定的算法能有不错的效果,但也仅限于特定问题,很难有一些通用的结果。于是看了Xlvector的博客之后,发现可以端到端来实现OCR,他是基于mxnet的,于是我想把它转到tensorflow这个框架来,顺便还能熟悉一下这个框架。本文主要介绍实现思路,更加细节的实现方法见另一篇。

正文

生成数据

利用captcha来生成验证码,具体生成验证码的代码请见这里,共生成4-6位包含数字和英文大小写的训练图片128000张和测试图片400张。命名规则就是num_label.png,生成的图片如下图
code = image_name.split('/')[2].split('_')[1].split('.')[0] code = [SPACE_INDEX if code == SPACE_TOKEN else maps[c] for c in list(code)] self.labels.append(code) print(image_name,' ',code) @property def size(self): return len(self.labels) def input_index_generate_batch(self,index=None): if index: image_batch=[self.image[i] for i in index] label_batch=[self.labels[i] for i in index] else: # get the whole data as input image_batch=self.image label_batch=self.labels def get_input_lens(sequences): lengths = np.asarray([len(s) for s in sequences], dtype=np.int64) return sequences,lengths batch_inputs,batch_seq_len = get_input_lens(np.array(image_batch)) batch_labels = sparse_tuple_from_label(label_batch) return batch_inputs,batch_seq_len,batch_labels

需要注意的是tensorflow lstm输入格式的问题,其label tensor应该是稀疏矩阵,所以读取图片和label之后,还要进行一些处理,具体可以看代码
关于载入图片,发现12.8w张图一次读进内存,内存也就涨了5G,如果训练数据加大,还是加一个pipeline来读比较好。

网络结构

然后是网络结构

1234567891011121314151617181920212223242526272829303132333435363738
graph = tf.Graph()with graph.as_default():    inputs = tf.placeholder(tf.float32, [None, None, num_features])    labels = tf.sparse_placeholder(tf.int32)    seq_len = tf.placeholder(tf.int32, [None])    # Stacking rnn cells    stack = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(FLAGS.num_hidden,state_is_tuple=True) for i in range(FLAGS.num_layers)] , state_is_tuple=True)    # The second output is the last state and we will no use that    outputs, _ = tf.nn.dynamic_rnn(stack, inputs, seq_len, dtype=tf.float32)    shape = tf.shape(inputs)    batch_s, max_timesteps = shape[0], shape[1]    # Reshaping to apply the same weights over the timesteps    outputs = tf.reshape(outputs, [-1, FLAGS.num_hidden])    # Truncated normal with mean 0 and stdev=0.1    W = tf.Variable(tf.truncated_normal([FLAGS.num_hidden,                                         num_classes],                                        stddev=0.1),name='W')    b = tf.Variable(tf.constant(0., shape=[num_classes],name='b'))    # Doing the affine projection    logits = tf.matmul(outputs, W) + b    # Reshaping back to the original shape    logits = tf.reshape(logits, [batch_s, -1, num_classes])    # Time major    logits = tf.transpose(logits, (1, 0, 2))    global_step = tf.Variable(0,trainable=False)    loss = tf.nn.ctc_loss(labels=labels,inputs=logits, sequence_length=seq_len)    cost = tf.reduce_mean(loss)    #optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,    # momentum=FLAGS.momentum).minimize(cost,global_step=global_step)    optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.initial_learning_rate,            beta1=FLAGS.beta1,beta2=FLAGS.beta2).minimize(loss,global_step=global_step)    # Option 2: tf.contrib.ctc.ctc_beam_search_decoder    # (it's slower but you'll get better results)    #decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len,merge_repeated=False)    decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len,merge_repeated=False)    # Inaccuracy: label error rate    lerr = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))

这里我参考了stackoverflow的一篇帖子写的,根据tensorflow 1.0.1的版本做了微调,使用了Adam作为optimizer。
需要注意的是ctc_beam_search_decoder是非常耗时的,见下图
ctc_beam_search_decoder
和greedy_decoder的区别是,greedy_decoder根据当前序列预测下一个字符,并且取概率最高的作为结果,再此基础上再进行下一次预测。而beam_search_decoder每次会保存取k个概率最高的结果,以此为基础再进行预测,并将下一个字符出现的概率与当前k个出现的概率相乘,这样就可以减缓贪心造成的丢失好解的情况,当k=1的时候,二者就一样了。

结果

—update—
稍微调一调,网络可以跑到85%以上。
把网络用在识别身份证号,试了73张网上爬的(不同分辨率下的)真实图片,错了一张,准确率在98%左右(不过毕竟身份证号比较简单)

大概14个epoch后,准确率过了50%,现在跑到了73%的正确率。
accuracy
最后,代码托管在Github上。

后记

百度出了一个warpCTC可以加速CTC的计算,试用了一下CPU的版本发现好像没什么速度的提升,不知道是不是姿势不对,回头再试试GPU的版本。
对于更加细节的实现方法(输入输出的构造,以及warpCTC和内置ctc_loss的异同)放在了另一篇博客。

  • warpCTC的GPU版本试过之后发现速度差不多,但是能极大的减少CPU的占用
  • 对于不同的优化器,数据,同样的参数是不能普适的。往往之前的参数可以收敛,换个optimizer,数据,网络就不能收敛了。这个时候要微调参数。对于不同的优化器之间区别,文末有一篇神文可以参考

如果有发现问题,请前辈们一定要不吝赐教,在下方留言指出,或者在github上提出issue


推荐阅读
  • 本文介绍了Java工具类库Hutool,该工具包封装了对文件、流、加密解密、转码、正则、线程、XML等JDK方法的封装,并提供了各种Util工具类。同时,还介绍了Hutool的组件,包括动态代理、布隆过滤、缓存、定时任务等功能。该工具包可以简化Java代码,提高开发效率。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Python使用Pillow包生成验证码图片的方法
    本文介绍了使用Python中的Pillow包生成验证码图片的方法。通过随机生成数字和符号,并添加干扰象素,生成一幅验证码图片。需要配置好Python环境,并安装Pillow库。代码实现包括导入Pillow包和随机模块,定义随机生成字母、数字和字体颜色的函数。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 目前正在做毕业设计,一个关于校园服务的app,我会抽取已完成的相关代码写到文章里。一是为了造福这个曾经帮助过我的社区,二是写文章的同时更能巩固相关知识的记忆。一、前言在爬取教务系统 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 点击上方“新机器视觉”,选择加”星标”或“置顶”重磅干货,第一时间送达很早就想总结一下前段时间学习HALCON的心得,但由于其他的事情总是抽不出时间。去年有过一段时间的集中学习,做 ... [详细]
  • 知识图谱表示概念:知识图谱是由一些相互连接的实体和他们的属性构成的。换句话说,知识图谱是由一条条知识组成,每条知识表示为一个SPO三元组(Subject-Predicate-Obj ... [详细]
  • 图片文字转换成word软件好用吗?
      图片文字都需要进行转换才能进行二次利用,因为这些文字都是不能编辑和复制的“死文字”word图片导出。进行转换的话就需要借助软件帮忙,图片文字转换成word软 ... [详细]
  • java编写一个为网站生成验证码的程序_Java后端产生验证码后台验证功能的实现代码...
    直接跳severlet在java后台生成验证码:RequestMapping(valueyzm.action)publicvoidYzm(HttpSessions ... [详细]
  • [Hei.Captcha]Asp.NetCore跨平台图形验证码实现
    写在前面说起来比较丢脸。我们有个手机的验证码发送逻辑需要使用验证码,这块本来项目里面就有验证码绘制逻辑,.NetFramework的,使用的包是System.Drawing,我 ... [详细]
  • Java实现验证码的制作
    验证码概述为什么使用验证码?验证码(CAPTCHA)是一种全自动程序。主要是为了区分“进行操作的是不是人”。如果没有验证码机制,将会导致以下的问题:对特定网站不断进行登录,破解密码 ... [详细]
  • 开发笔记:GD库的基本信息,图像的旋转水印缩略图验证码,以及图像类的封装
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了GD库的基本信息,图像的旋转水印缩略图验证码,以及图像类的封装相关的知识,希望对你有一定的参考价值。GD ... [详细]
author-avatar
小帅哥小羊儿_309
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有