热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

后端session前端访问_Tensorflow源码剖析:session创建

在TensorFlow中,用户是通过运行图来进行模型训练的,而启动图的第一步就是创建一个session对象。在日常编写Python代码时,

TensorFlow中,用户是通过运行图来进行模型训练的,而启动图的第一步就是创建一个session对象。在日常编写Python代码时,有的直接通过编写sess=tf.Session()来创建session,也有的在分布式TensorFlow中通过ChiefSessionCreatorWorkerSessionCreatorcreate_session()来创建session。这里简单说明下,create_session()实质上对tf.Session()的封装,只是里面添加了很多其他的功能,后期会对SessionCreator进行详细的介绍。鉴于前期读者反馈说看不大懂,所以今天,谱哥主要是想带大家来了解下sess=tf.Session()背后的实现原理,并介绍allocatorsession创建时在哪里有体现。

 TensorFlow系统分为前端系统和后端系统,前端系统提供编程模型,重点负责图的构造,目前主流编程语言是Python;后端系统主要负责图的执行,用C++语言来进行编写;Swig作为前端系统和后端系统建立连接的桥梁,使得前端Python创建session能够触发后端C++进行session创建。因此,接下来,将按照前端Python层、Swig以及后端C++层三个方面来详细说明sess=tf.Session()底部实现原理。1. 前端:Python层在前端系统中,session相关类的继承关系如下所示: 

b97dcc4b4f961442b45cd943c44edbcc.png

从中可知,session分为两种,普通Session和交互式InteractiveSession。后者自带with上下文管理器,并且在初始化的时候将自身作为默认的session,因此适合在Python交互式环境下使用。普通Session和交互式InteractiveSession都继承BaseSession,BaseSession继承SessionInterface。当用户层执行sess=tf.Session()时,会依次调用SessionInterfaceBaseSessionSession的初始化函数。在BaseSession的初始化函数中有以下几行代码:

from tensorflow.python import pywrap_tensorflow as tf_sessionclass BaseSession(SessionInterface): def __init__(self, target='', graph=None, config=None): ...... self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts)由上可知,session是在BaseSession初始化的时候执行tf_session.TF_NewSession()来创建,传入的参数opts通过tf_session.TF_NewSessionOptions创建,是一个SessionOptions结构体,完成envtargetconfig的简单封装。target参数主要用来判断是创建DirectSession还是GrpcSession

struct SessionOptions { Env* env; string target; ConfigProto config; SessionOptions();};tf_session实质指的是pywrap_tensorflow.py模块,该模块内部导入了pywrap_tensorflow_internal.py模块。而pywrap_tensorflow_internal.py是在系统启动Swig的时候通过tensorflow.i自动生成的适配文件,因此要想知道tf_session.TF_NewSession()内部到底干了啥,需要了解TensorfFlow是怎样使用Swig2. Swig包装器TensorFlow启动Swig的时候,会通过tensorflow.i生成两个文件: (1) pywrap_tensorflow_internal.py:对接前端python接口调用;
(2) pywrap_tensorflow_internal.cc:对接后端C++接口调用。 pywrap_tensorflow_internal.py模块在pywrap_tensorflow.py模块中被导入的时候,会自动加载_pywrap_tensorflow_internal.so的动态链接库,该库包含了整个TensorFlow运行时的所有符号。因此,在pywrap_tensorflow_internal.py模块中,可以通过_pywrap_tensorflow_internal转发,实现Python接口到_pywrap_tensorflow_internal.so的函数调用。

def TF_NewSession(graph, opts): return _pywrap_tensorflow_internal.TF_NewSession(graph, opts)TF_NewSession = _pywrap_tensorflow_internal.TF_NewSessionpywrap_tensorflow_internal.cc注册了一个函数符号表,实现Python函数到C函数名的映射。

