python编程设计GAN生成对抗网络报错,用的tensorflow框架,mnist数据集。
代码如下:
#导入相关的库
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
#import tensorflow as tf
import keras
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Input,Dense,Dropout,Activation,Flatten,Conv2D,Conv2DTranspose,BatchNormalization,LeakyReLU,Conv1D, GRU
from tensorflow.keras.optimizers import Adam,RMSprop
import numpy as np
import matplotlib.pyplot as plt
from keras.utils.np_utils import to_categorical
#import tensorflow as tf
#from tensorflow.keras.layers import Dense, Conv2D, Flatten, Conv1D, GRU
from tensorflow.keras.models import Sequential
#记录开始时间
import time
start =time.time()
#导入数据并预处理,这里直接使用MNIST数据集
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784)
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255
#设置噪声值
z_dim = np.array([None,100])
f_dim=tf.reshape((2,1,1))
#设置优化器
adam = Adam(lr=0.0002,beta_1=0.5)
#搭建生成器模型
g = Sequential()
g.add(GRU(256,input_shape=f_dim,activation='relu'))
g.add(Dense(512,activation='relu'))
g.add(Dense(1024,activation='relu'))
g.add(Dense(2048,activation='relu'))
g.add(Dense(784,activation='sigmoid'))
g.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
#搭建鉴别器模型,这里先把D模型设置为不可训练,因为我们需要先训练G
d = Sequential()
d.add(Conv1D(1024,kernel_size=3, strides=2, padding='same', activation='relu'))
d.add(Dropout(0.3))
d.add(Conv1D(512,kernel_size=3, strides=2, padding='same', activation='relu'))
d.add(Dropout(0.3))
d.add(Conv1D(256,kernel_size=3, strides=2, padding='same', activation='relu'))
d.add(Dropout(0.3))
d.add(Conv1D(64,kernel_size=3, strides=2, padding='same', activation='relu'))
d.add(Dropout(0.3))
d.add(Dense(1,activation='sigmoid'))
d.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
d.trainable = False
#把两个模型连接起来组成生成对抗网络
inputs = Input(shape=(None,100))
hidden = g(inputs)
output = d(hidden)
gan = Model(inputs,output)
gan.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
#写两个函数用于最后输出损失值和生成器生成的图片
def plot_loss(losses):
d_loss = [v[0] for v in losses["D"]]
g_loss = [v[0] for v in losses["G"]]
plt.figure(figsize=(10,8))
plt.plot(d_loss,label="Discriminator_loss")
plt.plot(g_loss,label="Generator_loss")
plt.legend()
plt.show()
def plot_generatored(n_ex=10,dim=(1,10),figsize=(12,2)):
#####
#np.random.seed(time.time())
noise = np.random.normal(0,1,size=(n_ex,(None,100)))
generatored_images = g.predict(noise)
generatored_images = generatored_images.reshape(n_ex,28,28)
plt.figure(figsize = figsize)
for i in range(generatored_images.shape[0]):
plt.subplot(dim[0],dim[1],i+1)
plt.imshow(generatored_images[i],interpolation='nearest',cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.show()
#设置一个字典用于保存损失值
losses = {"D":[],"G":[]}
#下面是train函数
def train(epochs=1,plt_frq=1,BATCH_SIZE=128):
batchCount = int(x_train.shape[0]/BATCH_SIZE)
print("Epochs:",epochs)
print("Batch size:",BATCH_SIZE)
print("Batches per epoch:",batchCount)
for e in range(1,epochs+1):
if e == 1 or e%plt_frq == 0:
print('-'*15,'Epoch %d' %e,'-'*15)
for _ in range(batchCount):
image_batch = x_train[np.random.randint(0,x_train.shape[0],size=BATCH_SIZE)]
########
#np.random.seed(time.time())
noise = np.random.normal(0,1,size=(BATCH_SIZE,(None,100)))
generatored_images = g.predict(noise)
#train d
#set data set which is composed of 2 parts
x = np.concatenate((image_batch, generatored_images))
#y are labels
y = np.zeros(2*BATCH_SIZE)
y[:BATCH_SIZE] = 0.9
d.trainable = True
d_loss = d.train_on_batch(x,y)
#train g
#set up data set
noise = np.random.normal(0,1,size=(BATCH_SIZE,(None,100)))
y2 = np.ones(BATCH_SIZE)
d.trainable = False
g_loss = gan.train_on_batch(noise,y2)
losses["D"].append(d_loss)
losses["G"].append(g_loss)
if e==1 or e%plt_frq==0:
plot_generatored()
plot_loss(losses)
#开始训练
train(200,20,128)
#记录结束时间,并输出所用时间
end=time.time()
print('Running time: %s Seconds'%(end-start))
想用tf.reshape函数二维转三维,但是报错了。
D:\Software\Anaconda\envs\test02\python.exe E:/PythonProject/test02/gan2.py
Using TensorFlow backend.
Traceback (most recent call last):
File "E:/PythonProject/test02/gan2.py", line 49, in
f_dim=tf.reshape((2,1,1))
TypeError: reshape() missing 1 required positional argument: 'shape'
进程已结束,退出代码 1
截图如下: