热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

keras并行九三_keras多gpu并行运行案例

一、多张gpu的卡上使用keras有多张gpu卡时,推荐使用tensorflow作为后端。使用多张gpu运行model,可以分为两种情况,

一、多张gpu的卡上使用keras

有多张gpu卡时,推荐使用tensorflow 作为后端。使用多张gpu运行model,可以分为两种情况,一是数据并行,二是设备并行。

二、数据并行

数据并行将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据。

利用multi_gpu_model实现

keras.utils.multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False)

具体来说,该功能实现了单机多 GPU 数据并行性。 它的工作原理如下:

将模型的输入分成多个子批次。

在每个子批次上应用模型副本。 每个模型副本都在专用 GPU 上执行。

将结果(在 CPU 上)连接成一个大批量。

例如, 如果你的 batch_size 是 64,且你使用 gpus=2, 那么我们将把输入分为两个 32 个样本的子批次, 在 1 个 GPU 上处理 1 个子批次,然后返回完整批次的 64 个处理过的样本。

参数

model: 一个 Keras 模型实例。为了避免OOM错误,该模型可以建立在 CPU 上, 详见下面的使用样例。

gpus: 整数 >= 2 或整数列表,创建模型副本的 GPU 数量, 或 GPU ID 的列表。

cpu_merge: 一个布尔值,用于标识是否强制合并 CPU 范围内的模型权重。

cpu_relocation: 一个布尔值,用来确定是否在 CPU 的范围内创建模型的权重。如果模型没有在任何一个设备范围内定义,您仍然可以通过激活这个选项来拯救它。

返回

一个 Keras Model 实例,它可以像初始 model 参数一样使用,但它将工作负载分布在多个 GPU 上。

例子

import tensorflow as tf

from keras.applications import Xception

from keras.utils import multi_gpu_model

import numpy as np

num_samples = 1000

height = 224

width = 224

num_classes = 1000

# 实例化基础模型(或者「模版」模型)。

# 我们推荐在 CPU 设备范围内做此操作,

# 这样模型的权重就会存储在 CPU 内存中。

# 否则它们会存储在 GPU 上,而完全被共享。

with tf.device("/cpu:0"):

model = Xception(weights=None,

input_shape=(height, width, 3),

classes=num_classes)

# 复制模型到 8 个 GPU 上。

# 这假设你的机器有 8 个可用 GPU。

parallel_model = multi_gpu_model(model, gpus=8)

parallel_model.compile(loss="categorical_crossentropy",

optimizer="rmsprop")

# 生成虚拟数据

x = np.random.random((num_samples, height, width, 3))

y = np.random.random((num_samples, num_classes))

# 这个 `fit` 调用将分布在 8 个 GPU 上。

# 由于 batch size 是 256, 每个 GPU 将处理 32 个样本。

parallel_model.fit(x, y, epochs=20, batch_size=256)

# 通过模版模型存储模型(共享相同权重):

model.save("my_model.h5")

注意:

要保存多 GPU 模型,请通过模板模型(传递给 multi_gpu_model 的参数)调用 .save(fname) 或 .save_weights(fname) 以进行存储,而不是通过 multi_gpu_model 返回的模型。

即要用model来保存,而不是parallel_model来保存。

使用ModelCheckpoint() 遇到的问题

使用ModelCheckpoint()会遇到下面的问题:

TypeError: can"t pickle ...(different text at different situation) objects

这个问题和保存问题类似,ModelCheckpoint() 会自动调用parallel_model.save()来保存,而不是model.save(),因此我们要自己写一个召回函数,使得ModelCheckpoint()用model.save()。

修改方法:

class ParallelModelCheckpoint(ModelCheckpoint):

def __init__(self,model,filepath, monitor="val_loss", verbose=0,

save_best_only=False, save_weights_only=False,

mode="auto", period=1):

self.single_model = model

super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

def set_model(self, model):

super(ParallelModelCheckpoint,self).set_model(self.single_model)

checkpoint = ParallelModelCheckpoint(original_model)