{ (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},_wrap_TF_NewSession将调用c_api.h对其开放的API接口:TF_NewSession,从而进入系统后端C++层。3. 后端:C++层3.1 TF_NewSessionc_api.cc文件中,TF_NewSession()相关定义如下所示:

TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; //创建session status->status = NewSession(opt->options, &session); if (status->status.ok()) { //创建TF_Session对象,实现session和graph的绑定 TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); graph->sessions[new_session] = ""; } //返回TF_Session对象 return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; }}从上可以看出,TF_NewSession()干了两件事儿:(1)创建了session对象;(2)创建TF_Session对象,实现sessiongraph的绑定,并最终返回TF_Session,而不是session。 TF_Session的相关定义如下:

struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; TF_Graph* const graph; int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call // ExtendSessionGraphHelper before running the graph std::atomic extend_before_run;};

3.2 NewSession

那么NewSession()又是怎么定义的呢?于是追踪到了session.cc文件,相关代码如下所示:

Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; const Status s &#61; SessionFactory::GetFactory(options, &factory); if (!s.ok()) { *out_session &#61; nullptr; LOG(ERROR) <NewSession(options); if (!*out_session) { return errors::Internal("Failed to create session."); } return Status::OK();}可知&#xff0c;NewSession采用了工厂模式&#xff0c;先根据options去找出符合要求的工厂factory&#xff0c;然后在指定的工厂里创建Session3.3 SessionFactory::GetFactory 可能大家又会问又是如何去根据options找出对应的factory&#xff1f;于是&#xff0c;追踪到了session_factory.cc文件&#xff0c;在剖析GetFactory()之前&#xff0c;需要先来理解session_factories()的相关概念&#xff0c;如下所示&#xff1a;

typedef std::unordered_map SessionFactories;SessionFactories* session_factories() { static SessionFactories* factories &#61; new SessionFactories; return factories;}可知&#xff0c;session_factories()其实是创建了一个静态的SessionFactories&#xff0c;而这个SessionFactories是一个unordered_map&#xff0c;实现string类型的runtime_typeSessionFactory指针的映射。那么既然是unordered_map&#xff0c;就必然涉及到keyvalue的存储&#xff0c;在这里是通过SessionFactory::Register()来进行注册的。SessionFactory::Register()的相关定义如下所示&#xff1a;

void SessionFactory::Register(const string& runtime_type, SessionFactory* factory) { mutex_lock l(*get_session_factory_lock()); if (!session_factories()->insert({runtime_type, factory}).second) { ...... }}由上可知&#xff0c;SessionFactory::Register()的本质就是将runtime_typefactory指针插入到unordered_map中。紧接着&#xff0c;我追踪到direct_session.ccgrpc_session.cc文件&#xff0c;发现了相关session的注册。 DirectSessionFactory的注册如下所示&#xff1a;

class DirectSessionRegistrar { public: DirectSessionRegistrar() { SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); }};static DirectSessionRegistrar registrar;GrpcSessionFactory的注册如下所示&#xff1a;

class GrpcSessionRegistrar { public: GrpcSessionRegistrar() { SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory()); }};static GrpcSessionRegistrar registrar;现在来看看SessionFactory::GetFactory()核心代码&#xff0c;如下所示&#xff1a;

