热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

不要怂,就是GAN(生成式对抗网络)(四):训练和测试GAN

在homeyour_nameTensorFlowDCGAN下新建文件train.py,同时新建文件夹logs和文件夹samples,前者用来保存训练过程中的日志和模型,后者用来保存

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保存训练过程中采样器的采样图片,在 train.py 中输入如下代码:

# -*- coding: utf-8 -*-
import tensorflow as tf
import os

from read_data import *
from utils import *
from ops import *
from model import *
from model import BATCH_SIZE


def train():

    # 设置 global_step ,用来记录训练过程中的 step        
    global_step = tf.Variable(0, name = 'global_step', trainable = False)
    # 训练过程中的日志保存文件
    train_dir = '/home/your_name/TensorFlow/DCGAN/logs'

    # 放置三个 placeholder,y 表示约束条件,images 表示送入判别器的图片,
    # z 表示随机噪声
    y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y')
    images = tf.placeholder(tf.float32, [64, 28, 28, 1], name='real_images')
    z = tf.placeholder(tf.float32, [None, 100], name='z')

    # 由生成器生成图像 G
    G = generator(z, y)
    # 真实图像送入判别器
    D, D_logits  = discriminator(images, y)
    # 采样器采样图像
    samples = sampler(z, y)
    # 生成图像送入判别器
    D_, D_logits_ = discriminator(G, y, reuse = True)
    
    # 损失计算
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)))
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)))
    d_loss = d_loss_real + d_loss_fake
    g_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)))

    # 总结操作
    z_sum = tf.histogram_summary("z", z)
    d_sum = tf.histogram_summary("d", D)
    d__sum = tf.histogram_summary("d_", D_)
    G_sum = tf.image_summary("G", G)

    d_loss_real_sum = tf.scalar_summary("d_loss_real", d_loss_real)
    d_loss_fake_sum = tf.scalar_summary("d_loss_fake", d_loss_fake)
    d_loss_sum = tf.scalar_summary("d_loss", d_loss)                                                
    g_loss_sum = tf.scalar_summary("g_loss", g_loss)
    
    # 合并各自的总结
    g_sum = tf.merge_summary([z_sum, d__sum, G_sum, d_loss_fake_sum, g_loss_sum])
    d_sum = tf.merge_summary([z_sum, d_sum, d_loss_real_sum, d_loss_sum])

    # 生成器和判别器要更新的变量,用于 tf.train.Optimizer 的 var_list
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'd_' in var.name]
    g_vars = [var for var in t_vars if 'g_' in var.name]

    saver = tf.train.Saver()
    
    # 优化算法采用 Adam
    d_optim = tf.train.AdamOptimizer(0.0002, beta1 = 0.5) \
                .minimize(d_loss, var_list = d_vars, global_step = global_step)
    g_optim = tf.train.AdamOptimizer(0.0002, beta1 = 0.5) \
                .minimize(g_loss, var_list = g_vars, global_step = global_step)
        
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.InteractiveSession(cOnfig=config)

    init = tf.initialize_all_variables()   
    writer = tf.train.SummaryWriter(train_dir, sess.graph)
    
    # 这个自己理解吧
    data_x, data_y = read_data()
    sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
