在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同時(shí)新建文件夾 logs 和文件夾 samples,前者用來保存訓(xùn)練過程中的日志和模型,后者用來保存訓(xùn)練過程中采樣器的采樣圖片,在 train.py 中輸入如下代碼:
# -*- coding: utf-8 -*-import tensorflow as tfimport osfrom read_data import *from utils import *from ops import *from model import *from model import BATCH_SIZEdef train(): # 設(shè)置 global_step ,用來記錄訓(xùn)練過程中的 step global_step = tf.Variable(0, name = 'global_step', trainable = False) # 訓(xùn)練過程中的日志保存文件 train_dir = '/home/your_name/TensorFlow/DCGAN/logs' # 放置三個(gè) placeholder,y 表示約束條件,images 表示送入判別器的圖片, # z 表示隨機(jī)噪聲 y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y') images = tf.placeholder(tf.float32, [64, 28, 28, 1], name='real_images') z = tf.placeholder(tf.float32, [None, 100], name='z') # 由生成器生成圖像 G G = generator(z, y) # 真實(shí)圖像送入判別器 D, D_logits = discriminator(images, y) # 采樣器采樣圖像 samples = sampler(z, y) # 生成圖像送入判別器 D_, D_logits_ = discriminator(G, y, reuse = True)