在TensorFlow
中,用户是通过运行图来进行模型训练的,而启动图的第一步就是创建一个session
对象。在日常编写Python
代码时,有的直接通过编写sess=tf.Session()
来创建session
,也有的在分布式TensorFlow
中通过ChiefSessionCreator
和WorkerSessionCreator
的create_session()
来创建session
。这里简单说明下,create_session()
实质上对tf.Session()
的封装,只是里面添加了很多其他的功能,后期会对SessionCreator
进行详细的介绍。鉴于前期读者反馈说看不大懂,所以今天,谱哥主要是想带大家来了解下sess=tf.Session()
背后的实现原理,并介绍allocator
在session
创建时在哪里有体现。
TensorFlow
系统分为前端系统和后端系统,前端系统提供编程模型,重点负责图的构造,目前主流编程语言是Python
;后端系统主要负责图的执行,用C++语言来进行编写;Swig
作为前端系统和后端系统建立连接的桥梁,使得前端Python
创建session
能够触发后端C++进行session
创建。因此,接下来,将按照前端Python
层、Swig
以及后端C++层三个方面来详细说明sess=tf.Session()
底部实现原理。1. 前端:Python层在前端系统中,session相关类的继承关系如下所示: 从中可知,session
分为两种,普通Session
和交互式InteractiveSession
。后者自带with上下文管理器,并且在初始化的时候将自身作为默认的session
,因此适合在Python
交互式环境下使用。普通Session
和交互式InteractiveSession
都继承BaseSession
,BaseSession
继承SessionInterface
。当用户层执行sess=tf.Session()
时,会依次调用SessionInterface
、BaseSession
和Session
的初始化函数。在BaseSession
的初始化函数中有以下几行代码:from tensorflow.python import pywrap_tensorflow as tf_sessionclass BaseSession(SessionInterface): def __init__(self, target='', graph=None, config=None): ...... self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts)
由上可知,session是在BaseSession初始化的时候执行tf_session.TF_NewSession()来创建,传入的参数opts
通过tf_session.TF_NewSessionOptions
创建,是一个SessionOptions
结构体,完成env
、target
和config
的简单封装。target
参数主要用来判断是创建DirectSession
还是GrpcSession
。
struct SessionOptions { Env* env; string target; ConfigProto config; SessionOptions();};
tf_session
实质指的是pywrap_tensorflow.py
模块,该模块内部导入了pywrap_tensorflow_internal.py
模块。而pywrap_tensorflow_internal.py
是在系统启动Swig
的时候通过tensorflow.i
自动生成的适配文件,因此要想知道tf_session.TF_NewSession()
内部到底干了啥,需要了解TensorfFlow
是怎样使用Swig
。2. Swig包装器在TensorFlow
启动Swig
的时候,会通过tensorflow.i
生成两个文件: (1) pywrap_tensorflow_internal.py:对接前端python接口调用;
(2) pywrap_tensorflow_internal.cc:对接后端C++接口调用。 pywrap_tensorflow_internal.py
模块在pywrap_tensorflow.py
模块中被导入的时候,会自动加载_pywrap_tensorflow_internal.so
的动态链接库,该库包含了整个TensorFlow
运行时的所有符号。因此,在pywrap_tensorflow_internal.py
模块中,可以通过_pywrap_tensorflow_internal
转发,实现Python
接口到_pywrap_tensorflow_internal.so
的函数调用。
def TF_NewSession(graph, opts): return _pywrap_tensorflow_internal.TF_NewSession(graph, opts)TF_NewSession = _pywrap_tensorflow_internal.TF_NewSession
pywrap_tensorflow_internal.cc
注册了一个函数符号表,实现Python函数到C函数名的映射。
{ (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},
而_wrap_TF_NewSession
将调用c_api.h
对其开放的API接口:TF_NewSession
,从而进入系统后端C++层。3. 后端:C++层3.1 TF_NewSession在c_api.cc
文件中,TF_NewSession()
相关定义如下所示:
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; //创建session status->status = NewSession(opt->options, &session); if (status->status.ok()) { //创建TF_Session对象,实现session和graph的绑定 TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); graph->sessions[new_session] = ""; } //返回TF_Session对象 return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; }}
从上可以看出,TF_NewSession()
干了两件事儿:(1)创建了session
对象;(2)创建TF_Session
对象,实现session
和graph
的绑定,并最终返回TF_Session,而不是session。 TF_Session
的相关定义如下:
struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; TF_Graph* const graph; int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call // ExtendSessionGraphHelper before running the graph std::atomic extend_before_run;};
NewSession()
又是怎么定义的呢?于是追踪到了session.cc
文件&#xff0c;相关代码如下所示&#xff1a;Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; const Status s &#61; SessionFactory::GetFactory(options, &factory); if (!s.ok()) { *out_session &#61; nullptr; LOG(ERROR) <
可知&#xff0c;NewSession(options); if (!*out_session) { return errors::Internal("Failed to create session."); } return Status::OK();}NewSession
采用了工厂模式&#xff0c;先根据options
去找出符合要求的工厂factory
&#xff0c;然后在指定的工厂里创建Session
。3.3 SessionFactory::GetFactory 可能大家又会问又是如何去根据options
找出对应的factory
&#xff1f;于是&#xff0c;追踪到了session_factory.cc
文件&#xff0c;在剖析GetFactory()
之前&#xff0c;需要先来理解session_factories()
的相关概念&#xff0c;如下所示&#xff1a;
typedef std::unordered_map SessionFactories;SessionFactories* session_factories() { static SessionFactories* factories &#61; new SessionFactories; return factories;}
可知&#xff0c;session_factories()
其实是创建了一个静态的SessionFactories
&#xff0c;而这个SessionFactories
是一个unordered_map
&#xff0c;实现string
类型的runtime_type
到SessionFactory
指针的映射。那么既然是unordered_map
&#xff0c;就必然涉及到key
和value
的存储&#xff0c;在这里是通过SessionFactory::Register()
来进行注册的。SessionFactory::Register()
的相关定义如下所示&#xff1a;
void SessionFactory::Register(const string& runtime_type, SessionFactory* factory) { mutex_lock l(*get_session_factory_lock()); if (!session_factories()->insert({runtime_type, factory}).second) { ...... }}
由上可知&#xff0c;SessionFactory::Register()
的本质就是将runtime_type
和factory
指针插入到unordered_map
中。紧接着&#xff0c;我追踪到direct_session.cc
和grpc_session.cc
文件&#xff0c;发现了相关session
的注册。 DirectSessionFactory
的注册如下所示&#xff1a;
class DirectSessionRegistrar { public: DirectSessionRegistrar() { SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); }};static DirectSessionRegistrar registrar;
GrpcSessionFactory
的注册如下所示&#xff1a;
class GrpcSessionRegistrar { public: GrpcSessionRegistrar() { SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory()); }};static GrpcSessionRegistrar registrar;
现在来看看SessionFactory::GetFactory()
核心代码&#xff0c;如下所示&#xff1a;
Status SessionFactory::GetFactory(const SessionOptions& options, SessionFactory** out_factory) { std::vector<:pair sessionfactory>> candidate_factories; for (const auto& session_factory : *session_factories()) { if (session_factory.second->AcceptsOptions(options)) { candidate_factories.push_back(session_factory); } } if (candidate_factories.size() &#61;&#61; 1) { *out_factory &#61; candidate_factories[0].second; return Status::OK(); } else if (candidate_factories.size() > 1) { //报错 }}
从GetFactory()
代码中可知&#xff0c;其本质就是遍历session_factories()
中的unordered_map
&#xff0c;然后通过unordered_map
中的SessionFactory
(如DirectSessionFactory
和GrpcSessionFactory
)是否AcceptsOptions
来进行选择&#xff0c;并且硬性要求有且仅有一个factory
满足要求&#xff0c;否则报错。不同SessionFactory
的AcceptsOptions()
的定义如下&#xff1a; DirectSessionFactory:
bool AcceptsOptions(const SessionOptions& options) override { return options.target.empty(); }
从上可知&#xff0c;若options.target
为空&#xff0c;则应选择DirectSessionFactory
,用于本地训练。 GrpcSessionFactory:
const char* const kSchemePrefix &#61; "grpc://";bool AcceptsOptions(const SessionOptions& options) override { return str_util::StartsWith(options.target, kSchemePrefix); }
从上可知&#xff0c;若options.target
是以grpc://
开头的&#xff0c;则应选择GrpcSessionFactory
&#xff0c;用于分布式TensorFlow
。3.4 factory->NewSession延续3.2节所讲&#xff0c;根据3.3节SessionFactory::GetFactory
返回值的SessionFactory
类型&#xff0c;去调用对应SessionFactory
的NewSession
接口。这里以DirectSessionFactory::NewSession
为例&#xff0c;代码如下&#xff1a;
从上可知&#xff0c;Session* NewSession(const SessionOptions& options) override { // Must do this before the CPU allocator is created. if (options.config.graph_options().build_cost_model() > 0) { EnableCPUAllocatorFullStats(true); } std::vector devices; const Status s &#61; DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices); if (!s.ok()) { LOG(ERROR) <
从上可知&#xff0c;DirectSessionFactory::NewSession()
不单单只是创建DirectSession
&#xff0c;还要完成相关的devices
收集。相关设备收集通过调用 DeviceFactory::AddDevices()
来完成&#xff0c;相关代码在device_factory.cc
中&#xff0c;如下所示&#xff1a;Status DeviceFactory::AddDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { //先获取CPU对应的设备工厂cpu_factory auto cpu_factory &#61; GetFactory("CPU"); //创建设备并记录保存到devices TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices)); ... //遍历device_factories()&#xff0c;创建设备集&#xff0c;包括GPU for (auto& p : device_factories()) { auto factory &#61; p.second.factory.get(); if (factory !&#61; cpu_factory) { TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices)); } } return Status::OK();}
DeviceFactory::AddDevices()
也是采用工厂模式&#xff0c;主要完成的是遍历device_factories()
&#xff0c;然后调用每个factory
中的CreateDevices
接口&#xff0c;创建设备并把相应指针存储到devices vector
中。在此&#xff0c;有几个接口函数需要说明下&#xff1a;DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { //根据device_type查找对应的DeviceFactory auto it &#61; device_factories().find(device_type); if (it &#61;&#61; device_factories().end()) { return nullptr; } return it->second.factory.get();}
DeviceFactory::GetFactory()
主要是通过输入参数device_type
在device_factories()
中查找对应的设备工厂DeviceFactory
。device_factories()的本质类似于session_factories()&#xff0c;函数里创建了一个静态的unordered_map&#xff0c;表示device_type到FactoryItem的映射。FactoryItem是个结构体&#xff0c;包括factory指针和相应的优先级。struct FactoryItem { std::unique_ptr factory; int priority;};std::unordered_map& device_factories() { static std::unordered_map* factories &#61; new std::unordered_map; return *factories;}
跟SessionFactory::Register()类似&#xff0c;既然涉及到对unordered_map的读取&#xff0c;那么肯定存在对key和value的存储操作。该操作主要是通过DeviceFactory::Register()
接口来完成相关DeviceFactory
的注册。在TensorFlow中&#xff0c;专门为此进行了宏定义&#xff0c;如下所示&#xff1a;#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...)INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, __COUNTER__, ##__VA_ARGS__)
以下是相关DeviceFactory注册的部分代码&#xff0c;分别在threadpool_device_factory.cc和gpu_device_factory.cc文件中。从这也可以看出&#xff0c;GPUDeviceFactory
的优先级要明显高于ThreadPoolDeviceFactory
。REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory, 60);REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 70);REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory, 210);
CreateDevices 遍历device_factories()
里unordered_map
的时候&#xff0c;都会让每个DeviceFactory
调用设备创建CreateDevices()
&#xff0c;并存储到std::vector* devices
中。下面以ThreadPoolDeviceFactory::CreateDevices为例来介绍其具体细节。Status ThreadPoolDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) override { int n &#61; 1; auto iter &#61; options.config.device_count().find("CPU"); if (iter !&#61; options.config.device_count().end()) { n &#61; iter->second; } for (int i &#61; 0; i
从第9行可以看出&#xff0c;ThreadPoolDeviceFactory::CreateDevices
主要是创建了一个ThreadPoolDevice
&#xff0c;ThreadPoolDevice
中存有一个allocator
用来分配和释放内存。该allocator
由cpu_allocator()
来获取。cpu_allocator()
相关定义在allocator.cc
文件中&#xff0c;如下所示&#xff1a;Allocator* cpu_allocator() { static Allocator* cpu_alloc &#61; AllocatorRegistry::Global()->GetAllocator(); if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { cpu_alloc &#61; new TrackingAllocator(cpu_alloc, true); } return cpu_alloc;}
不知道大家现在对第2行的接口有没有丝丝熟悉感&#xff1f;对的&#xff0c;在Allocator(基础篇)已经对AllocatorRegistry
相关接口进行了说明。从第2行可看出&#xff0c;每次执行AllocatorRegistry::Global()->GetAllocator()
都会返回AllocatorRegistry
当前优先级最高的allocator
。如果该allocator
想收集状态但是TracksAllocationSizes()
又为false&#xff0c;那么就可以对该allocator
进行封装&#xff0c;在此基础上创建一个TrackingAllocator
即可进行记录追踪。当然&#xff0c;如果想知道底层是在哪里使用了BFCAllocator
&#xff0c;则推荐阅读GPUDeviceFactory::CreateDevices
&#xff0c;这里不再做过多说明。4. 总结本篇根据前端Python
层、Swig
以及后端C&#43;&#43;
层三个方面来详细说明sess&#61;tf.Session()
底部实现原理。前端Python
层介绍了Session
和BaseSession
等的概念和相互联系&#xff1b;Swig
主要完成将Python
层的session
创建转发到C&#43;&#43;
层的session
创建&#xff1b;后端C&#43;&#43;
层session
创建根据SessionOptions
找到相应的SessionFactory
来执行NewSession
操作。而NewSession
函数不仅要完成session
创建&#xff0c;还要根据DeviceFactory
完成设备的创建并收集&#xff0c;这里自然离不开各种allocator
。