#    sample_images = data_x[0: 64]
    sample_labels = data_y[0: 64]
    sess.run(init)    
    
    # 循环 25 个 epoch 训练网络
    for epoch in range(25):
        batch_idxs = 1093
        for idx in range(batch_idxs):        
            batch_images = data_x[idx*64: (idx+1)*64]
            batch_labels = data_y[idx*64: (idx+1)*64]
            batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))            
            
            # 更新 D 的参数
            _, summary_str = sess.run([d_optim, d_sum], 
                                      feed_dict = {images: batch_images, 
                                                   z: batch_z, 
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)

            # 更新 G 的参数
            _, summary_str = sess.run([g_optim, g_sum], 
                                      feed_dict = {z: batch_z, 
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)

            # 更新两次 G 的参数确保网络的稳定
            _, summary_str = sess.run([g_optim, g_sum], 
                                      feed_dict = {z: batch_z,
                                                   y: batch_labels})
            writer.add_summary(summary_str, idx+1)
            
            # 计算训练过程中的损失,打印出来
            errD_fake = d_loss_fake.eval({z: batch_z, y: batch_labels})
            errD_real = d_loss_real.eval({images: batch_images, y: batch_labels})
            errG = g_loss.eval({z: batch_z, y: batch_labels})

            if idx % 20 == 0:
                print("Epoch: [%2d] [%4d/%4d] d_loss: %.8f, g_loss: %.8f" \
                        % (epoch, idx, batch_idxs, errD_fake+errD_real, errG))
            
            # 训练过程中,用采样器采样,并且保存采样的图片到 
            # /home/your_name/TensorFlow/DCGAN/samples/
            if idx % 100 == 1:
                sample = sess.run(samples, feed_dict = {z: sample_z, y: sample_labels})
                samples_path = '/home/your_name/TensorFlow/DCGAN/samples/'
                save_images(sample, [8, 8], 
                            samples_path + 'test_%d_epoch_%d.png' % (epoch, idx))
                print 'save down'
            
            # 每过 500 次迭代,保存一次模型
            if idx % 500 == 2:
                checkpoint_path = os.path.join(train_dir, 'DCGAN_model.ckpt')
                saver.save(sess, checkpoint_path, global_step = idx+1)
                
    sess.close()


if __name__ == '__main__':
    train()    

 输入完成后点击运行,运行过程中,可以看到,生成的每个图片对应行对应列都是一样的数字,这是因为我们加了条件约束;采样器 sampler 采样的图片被保存在 samples 文件夹下,由模糊到清晰,由刚开始的噪声,慢慢变成手写字符,最后完全区分不出来是生成图片还是真实图片,反正我是区分不出来,you can you up。

 

不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN 不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

 与此同时,要是在训练的时候打开 TensorBoard,可以看到 D 的分布,大致在趋于 0.5 左右的附件徘徊,说明判别器 D 已经趋于判别不出来了,只能随机猜测,正确率大致 0.5。

 

不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

 

 

 

 

 

 

 

 

 

 

 

 

 

讲道理,我们的 GAN 到这一步,已经算是完成了,测试的过程,我们已经在训练的时候通过采样完成了,如果嫌不够,非要单独写个测试的文件,也不是不可以:

在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 eval.py 和文件夹 eval,eval 文件夹用来保存测试结果图片,在 eval.py 中输入如下代码:

 

# -*- coding: utf-8 -*-
import tensorflow as tf
import os

from read_data import *
from utils import *
from ops import *
from model import *
from model import BATCH_SIZE


def eval():
    # 用于存放测试图片
    test_dir = '/home/your_name/TensorFlow/DCGAN/eval/'
    # 从此处加载模型
    checkpoint_dir = '/home/your_name/TensorFlow/DCGAN/logs/'
    
    y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y')
    z = tf.placeholder(tf.float32, [None, 100], name='z')
    
    G = generator(z, y)    
    data_x, data_y = read_data()
    sample_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
    sample_labels = data_y[120: 184]
    
    # 读取 ckpt 需要 sess,saver
    print("Reading checkpoints...")
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    
    # saver
    saver = tf.train.Saver(tf.all_variables())
    
    # sess
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.2
    sess = tf.InteractiveSession(cOnfig=config)
    
    # 从保存的模型中恢复变量
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)        
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
    
    # 用恢复的变量进行生成器的测试
    test_sess = sess.run(G, feed_dict = {z: sample_z, y: sample_labels})
    
    # 保存测试的生成器图片到特定文件夹
    save_images(test_sess, [8, 8], test_dir + 'test_%d.png' % 500)
    
    sess.close()


if  __name__ == '__main__':

    eval()    

 点击运行,在 eval 文件夹下生成test_500.png 文件,可以看到,生成器 G 已经可以生成不错的结果。

 

训练测试完,可以打开 TensorBoard 查看网络的 Graph,可以看到,由于没有细致采用 namespace 和 variable_scope ,画出来的 Graph 比较凌乱,只能依稀的看出来网络的一些结构。

不要怂,就是GAN (生成式对抗网络) (四):训练和测试 GAN

 

至此,我们的 TensorFlow GAN 工作基本完成,细心的朋友会发现,我们的程序存在以下几个问题:

1)在写 eval() 函数的时候,对于生成函数 generator(),没有指定 train = False,也就是在 BN 层,没有体现出训练和测试的区别;

