热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

分布式深度学习|TensorFlow单主机多GPU/多主机多GPU原理与实现

TF的部署分为单机部署和分布式部署。在训练集数据量很大的情况下,单机跑深度学习程序过于耗时,所以需要分布式并行计算。在分布式部署中,我们需

TF的部署分为单机部署和分布式部署。在训练集数据量很大的情况下,单机跑深度学习程序过于耗时,所以需要分布式并行计算。在分布式部署中,我们需要在不同主机节点,实现client,master,worker.

在这里插入图片描述

1. Single-Device Execution
1.1 单机多GPU训练
构建好图后,使用拓扑算法来决定执行哪一个节点,即对每个节点使用一个计数,值表示所依赖的未完成的节点数目,当一个节点的运算完成时,将依赖该节点的所有节点的计数减一。如果节点的计数为0,将其放入准备队列等待执行。

单机的多GPU训练, tensorflow的官方已经给了一个cifar的例子,已经有比较详细的代码和文档介绍, 这里大致说下多GPU的过程,以便方便引入到多机多GPU的介绍。单机多GPU的训练过程如下:

假设你的机器上有3个GPU,在单机单GPU的训练中,数据是一个batch一个batch的训练。 在单机多GPU中,数据一次处理3个batch(假设是3个GPU训练), 每个GPU处理一个batch的数据计算变量,或者说参数,保存在CPU节点上。刚开始的时候数据由CPU分发给3个GPU, 在GPU上完成计算,得到每个batch要更新的梯度然后在CPU上收集完了3个GPU上的要更新的梯度, 计算一下平均梯度,然后更新参数。继续循环这个过程。

通过这个过程,处理的速度取决于最慢那个GPU的速度。如果3个GPU的处理速度差不多的话, 处理速度就相当于单机单GPU的速度的3倍减去数据在CPU和GPU之间传输的开销,实际的效率提升看CPU和GPU之间数据的速度和处理数据的大小。

1.2 通俗解释

老师给小明和小华布置了10000张纸的乘法题并且把所有的乘法的结果加起来, 每张纸上有128道乘法题。 这里一张纸就是一个batch, batch_size就是128. 小明算加法比较快, 小华算乘法比较快,于是小华就负责计算乘法, 小明负责把小华的乘法结果加起来 。 这样小明就是CPU,小华就是GPU.


这样计算的话, 预计小明和小华两个人得要花费一个星期的时间才能完成老师布置的题目。 于是小明就招来2个算乘法也很快的小红和小亮。 于是每次小明就给小华,小红,小亮各分发一张纸,让他们算乘法, 他们三个人算完了之后, 把结果告诉小明, 小明把他们的结果加起来,然后再给他们没人分发一张算乘法的纸,依次循环,知道所有的算完。


这里小明采用的是同步模式,就是每次要等他们三个都算完了之后, 再统一算加法,算完了加法之后,再给他们三个分发纸张。这样速度就取决于他们三个中算乘法算的最慢的那个人, 和分发纸张的速度。

2. Multi-Device Execution
在分布式系统情况下,事情就变得复杂了很多,还好前述调度用了现有框架。那么对于TF来说,剩下的事情就是:

(1)决定运算在哪个设备上运行
(2)管理设备之间的数据传递

2.1 分布式多主机多GPU训练

随着设计的模型越来越复杂,模型参数越来越多,越来越大, 大到什么程度?多到什么程度? 多参数的个数上百亿个, 训练的数据多到按TB级别来衡量。大家知道每次计算一轮,都要计算梯度,更新参数。 当参数的量级上升到百亿量级甚至更大之后, 参数的更新的性能都是问题。 如果是单机16个GPU, 一个step最多也是处理16个batch, 这对于上TB级别的数据来说,不知道要训练到什么时候。于是就有了分布式的深度学习训练方法,或者说框架。

参数服务器

在介绍tensorflow的分布式训练之前,先说下参数服务器的概念。

前面说道, 当你的模型越来越大, 模型的参数越来越多,多到模型参数的更新,一台机器的性能都不够的时候, 很自然的我们就会想到把参数分开放到不同的机器去存储和更新。
因为碰到上面提到的那些问题, 所有参数服务器就被单独拧出来, 于是就有了参数服务器的概念。 参数服务器可以是多台机器组成的集群, 这个就有点类似分布式的存储架构了, 涉及到数据的同步,一致性等等, 一般是key-value的形式,可以理解为一个分布式的key-value内存数据库,然后再加上一些参数更新的操作。 详细的细节可以去google一下, 这里就不详细说了。 反正就是当性能不够的时候,
几百亿的参数分散到不同的机器上去保存和更新,解决参数存储和更新的性能问题。
借用上面的小明算题的例子,小明觉得自己算加法都算不过来了, 于是就叫了10个小明过来一起帮忙算。