Status SessionFactory::GetFactory(const SessionOptions& options, SessionFactory** out_factory) { std::vector<:pair sessionfactory>> candidate_factories; for (const auto& session_factory : *session_factories()) { if (session_factory.second->AcceptsOptions(options)) { candidate_factories.push_back(session_factory); } } if (candidate_factories.size() &#61;&#61; 1) { *out_factory &#61; candidate_factories[0].second; return Status::OK(); } else if (candidate_factories.size() > 1) { //报错 }}GetFactory()代码中可知&#xff0c;其本质就是遍历session_factories()中的unordered_map&#xff0c;然后通过unordered_map中的SessionFactory(如DirectSessionFactoryGrpcSessionFactory)是否AcceptsOptions来进行选择&#xff0c;并且硬性要求有且仅有一个factory满足要求&#xff0c;否则报错。不同SessionFactoryAcceptsOptions()的定义如下&#xff1a; DirectSessionFactory:

bool AcceptsOptions(const SessionOptions& options) override { return options.target.empty(); }从上可知&#xff0c;若options.target为空&#xff0c;则应选择DirectSessionFactory,用于本地训练。 GrpcSessionFactory:

const char* const kSchemePrefix &#61; "grpc://";bool AcceptsOptions(const SessionOptions& options) override { return str_util::StartsWith(options.target, kSchemePrefix); }从上可知&#xff0c;若options.target是以grpc://开头的&#xff0c;则应选择GrpcSessionFactory&#xff0c;用于分布式TensorFlow3.4 factory->NewSession延续3.2节所讲&#xff0c;根据3.3节SessionFactory::GetFactory返回值的SessionFactory类型&#xff0c;去调用对应SessionFactoryNewSession接口。这里以DirectSessionFactory::NewSession为例&#xff0c;代码如下&#xff1a;

Session* NewSession(const SessionOptions& options) override { // Must do this before the CPU allocator is created. if (options.config.graph_options().build_cost_model() > 0) { EnableCPUAllocatorFullStats(true); } std::vector devices; const Status s &#61; DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices); if (!s.ok()) { LOG(ERROR) <从上可知&#xff0c;DirectSessionFactory::NewSession()不单单只是创建DirectSession&#xff0c;还要完成相关的devices收集。相关设备收集通过调用 DeviceFactory::AddDevices()来完成&#xff0c;相关代码在device_factory.cc中&#xff0c;如下所示&#xff1a;

Status DeviceFactory::AddDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { //先获取CPU对应的设备工厂cpu_factory auto cpu_factory &#61; GetFactory("CPU"); //创建设备并记录保存到devices TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices)); ... //遍历device_factories()&#xff0c;创建设备集&#xff0c;包括GPU for (auto& p : device_factories()) { auto factory &#61; p.second.factory.get(); if (factory !&#61; cpu_factory) { TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices)); } } return Status::OK();}

从上可知&#xff0c;DeviceFactory::AddDevices()也是采用工厂模式&#xff0c;主要完成的是遍历device_factories()&#xff0c;然后调用每个factory中的CreateDevices接口&#xff0c;创建设备并把相应指针存储到devices vector中。在此&#xff0c;有几个接口函数需要说明下&#xff1a;

DeviceFactory::GetFactory

DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { //根据device_type查找对应的DeviceFactory auto it &#61; device_factories().find(device_type); if (it &#61;&#61; device_factories().end()) { return nullptr; } return it->second.factory.get();}DeviceFactory::GetFactory()主要是通过输入参数device_typedevice_factories()中查找对应的设备工厂DeviceFactory。device_factories()的本质类似于session_factories()&#xff0c;函数里创建了一个静态的unordered_map&#xff0c;表示device_type到FactoryItem的映射。FactoryItem是个结构体&#xff0c;包括factory指针和相应的优先级。

struct FactoryItem { std::unique_ptr factory; int priority;};std::unordered_map& device_factories() { static std::unordered_map* factories &#61; new std::unordered_map; return *factories;}跟SessionFactory::Register()类似&#xff0c;既然涉及到对unordered_map的读取&#xff0c;那么肯定存在对key和value的存储操作。该操作主要是通过DeviceFactory::Register()接口来完成相关DeviceFactory的注册。在TensorFlow中&#xff0c;专门为此进行了宏定义&#xff0c;如下所示&#xff1a;

#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...)INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, __COUNTER__, ##__VA_ARGS__)以下是相关DeviceFactory注册的部分代码&#xff0c;分别在threadpool_device_factory.cc和gpu_device_factory.cc文件中。从这也可以看出&#xff0c;GPUDeviceFactory的优先级要明显高于ThreadPoolDeviceFactory

REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory, 60);REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 70);REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory, 210);CreateDevices 遍历device_factories()unordered_map的时候&#xff0c;都会让每个DeviceFactory调用设备创建CreateDevices()&#xff0c;并存储到std::vector* devices中。下面以ThreadPoolDeviceFactory::CreateDevices为例来介绍其具体细节。

Status ThreadPoolDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) override { int n &#61; 1; auto iter &#61; options.config.device_count().find("CPU"); if (iter !&#61; options.config.device_count().end()) { n &#61; iter->second; } for (int i &#61; 0; i push_back(new ThreadPoolDevice( options, name, Bytes(256 <<20), DeviceLocality(), cpu_allocator())); } return Status::OK(); }从第9行可以看出&#xff0c;ThreadPoolDeviceFactory::CreateDevices主要是创建了一个ThreadPoolDevice&#xff0c;ThreadPoolDevice中存有一个allocator用来分配和释放内存。该allocatorcpu_allocator()来获取。cpu_allocator()相关定义在allocator.cc文件中&#xff0c;如下所示&#xff1a;

Allocator* cpu_allocator() { static Allocator* cpu_alloc &#61; AllocatorRegistry::Global()->GetAllocator(); if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { cpu_alloc &#61; new TrackingAllocator(cpu_alloc, true); } return cpu_alloc;}不知道大家现在对第2行的接口有没有丝丝熟悉感&#xff1f;对的&#xff0c;在Allocator(基础篇)已经对AllocatorRegistry相关接口进行了说明。从第2行可看出&#xff0c;每次执行AllocatorRegistry::Global()->GetAllocator()都会返回AllocatorRegistry当前优先级最高的allocator。如果该allocator想收集状态但是TracksAllocationSizes()又为false&#xff0c;那么就可以对该allocator进行封装&#xff0c;在此基础上创建一个TrackingAllocator即可进行记录追踪。当然&#xff0c;如果想知道底层是在哪里使用了BFCAllocator&#xff0c;则推荐阅读GPUDeviceFactory::CreateDevices&#xff0c;这里不再做过多说明。4. 总结本篇根据前端Python层、Swig以及后端C&#43;&#43;层三个方面来详细说明sess&#61;tf.Session()底部实现原理。前端Python层介绍了SessionBaseSession等的概念和相互联系&#xff1b;Swig主要完成将Python层的session创建转发到C&#43;&#43;层的session创建&#xff1b;后端C&#43;&#43;session创建根据SessionOptions找到相应的SessionFactory来执行NewSession操作。而NewSession函数不仅要完成session创建&#xff0c;还要根据DeviceFactory完成设备的创建并收集&#xff0c;这里自然离不开各种allocator




推荐阅读
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • Java学习笔记之面向对象编程(OOP)
    本文介绍了Java学习笔记中的面向对象编程(OOP)内容,包括OOP的三大特性(封装、继承、多态)和五大原则(单一职责原则、开放封闭原则、里式替换原则、依赖倒置原则)。通过学习OOP,可以提高代码复用性、拓展性和安全性。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • PHP反射API的功能和用途详解
    本文详细介绍了PHP反射API的功能和用途,包括动态获取信息和调用对象方法的功能,以及自动加载插件、生成文档、扩充PHP语言等用途。通过反射API,可以获取类的元数据,创建类的实例,调用方法,传递参数,动态调用类的静态方法等。PHP反射API是一种内建的OOP技术扩展,通过使用Reflection、ReflectionClass和ReflectionMethod等类,可以帮助我们分析其他类、接口、方法、属性和扩展。 ... [详细]
  • 代码如下:#coding:utf-8importstring,os,sysimportnumpyasnpimportmatplotlib.py ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • 计算机存储系统的层次结构及其优势
    本文介绍了计算机存储系统的层次结构,包括高速缓存、主存储器和辅助存储器三个层次。通过分层存储数据可以提高程序的执行效率。计算机存储系统的层次结构将各种不同存储容量、存取速度和价格的存储器有机组合成整体,形成可寻址存储空间比主存储器空间大得多的存储整体。由于辅助存储器容量大、价格低,使得整体存储系统的平均价格降低。同时,高速缓存的存取速度可以和CPU的工作速度相匹配,进一步提高程序执行效率。 ... [详细]
  • 本文介绍了UVALive6575题目Odd and Even Zeroes的解法,使用了数位dp和找规律的方法。阶乘的定义和性质被介绍,并给出了一些例子。其中,部分阶乘的尾零个数为奇数,部分为偶数。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了在使用Python中的aiohttp模块模拟服务器时出现的连接失败问题,并提供了相应的解决方法。文章中详细说明了出错的代码以及相关的软件版本和环境信息,同时也提到了相关的警告信息和函数的替代方案。通过阅读本文,读者可以了解到如何解决Python连接服务器失败的问题,并对aiohttp模块有更深入的了解。 ... [详细]
  • SpringBoot简单日志配置
     在生产环境中,只打印error级别的错误,在测试环境中,可以调成debugapplication.properties文件##默认使用logbacklogging.level.r ... [详细]
author-avatar
厉害了
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有