Hfunter 2018-07-19 04:58 采纳率: 50%
浏览 4872
已采纳

'Datasets' object has no attribute 'train_step'

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATAH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 50000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"

def backward(mnist):

x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))

learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATAH_SIZE,
                                           LEARNING_RATE_DECAY, staircase=True)

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
    train_op = tf.no_op(name='train')

saver = tf.train.Saver()

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    for i in range(STEPS):
        xs, ys = mnist.train_step.next_batch(BATAH_SIZE)
        _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
        if i % 1000 == 0:
            print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)

def main():
mnist = input_data.read_data_sets("./data/", one_hot=True)
backward(mnist)

if name == '__main__':
main()

运行程序后报错:

File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 54, in
main()
File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 51, in main
backward(mnist)
File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 43, in backward
xs, ys = mnist.train_step.next_batch(BATAH_SIZE)
AttributeError: 'Datasets' object has no attribute 'train_step'

展开全部

  • 写回答

2条回答 默认 最新

  • allenjan1988 2018-07-19 05:46
    关注

    请把xs, ys = mnist.train_step.next_batch(BATAH_SIZE)修改为xs, ys = mnist.train.next_batch(BATAH_SIZE),就可以正常运行了

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)
编辑
预览

报告相同问题?

悬赏问题

  • ¥15 KeiI中头文件找不到怎么解决
  • ¥15 QT6将音频采样数据转PCM
  • ¥15 本地安装org.Hs.eg.dby一直这样的图片报错如何解决?
  • ¥15 下面三个文件分别是OFDM波形的数据,我的思路公式和我写的成像算法代码,有没有人能帮我改一改,如何解决?
  • ¥15 Ubuntu打开gazebo模型调不出来,如何解决?
  • ¥100 有chang请一位会arm和dsp的朋友解读一个工程
  • ¥50 求代做一个阿里云百炼的小实验
  • ¥15 查询优化:A表100000行,B表2000 行,内存页大小只有20页,运行时3页,设计两个表等值连接的最简单的算法
  • ¥15 led数码显示控制(标签-流程图)
  • ¥20 为什么在复位后出现错误帧
手机看
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回
顶部