作者:micheals | 来源:互联网 | 2023-09-17 07:11
01JIT编译方式《TensorFlow技术解析与实战》15TensorFlow线性代数编译框架XLA通过XLA运行TensorFlow计算有两种方法,一是打开CPU或GPU设
01 JIT编译方式
《TensorFlow技术解析与实战》15 TensorFlow线性代数编译框架XLA
通过XLA运行TensorFlow计算有两种方法,一是打开CPU或GPU设备上的JIT编译,二是将操作符放在XLA_CPU或XLA_GPU设备上。
01.01 打开JIT编译的两种方式
# 方式1
cOnfig=tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(cOnfig=config)
# 方式2
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
x = tf.placeholder(np.float32)
with jit_scope():
y = tf.add(x, x)
01.02 将操作符放在XLA设备上
with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
output = tf.add(input1, input2)
02 测试方法
# 《TensorFlow技术解析与实战》15 TensorFlow线性代数编译框架XLA
# win10 Tensorflow-gpu1.2.0-rc0 python3.5.3
# CUDA v8.0 cudnn-8.0-windows10-x64-v5.1
# https://github.com/tensorflow/tensorflow/blob/v1.2.0-rc0/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
# 测试方法:
# 00 cd tensorflow\tensorflow\examples\tutorials\mnist
# 01 python mnist_softmax_xla.py --xla=false
# 02
set TF_XLA_FLAGS=--xla_generate_hlo_graph=.*
python mnist_softmax_xla.py
03 python mnist_softmax_xla.py 代码
"""Simple MNIST classifier example with JIT XLA and timelines.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.client import timeline
FLAGS = None
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, w) + b
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
cOnfig= tf.ConfigProto()
jit_level = 0
if FLAGS.xla:
jit_level = tf.OptimizerOptions.ON_1
config.graph_options.optimizer_options.global_jit_level = jit_level
run_metadata = tf.RunMetadata()
sess = tf.Session(cOnfig=config)
tf.global_variables_initializer().run(session=sess)
train_loops = 1000
for i in range(train_loops):
batch_xs, batch_ys = mnist.train.next_batch(100)
if i == train_loops - 1:
sess.run(train_step,
feed_dict={x: batch_xs,
y_: batch_ys},
optiOns=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
with open('timeline.ctf.json', 'w') as trace_file:
trace_file.write(trace.generate_chrome_trace_format())
else:
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy,
feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))
sess.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
parser.add_argument(
'--xla', type=bool, default=True, help='Turn xla via JIT on')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)