Hfunter
2018-07-19 12:58
采纳率: 50%
浏览 4.1k

'Datasets' object has no attribute 'train_step'

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATAH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 50000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"

def backward(mnist):

x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))

learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATAH_SIZE,
                                           LEARNING_RATE_DECAY, staircase=True)

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
    train_op = tf.no_op(name='train')

saver = tf.train.Saver()

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    for i in range(STEPS):
        xs, ys = mnist.train_step.next_batch(BATAH_SIZE)
        _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
        if i % 1000 == 0:
            print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)

def main():
mnist = input_data.read_data_sets("./data/", one_hot=True)
backward(mnist)

if name == '__main__':
main()

运行程序后报错:

File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 54, in
main()
File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 51, in main
backward(mnist)
File "C:/Users/98382/PycharmProjects/minst/mnist_backward.py", line 43, in backward
xs, ys = mnist.train_step.next_batch(BATAH_SIZE)
AttributeError: 'Datasets' object has no attribute 'train_step'

  • 写回答
  • 好问题 提建议
  • 关注问题
  • 收藏
  • 邀请回答

2条回答 默认 最新

  • allenjan1988 2018-07-19 13:46
    已采纳

    请把xs, ys = mnist.train_step.next_batch(BATAH_SIZE)修改为xs, ys = mnist.train.next_batch(BATAH_SIZE),就可以正常运行了

    已采纳该答案
    评论
    解决 4 无用
    打赏 举报
  • licgpolo 2018-07-22 07:16

    可以查看mnist的属性,如果没有train _step,可以找并且属性里有next_batch(),程序就可以了

    评论
    解决 无用
    打赏 举报

相关推荐 更多相似问题