gRPC (google remote procedure call)
TensorFlow分布式并行计算基于gRPC通信框架,其中包括一个master创建Session,还有多个worker负责执行计算图中的任务。

gRPC首先是一个RPC,即远程过程调用。通俗的解释是:假设你在本机上执行一段代码num=add(a,b),它调用了一个过程 call,然后返回了一个值num,你感觉这段代码只是在本机上执行的, 但实际情况,本机上的add方法是将参数打包发送给服务器,然后服务器运行服务器端的add方法,返回的结果再将数据打包返回给客户端。

结构

Cluster是Job的集合,Job是Task的集合。即一个Cluster可以切分多个Job,一个Job指一类特定的任务,每个Job包含多个Task,比如parameter server(ps)、worker。在大多数情况下,一个主机上只运行一个Task。

在分布式深度学习框架中,我们一般把Job划分为Parameter Server和Worker:Parameter Job是管理参数的存储和更新工作;Worker Job是来运行ops。如果参数的数量太大,一台主机处理不了,这就要需要多个Tasks。

2.2 分布式TensorFlow模式
In-graph 模式
模型并行,将模型的计算图的不同部分放在不同的主机上执行。

In-graph模式和单机多GPU模型有点类似。 还是一个小明算加法, 但是算乘法的就可以不止是他们一个教室的小华,小红,小亮了。 可以是其他教师的小张,小李。>In-graph模式, 把计算已经从单机多GPU,已经扩展到了多机多GPU了, 不过数据分发还是在一个节点。 这样的好处是配置简单, 其他多机多GPU的计算节点,只要起个join操作, 暴露一个网络接口,等在那里接受任务就好了。 这些计算节点暴露出来的网络接口,使用起来就跟本机的一个GPU的使用一样, 只要在操作的时候指定tf.device(“/job:worker/task:n”), 就可以向指定GPU一样,把操作指定到一个计算节点上计算,使用起来和多GPU的类似。 但是这样的坏处是训练数据的分发依然在一个节点上, 要把训练数据分发到不同的机器上, 严重影响并发训练速度。在大数据训练的情况下, 不推荐使用这种模式。

Between-graph 模式
数据并行,每台主机使用完全相同的计算图。

Between-graph模式下,训练的参数保存在参数服务器, 数据不用分发, 数据分片的保存在各个计算节点, 各个计算节点自己算自己的, 算完了之后, 把要更新的参数告诉参数服务器,参数服务器更新参数。这种模式的优点是不用训练数据的分发了, 尤其是在数据量在TB级的时候, 节省了大量的时间,所以大数据深度学习还是推荐使用Between-graph模式。

2.3 同步更新和异步更新
in-graph模式和between-graph模式都支持同步和异步更新。

在同步更新的时, 每次梯度更新,都要等所有分发出去的数据计算完成后,返回结果之后,把梯度累加算了均值之后,再更新参数。 这样的好处是loss的下降比较稳定, 但是这个的坏处也很明显, 处理的速度取决于最慢的那个分片计算的时间。

在异步更新时, 所有的计算节点,各自算自己的, 更新参数也是自己更新自己计算的结果, 这样的优点就是计算速度快, 计算资源能得到充分利用,但是缺点是loss的下降不稳定, 抖动大。

在数据量较小的情况下, 各个节点的计算能力比较均衡, 推荐使用同步模式;在数据量很大情况下,各个主机的计算性能掺差不齐的情况下,推荐使用异步的方式。

3. 应用实例
tensorflow官方有个分布式tensorflow的文档,但是例子没有完整的代码, 这里写了一个最简单的可以跑起来的例子,供大家参考,这里也傻瓜式给大家解释一下代码,以便更加通俗的理解。本文的实例——基于分布式Tensorflow的手写字体识别,都加了代码中文注释,更加通俗易懂。

3.1 运行步骤

ps节点执行:

python distributed.py --job_name=ps --task_index=0

worker1节点执行:

python distributed.py --job_name=worker --task_index=0

worker2 节点执行:

python distributed.py --job_name=worker --task_index=1

3.2 实验代码


