weixin_45578223 2019-10-04 19:54 采纳率: 0%
浏览 2331

运行结果如下:train(generator,discriminator,gan_model,latent_dim) NameError: name 'train' is not defined,请问如何解决

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from numpy import hstack
from numpy import zeros
from numpy import ones
from numpy.random import rand
from numpy.random import randn
from keras.models import Sequential
from keras.layers import Dense
def define_discriminator(n_inputs=2):
model=Sequential()
model.add(Dense(25, activation='relu',kernel_initializer='he_uniform',input_dim=n_inputs))
model.add(Dense(1,activation='sigmoid'))
model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])
return model
def define_generator(latent_dim,n_outputs=2):
model=Sequential()
model.add(Dense(15, activation='relu',kernel_initializer='he_uniform', input_dim=latent_dim))
model.add(Dense(n_outputs,activation='linear'))
return model
def define_gan(generator,discriminator):
discriminator.trainable=False
model=Sequential()
model.add(generator)
model.add(discriminator)
model.compile(loss='binary_crossentropy',optimizer='adam')
return model
def generate_real_samples(n):
x1=rand(n)-0.5
x2=x1*x1
x1=x1.reshape(n,1)
x2=x2.reshape(n,1)
x=hstack((x1,x2))
y=ones((n,1))
return x,y
def generate_latent_points(latent_dim,n):
x_input=randn(latent_dim*n)
x_input=x_input.reshape(n,latent_dim)
return x_input
def generate_fake_samples(generator,latent_dim,n):
x_input=generate_latent_points(latent_dim,n)
x=generator.predict(x_input)
y=zeros((n,1))
return x,y
def summarize_performance(epoch,generator,discriminator,latent_dim,n=100):
x_real,y_real=generate_real_samples(n)
_,acc_real=discriminator.evaluate(x_real,y_real,verbose=0)
x_fake, y_fake = generate_fake_samples(generator,latent_dim,n)
_, acc_fake = discriminator.evaluate(x_fake, y_fake, verbose=0)
print(epoch,acc_real,acc_fake)
plt.scatter(x_real[:,0],x_real[:,1],color='red')
plt.scatter(x_fake[:, 0], x_fake[:, 1], color='blue')
plt.show()
def train(g_model,d_model,gan_model,latent_dim,n_epochs=10000,n_batch=128,n_eval=2000):
half_batch=int(n_batch/2)
for i in range(n_epochs):
x_real,y_real=generate_real_samples(half_batch)
x_fake,y_fake=generate_fake_samples(g_model,latent_dim,half_batch)
d_model.train_on_batch(x_real,y_real)
d_model.train_on_batch(x_fake, y_fake)
x_gan=generate_latent_points(latent_dim,n_batch)
y_gan=ones((n_batch,1))
gan_model.train_on_batch(x_gan,y_gan)
if(i+1)%n_epochs==0:
summarize_performance(i,g_model,d_model,latent_dim)
latent_dim=5
discriminator=define_discriminator()
generator=define_generator(latent_dim)
gan_model=define_gan(generator,discriminator)
train(generator,discriminator,gan_model,latent_dim)

问题

  • 写回答

1条回答

  • dabocaiqq 2019-10-04 23:40
    关注
    评论

报告相同问题?

悬赏问题

  • ¥100 求三轴之间相互配合画圆以及直线的算法
  • ¥100 c语言,请帮蒟蒻写一个题的范例作参考
  • ¥15 名为“Product”的列已属于此 DataTable
  • ¥15 安卓adb backup备份应用数据失败
  • ¥15 eclipse运行项目时遇到的问题
  • ¥15 关于#c##的问题:最近需要用CAT工具Trados进行一些开发
  • ¥15 南大pa1 小游戏没有界面,并且报了如下错误,尝试过换显卡驱动,但是好像不行
  • ¥15 自己瞎改改,结果现在又运行不了了
  • ¥15 链式存储应该如何解决
  • ¥15 没有证书,nginx怎么反向代理到只能接受https的公网网站