我找了很多教程,但没有一个教程适用于我这一种情况。我曾尝试看自己的x和batch的类型,发现他们并不是None类型的数据。我怀疑是由于tf.session.run()中的fetches是一个tf.Operation,它会返回None造成的,但我不知道怎么改。请问有遇到这种情况的吗?可以和我说一下您当时是怎么解决的吗?
报错代码片段如下:
片段1:
for epoch in t_bar:
supervised_g_losses = []
self.gen_loader.reset_pointer()
for it in range(self.gen_loader.num_batch):
batch = self.gen_loader.next_batch()
_, g_loss, g_pred = self.generator.pretrain_step(self.sess,
batch)
supervised_g_losses.append(g_loss)
# print results
mean_g_loss = np.mean(supervised_g_losses)
t_bar.set_postfix(G_loss=mean_g_loss)
samples = self.generate_samples(self.SAMPLE_NUM)
self.mle_loader.create_batches(samples)
片段2:
def pretrain_step(self, session, x):
"""Performs a pretraining step on the generator."""
outputs = session.run([self.pretrain_updates, self.pretrain_loss,
self.g_predictions], feed_dict={self.x: x})
return outputs