ali_wangli 2017-08-25 02:45 采纳率: 100%

# 用tensorflow写一个简单的神经网络识别mnist出现问题（python）

`````` #放入每个批次的数量
batch_size = 200
#计算有多少批次
n_batch = mnist.train.num_examples // batch_size

#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#创建神经网络中间层
W1 = tf.Variable(tf.zeros([784,100]))
b1 = tf.Variable(tf.zeros([784,100]))
L1 = tf.nn.sigmoid(tf.matmul(x,W1) + b1)
#创建神经网络输出层
W = tf.Variable(tf.zeros([100,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(L1,W) + b)

#定义二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))

#梯度下降

init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys})

acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("iter " + str(epoch) + "testing accuracy " + str(acc))
``````

InvalidArgumentError Traceback (most recent call last)
F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1326 try:
-> 1327 return fn(*args)
1328 except errors.OpError as e:

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1305 feed_dict, fetch_list, target_list,
1307

F:\Program Files\Anaconda\lib\contextlib.py in exit(self, type, value, traceback)
65 try:
---> 66 next(self.gen)
67 except StopIteration:

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()
465 compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 466 pywrap_tensorflow.TF_GetCode(status))
467 finally:

InvalidArgumentError: Incompatible shapes: [200,10] vs. [784,10]

During handling of the above exception, another exception occurred:

InvalidArgumentError Traceback (most recent call last)
in ()
35 for batch in range(n_batch):
36 batch_xs,batch_ys = mnist.train.next_batch(batch_size)
---> 37 sess.run(train_step,feed_dict = {x:batch_xs,y:batch_ys})
38
39 acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
893 try:
894 result = self._run(None, fetches, feed_dict, options_ptr,

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1122 if final_fetches or final_targets or (handle and feed_dict_tensor):
1123 results = self._do_run(handle, final_targets, final_fetches,
1125 else:
1126 results = []

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1319 if handle is None:
1320 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
1322 else:
1323 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

F:\Program Files\Anaconda\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1338 except KeyError:
1339 pass
-> 1340 raise type(e)(node_def, op, message)
1341
1342 def _extend_graph(self):

InvalidArgumentError: Incompatible shapes: [200,10] vs. [784,10]

• 写回答

#### 1条回答默认 最新

• silent56_th 2017-08-25 12:48
关注

b1 = tf.Variable(tf.zeros([784,100]))改成b1 = tf.Variable(tf.zeros([100,]))
应该可以解决这个报错

本回答被题主选为最佳回答 , 对您是否有帮助呢?
评论

#### 悬赏问题

• ¥15 vue引入sdk后的回调问题
• ¥15 求一个智能家居控制的代码
• ¥15 ad软件 pcb布线pcb规则约束编辑器where the object matpcb布线pcb规则约束编辑器where the object matchs怎么没有+15v只有no net
• ¥15 虚拟机vmnet8 nat模式可以ping通主机，主机也能ping通虚拟机，但是vmnet8一直未识别怎么解决，其次诊断结果就是默认网关不可用
• ¥20 求各位能用我能理解的话回答超级简单的一些问题
• ¥15 yolov5双目识别输出坐标代码报错
• ¥15 这个代码有什么语法错误
• ¥15 给予STM32按键中断与串口通信
• ¥15 使用QT实现can通信
• ¥15 关于sp验证的一些东西，求告知如何解决，