# encoding:utf-8
import math
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dataflags = tf.app.flags
IMAGE_PIXELS = 28
# 定义默认训练参数和数据路径
flags.DEFINE_string('data_dir', '/tmp/mnist-data', 'Directory for storing mnist data')
flags.DEFINE_integer('hidden_units', 100, 'Number of units in the hidden layer of the NN')
flags.DEFINE_integer('train_steps', 10000, 'Number of training steps to perform')
flags.DEFINE_integer('batch_size', 100, 'Training batch size ')
flags.DEFINE_float('learning_rate', 0.01, 'Learning rate')
# 定义分布式参数
# 参数服务器parameter server节点
flags.DEFINE_string('ps_hosts', '192.168.32.145:22221', 'Comma-separated list of hostname:port pairs')
# 两个worker节点
flags.DEFINE_string('worker_hosts', '192.168.32.146:22221,192.168.32.160:22221','Comma-separated list of hostname:port pairs')
# 设置job name参数
flags.DEFINE_string('job_name', None, 'job name: worker or ps')
# 设置任务的索引
flags.DEFINE_integer('task_index', None, 'Index of task within the job')
# 选择异步并行,同步并行
flags.DEFINE_integer("issync", None, "是否采用分布式的同步模式,1表示同步模式,0表示异步模式")FLAGS = flags.FLAGSdef main(unused_argv):mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)if FLAGS.job_name is None or FLAGS.job_name == '':raise ValueError('Must specify an explicit job_name !')else:print 'job_name : %s' % FLAGS.job_nameif FLAGS.task_index is None or FLAGS.task_index == '':raise ValueError('Must specify an explicit task_index!')else:print 'task_index : %d' % FLAGS.task_indexps_spec = FLAGS.ps_hosts.split(',')worker_spec = FLAGS.worker_hosts.split(',')# 创建集群num_worker = len(worker_spec)cluster = tf.train.ClusterSpec({'ps': ps_spec, 'worker': worker_spec})server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)if FLAGS.job_name == 'ps':server.join()is_chief = (FLAGS.task_index == 0)# worker_device = '/job:worker/task%d/cpu:0' % FLAGS.task_indexwith tf.device(tf.train.replica_device_setter(cluster=cluster)):global_step = tf.Variable(0, name='global_step', trainable=False) # 创建纪录全局训练步数变量hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],stddev=1.0 / IMAGE_PIXELS), name='hid_w')hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name='hid_b')sm_w = tf.Variable(tf.truncated_normal([FLAGS.hidden_units, 10],stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name='sm_w')sm_b = tf.Variable(tf.zeros([10]), name='sm_b')x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])y_ = tf.placeholder(tf.float32, [None, 10])hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)hid = tf.nn.relu(hid_lin)y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))opt = tf.train.AdamOptimizer(FLAGS.learning_rate)train_step = opt.minimize(cross_entropy, global_step=global_step)# 生成本地的参数初始化操作init_opinit_op = tf.global_variables_initializer()train_dir = tempfile.mkdtemp()sv = tf.train.Supervisor(is_chief=is_chief, logdir=train_dir, init_op=init_op, recovery_wait_secs=1,global_step=global_step)if is_chief:print 'Worker %d: Initailizing session...' % FLAGS.task_indexelse:print 'Worker %d: Waiting for session to be initaialized...' % FLAGS.task_indexsess = sv.prepare_or_wait_for_session(server.target)print 'Worker %d: Session initialization complete.' % FLAGS.task_indextime_begin = time.time()print 'Traing begins @ %f' % time_beginlocal_step = 0while True:batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)train_feed = {x: batch_xs, y_: batch_ys}_, step = sess.run([train_step, global_step], feed_dict=train_feed)local_step += 1now = time.time()print '%f: Worker %d: traing step %d dome (global step:%d)' % (now, FLAGS.task_index, local_step, step)if step >= FLAGS.train_steps:breaktime_end = time.time()print 'Training ends @ %f' % time_endtrain_time = time_end - time_beginprint 'Training elapsed time:%f s' % train_timeval_feed = {x: mnist.validation.images, y_: mnist.validation.labels}val_xent = sess.run(cross_entropy, feed_dict=val_feed)print 'After %d training step(s), validation cross entropy = %g' % (FLAGS.train_steps, val_xent)sess.close()if __name__ == '__main__':tf.app.run()

4 . 参考文章
http://blog.csdn.net/luodongri/article/details/52596780
http://blog.csdn.net/u012436149/article/details/53140869
http://blog.csdn.net/stdcoutzyx/article/details/51645396
https://blog.csdn.net/yjk13703623757/article/details/80956268
TensorFlow实战[黄文坚 唐源 著]
https://blog.csdn.net/xbinworld/article/details/74781605


推荐阅读
  • t-io 2.0.0发布-法网天眼第一版的回顾和更新说明
    本文回顾了t-io 1.x版本的工程结构和性能数据,并介绍了t-io在码云上的成绩和用户反馈。同时,还提到了@openSeLi同学发布的t-io 30W长连接并发压力测试报告。最后,详细介绍了t-io 2.0.0版本的更新内容,包括更简洁的使用方式和内置的httpsession功能。 ... [详细]
  • 一句话解决高并发的核心原则
    本文介绍了解决高并发的核心原则,即将用户访问请求尽量往前推,避免访问CDN、静态服务器、动态服务器、数据库和存储,从而实现高性能、高并发、高可扩展的网站架构。同时提到了Google的成功案例,以及适用于千万级别PV站和亿级PV网站的架构层次。 ... [详细]
  • Sleuth+zipkin链路追踪SpringCloud微服务的解决方案
    在庞大的微服务群中,随着业务扩展,微服务个数增多,系统调用链路复杂化。Sleuth+zipkin是解决SpringCloud微服务定位和追踪的方案。通过TraceId将不同服务调用的日志串联起来,实现请求链路跟踪。通过Feign调用和Request传递TraceId,将整个调用链路的服务日志归组合并,提供定位和追踪的功能。 ... [详细]
  • 云原生应用最佳开发实践之十二原则(12factor)
    目录简介一、基准代码二、依赖三、配置四、后端配置五、构建、发布、运行六、进程七、端口绑定八、并发九、易处理十、开发与线上环境等价十一、日志十二、进程管理当 ... [详细]
  • ejava,刘聪dejava
    本文目录一览:1、什么是Java?2、java ... [详细]
  • 14亿人的大项目,腾讯云数据库拿下!
    全国人 ... [详细]
  • 知识图谱表示概念:知识图谱是由一些相互连接的实体和他们的属性构成的。换句话说,知识图谱是由一条条知识组成,每条知识表示为一个SPO三元组(Subject-Predicate-Obj ... [详细]
  • 云原生的十大开源项目是什么
    这篇“云原生的十大开源项目是什么”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值 ... [详细]
  • 本文介绍了OpenStack的逻辑概念以及其构成简介,包括了软件开源项目、基础设施资源管理平台、三大核心组件等内容。同时还介绍了Horizon(UI模块)等相关信息。 ... [详细]
  • 如何使用代理服务器进行网页抓取?
    本文介绍了如何使用代理服务器进行网页抓取,并探讨了数据驱动对竞争优势的重要性。通过网页抓取,企业可以快速获取并分析大量与需求相关的数据,从而制定营销战略。同时,网页抓取还可以帮助电子商务公司在竞争对手的网站上下载数百页的有用数据,提高销售增长和毛利率。 ... [详细]
  • 统一知识图谱学习和建议:更好地理解用户偏好
    本文介绍了一种将知识图谱纳入推荐系统的方法,以提高推荐的准确性和可解释性。与现有方法不同的是,本方法考虑了知识图谱的不完整性,并在知识图谱中传输关系信息,以更好地理解用户的偏好。通过大量实验,验证了本方法在推荐任务和知识图谱完成任务上的优势。 ... [详细]
  • 篇首语:本文由编程笔记#小编为大家整理,主要介绍了软件测试知识点之数据库压力测试方法小结相关的知识,希望对你有一定的参考价值。 ... [详细]
  • MySQL数据库锁机制及其应用(数据库锁的概念)
    本文介绍了MySQL数据库锁机制及其应用。数据库锁是计算机协调多个进程或线程并发访问某一资源的机制,在数据库中,数据是一种供许多用户共享的资源,如何保证数据并发访问的一致性和有效性是数据库必须解决的问题。MySQL的锁机制相对简单,不同的存储引擎支持不同的锁机制,主要包括表级锁、行级锁和页面锁。本文详细介绍了MySQL表级锁的锁模式和特点,以及行级锁和页面锁的特点和应用场景。同时还讨论了锁冲突对数据库并发访问性能的影响。 ... [详细]
  • Servlet多用户登录时HttpSession会话信息覆盖问题的解决方案
    本文讨论了在Servlet多用户登录时可能出现的HttpSession会话信息覆盖问题,并提供了解决方案。通过分析JSESSIONID的作用机制和编码方式,我们可以得出每个HttpSession对象都是通过客户端发送的唯一JSESSIONID来识别的,因此无需担心会话信息被覆盖的问题。需要注意的是,本文讨论的是多个客户端级别上的多用户登录,而非同一个浏览器级别上的多用户登录。 ... [详细]
  • 如果说以比特币为代表的货币区块链技术为1.0,以以太坊为代表的合同区块链技术为2.0,那么实现了完备的权限控制和安全保障的Hyperledger项目毫无疑问代表着区块链技术3.0 ... [详细]
author-avatar
手机用户2502902345
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有