ali_wangli
ali_wangli
采纳率100%
2017-08-25 02:45

用tensorflow写一个简单的神经网络识别mnist出现问题(python)

 #放入每个批次的数量
batch_size = 200
#计算有多少批次
n_batch = mnist.train.num_examples // batch_size

#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#创建神经网络中间层
W1 = tf.Variable(tf.zeros([784,100]))
b1 = tf.Variable(tf.zeros([784,100]))
L1 = tf.nn.sigmoid(tf.matmul(x,W1) + b1)
#创建神经网络输出层
W = tf.Variable(tf.zeros([100,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(L1,W) + b)

#定义二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))

#梯度下降
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys})

        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
    print("iter " + str(epoch) + "testing accuracy " + str(acc))

本来没有隐层,想加一个隐层,然后哦出现各种问题,然后有一个报错不知道怎么解决,求教,谢谢!
InvalidArgumentError Traceback (most recent call last)
F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1326 try:
-> 1327 return fn(*args)
1328 except errors.OpError as e:

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1305 feed_dict, fetch_list, target_list,
-> 1306 status, run_metadata)
1307

F:\Program Files\Anaconda\lib\contextlib.py in exit(self, type, value, traceback)
65 try:
---> 66 next(self.gen)
67 except StopIteration:

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()
465 compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 466 pywrap_tensorflow.TF_GetCode(status))
467 finally:

InvalidArgumentError: Incompatible shapes: [200,10] vs. [784,10]
[[Node: gradients_2/sub_2_grad/BroadcastGradientArgs = BroadcastGradientArgsT=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"]]

During handling of the above exception, another exception occurred:

InvalidArgumentError Traceback (most recent call last)
in ()
35 for batch in range(n_batch):
36 batch_xs,batch_ys = mnist.train.next_batch(batch_size)
---> 37 sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys})
38
39 acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
893 try:
894 result = self._run(None, fetches, feed_dict, options_ptr,
--> 895 run_metadata_ptr)
896 if run_metadata:
897 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1122 if final_fetches or final_targets or (handle and feed_dict_tensor):
1123 results = self._do_run(handle, final_targets, final_fetches,
-> 1124 feed_dict_tensor, options, run_metadata)
1125 else:
1126 results = []

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1319 if handle is None:
1320 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1321 options, run_metadata)
1322 else:
1323 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1338 except KeyError:
1339 pass
-> 1340 raise type(e)(node_def, op, message)
1341
1342 def _extend_graph(self):

InvalidArgumentError: Incompatible shapes: [200,10] vs. [784,10]

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享
  • 邀请回答

1条回答

  • silent56_th silent56_th 4年前

    b1 = tf.Variable(tf.zeros([784,100]))改成b1 = tf.Variable(tf.zeros([100,]))
    应该可以解决这个报错

    点赞 2 评论 复制链接分享