ParallelModelCheckpoint调用的时候,model应该为原来的model而不是parallel_model。

EarlyStopping 没有此类问题

二、设备并行

设备并行适用于多分支结构,一个分支用一个gpu。

这种并行方法可以通过使用TensorFlow device scopes实现,下面是一个例子:

# Model where a shared LSTM is used to encode two different sequences in parallel

input_a = keras.Input(shape=(140, 256))

input_b = keras.Input(shape=(140, 256))

shared_lstm = keras.layers.LSTM(64)

# Process the first sequence on one GPU

with tf.device_scope("/gpu:0"):

encoded_a = shared_lstm(tweet_a)

# Process the next sequence on another GPU

with tf.device_scope("/gpu:1"):

encoded_b = shared_lstm(tweet_b)

# Concatenate results on CPU

with tf.device_scope("/cpu:0"):

merged_vector = keras.layers.concatenate([encoded_a, encoded_b],

axis=-1)

三、分布式运行

keras的分布式是利用TensorFlow实现的,要想完成分布式的训练,你需要将Keras注册在连接一个集群的TensorFlow会话上:

server = tf.train.Server.create_local_server()

sess = tf.Session(server.target)

from keras import backend as K

K.set_session(sess)

以上这篇keras 多gpu并行运行案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持云海天教程。

原文链接:https://blog.csdn.net/xingkongyidian/article/details/88343115



推荐阅读
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了Perl的测试框架Test::Base,它是一个数据驱动的测试框架,可以自动进行单元测试,省去手工编写测试程序的麻烦。与Test::More完全兼容,使用方法简单。以plural函数为例,展示了Test::Base的使用方法。 ... [详细]
  • 本文介绍了一个题目的解法,通过二分答案来解决问题,但困难在于如何进行检查。文章提供了一种逃逸方式,通过移动最慢的宿管来锁门时跑到更居中的位置,从而使所有合格的寝室都居中。文章还提到可以分开判断两边的情况,并使用前缀和的方式来求出在任意时刻能够到达宿管即将锁门的寝室的人数。最后,文章提到可以改成O(n)的直接枚举来解决问题。 ... [详细]
  • Android工程师面试准备及设计模式使用场景
    本文介绍了Android工程师面试准备的经验,包括面试流程和重点准备内容。同时,还介绍了建造者模式的使用场景,以及在Android开发中的具体应用。 ... [详细]
  • NotSupportedException无法将类型“System.DateTime”强制转换为类型“System.Object”
    本文介绍了在使用LINQ to Entities时出现的NotSupportedException异常,该异常是由于无法将类型“System.DateTime”强制转换为类型“System.Object”所导致的。同时还介绍了相关的错误信息和解决方法。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 上图是InnoDB存储引擎的结构。1、缓冲池InnoDB存储引擎是基于磁盘存储的,并将其中的记录按照页的方式进行管理。因此可以看作是基于磁盘的数据库系统。在数据库系统中,由于CPU速度 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 本文介绍了如何使用MATLAB调用摄像头进行人脸检测和识别。首先需要安装扩展工具,并下载安装OS Generic Video Interface。然后使用MATLAB的机器视觉工具箱中的VJ算法进行人脸检测,可以直接调用CascadeObjectDetector函数进行检测。同时还介绍了如何调用摄像头进行人脸识别,并对每一帧图像进行识别。最后,给出了一些相关的参考资料和实例。 ... [详细]
  • Parity game(poj1733)题解及思路分析
    本文是对题目"Parity game(poj1733)"的解题思路进行分析。题目要求判断每次给出的区间内1的个数是否和之前的询问相冲突,如果冲突则结束。本文首先介绍了离线算法的思路,然后详细解释了带权并查集的基本操作。同时,本文还对异或运算进行了学习,并给出了具体的操作步骤。最后,本文给出了完整的代码实现,并进行了测试。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 花瓣|目标值_Compose 动画边学边做夏日彩虹
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了Compose动画边学边做-夏日彩虹相关的知识,希望对你有一定的参考价值。引言Comp ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
author-avatar
天高云淡-tgyd
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有