2)在我的这篇 http://www.cnblogs.com/Charles-Wan/p/6197019.html 博客中,提到了我采用了 tfrecords 进行 GAN 数据的输入处理,但是此程序并没有体现出来;

3)没有细致的采用 namespace 和 variable_scope ,画出来的 Graph 比较凌乱;

4)程序中太多不明含义的数字,路径名字全都采用绝对路径;

5)训练过程中不能断点续训练等。

针对以上问题,我们在下一节的不加约束 GAN 上将进行改进。

 

 

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

 


推荐阅读
  • 树莓派语音控制的配置方法和步骤
    本文介绍了在树莓派上实现语音控制的配置方法和步骤。首先感谢博主Eoman的帮助,文章参考了他的内容。树莓派的配置需要通过sudo raspi-config进行,然后使用Eoman的控制方法,即安装wiringPi库并编写控制引脚的脚本。具体的安装步骤和脚本编写方法在文章中详细介绍。 ... [详细]
  • 花瓣|目标值_Compose 动画边学边做夏日彩虹
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了Compose动画边学边做-夏日彩虹相关的知识,希望对你有一定的参考价值。引言Comp ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文介绍了在git中如何对指定的commit id打标签,并解决了忘记打标签的问题。通过查找历史提交的commit id,可以在任意时间点打上标签。同时,还介绍了git中的一些常用命令和操作。 ... [详细]
  • iOS超签签名服务器搭建及其优劣势
    本文介绍了搭建iOS超签签名服务器的原因和优势,包括不掉签、用户可以直接安装不需要信任、体验好等。同时也提到了超签的劣势,即一个证书只能安装100个,成本较高。文章还详细介绍了超签的实现原理,包括用户请求服务器安装mobileconfig文件、服务器调用苹果接口添加udid等步骤。最后,还提到了生成mobileconfig文件和导出AppleWorldwideDeveloperRelationsCertificationAuthority证书的方法。 ... [详细]
  • 本文介绍了使用Python编写购物程序的实现步骤和代码示例。程序启动后,用户需要输入工资,并打印商品列表。用户可以根据商品编号选择购买商品,程序会检测余额是否充足,如果充足则直接扣款,否则提醒用户。用户可以随时退出程序,在退出时打印已购买商品的数量和余额。附带了完整的代码示例。 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 图片复制到服务器 方向变了_双服务器热备更新配置文件步骤问题及解决方法
    本文介绍了在将图片复制到服务器并进行方向变换的过程中,双服务器热备更新配置文件所出现的问题及解决方法。通过停止所有服务、更新配置、重启服务等操作,可以避免数据中断和操作不规范导致的问题。同时还提到了注意事项,如Avimet版本的差异以及配置文件和批处理文件的存放路径等。通过严格执行切换步骤,可以成功进行更新操作。 ... [详细]
  • 本文讨论了Kotlin中扩展函数的一些惯用用法以及其合理性。作者认为在某些情况下,定义扩展函数没有意义,但官方的编码约定支持这种方式。文章还介绍了在类之外定义扩展函数的具体用法,并讨论了避免使用扩展函数的边缘情况。作者提出了对于扩展函数的合理性的质疑,并给出了自己的反驳。最后,文章强调了在编写Kotlin代码时可以自由地使用扩展函数的重要性。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文介绍了在CentOS上安装Python2.7.2的详细步骤,包括下载、解压、编译和安装等操作。同时提供了一些注意事项,以及测试安装是否成功的方法。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • Oracle优化新常态的五大禁止及其性能隐患
    本文介绍了Oracle优化新常态中的五大禁止措施,包括禁止外键、禁止视图、禁止触发器、禁止存储过程和禁止JOB,并分析了这些禁止措施可能带来的性能隐患。文章还讨论了这些禁止措施在C/S架构和B/S架构中的不同应用情况,并提出了解决方案。 ... [详细]
author-avatar
你一句话就逼我撤退
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有