weixin_45578223 2019-10-04 19:54 采纳率: 0%
浏览 2331

运行结果如下:train(generator,discriminator,gan_model,latent_dim) NameError: name 'train' is not defined,请问如何解决

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from numpy import hstack
from numpy import zeros
from numpy import ones
from numpy.random import rand
from numpy.random import randn
from keras.models import Sequential
from keras.layers import Dense
def define_discriminator(n_inputs=2):
model=Sequential()
model.add(Dense(25, activation='relu',kernel_initializer='he_uniform',input_dim=n_inputs))
model.add(Dense(1,activation='sigmoid'))
model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])
return model
def define_generator(latent_dim,n_outputs=2):
model=Sequential()
model.add(Dense(15, activation='relu',kernel_initializer='he_uniform', input_dim=latent_dim))
model.add(Dense(n_outputs,activation='linear'))
return model
def define_gan(generator,discriminator):
discriminator.trainable=False
model=Sequential()
model.add(generator)
model.add(discriminator)
model.compile(loss='binary_crossentropy',optimizer='adam')
return model
def generate_real_samples(n):
x1=rand(n)-0.5
x2=x1*x1
x1=x1.reshape(n,1)
x2=x2.reshape(n,1)
x=hstack((x1,x2))
y=ones((n,1))
return x,y
def generate_latent_points(latent_dim,n):
x_input=randn(latent_dim*n)
x_input=x_input.reshape(n,latent_dim)
return x_input
def generate_fake_samples(generator,latent_dim,n):
x_input=generate_latent_points(latent_dim,n)
x=generator.predict(x_input)
y=zeros((n,1))
return x,y
def summarize_performance(epoch,generator,discriminator,latent_dim,n=100):
x_real,y_real=generate_real_samples(n)
_,acc_real=discriminator.evaluate(x_real,y_real,verbose=0)
x_fake, y_fake = generate_fake_samples(generator,latent_dim,n)
_, acc_fake = discriminator.evaluate(x_fake, y_fake, verbose=0)
print(epoch,acc_real,acc_fake)
plt.scatter(x_real[:,0],x_real[:,1],color='red')
plt.scatter(x_fake[:, 0], x_fake[:, 1], color='blue')
plt.show()
def train(g_model,d_model,gan_model,latent_dim,n_epochs=10000,n_batch=128,n_eval=2000):
half_batch=int(n_batch/2)
for i in range(n_epochs):
x_real,y_real=generate_real_samples(half_batch)
x_fake,y_fake=generate_fake_samples(g_model,latent_dim,half_batch)
d_model.train_on_batch(x_real,y_real)
d_model.train_on_batch(x_fake, y_fake)
x_gan=generate_latent_points(latent_dim,n_batch)
y_gan=ones((n_batch,1))
gan_model.train_on_batch(x_gan,y_gan)
if(i+1)%n_epochs==0:
summarize_performance(i,g_model,d_model,latent_dim)
latent_dim=5
discriminator=define_discriminator()
generator=define_generator(latent_dim)
gan_model=define_gan(generator,discriminator)
train(generator,discriminator,gan_model,latent_dim)

问题

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2019-10-04 23:40
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥200 uniapp长期运行卡死问题解决
  • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?
  • ¥15 为什么使用javacv转封装rtsp为rtmp时出现如下问题:[h264 @ 000000004faf7500]no frame?
  • ¥15 乘性高斯噪声在深度学习网络中的应用
  • ¥15 关于docker部署flink集成hadoop的yarn,请教个问题 flink启动yarn-session.sh连不上hadoop,这个整了好几天一直不行,求帮忙看一下怎么解决
  • ¥15 深度学习根据CNN网络模型,搭建BP模型并训练MNIST数据集
  • ¥15 C++ 头文件/宏冲突问题解决
  • ¥15 用comsol模拟大气湍流通过底部加热(温度不同)的腔体
  • ¥50 安卓adb backup备份子用户应用数据失败