qq_34644971 2018-11-10 04:51 采纳率: 50%
浏览 1468

使用RNN进行手写数字识别,为什么正确率总是无法提高

我使用最简单RNN进行mnist手写数字的识别,为什么交叉商总是无法降低呢。完整代码如下。

import tensorflow as tf

from tensorflow.contrib.layers import fully_connected
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/home/as/mnist_dataset', one_hot=True)
n_steps = 28
n_inputs = 28
n_neurons = 100
x = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
action_one_hot = tf.placeholder(tf.float32,[None,10])

basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.nn.dynamic_rnn(basic_cell,x,dtype=tf.float32)
y0 = fully_connected(states,100,activation_fn=tf.nn.relu)
y = fully_connected(y0,10)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=action_one_hot, logits=y)
mean_loss = tf.reduce_mean(cross_entropy)
trian_op = tf.train.AdamOptimizer(0.001).minimize(mean_loss)

with tf.Session() as sess:
    for i in range(10000):
        sess.run(tf.global_variables_initializer())
        x1,y1 = mnist.train.next_batch(1000)
        x1 = x1.reshape((-1,n_steps,n_inputs))
        sess.run(trian_op,feed_dict={x:x1,action_one_hot:y1})
        if i%200==0:
            print(sess.run(mean_loss,feed_dict={x:x1,action_one_hot:y1}))

就是在每200步输出一下交叉商,但是这个交叉商总是无法下降。

  • 写回答

1条回答

  • threenewbee 2018-11-10 07:53
    关注

    RNN做文字识别没有什么优势,建议你用CNN。

    评论

报告相同问题?

悬赏问题

  • ¥15 【提问】基于Invest的水源涵养
  • ¥20 微信网友居然可以通过vx号找到我绑的手机号
  • ¥15 spring后端vue前端
  • ¥15 寻一个支付宝扫码远程授权登录的软件助手app
  • ¥15 解riccati方程组
  • ¥15 display:none;样式在嵌套结构中的已设置了display样式的元素上不起作用?
  • ¥15 使用rabbitMQ 消息队列作为url源进行多线程爬取时,总有几个url没有处理的问题。
  • ¥15 Ubuntu在安装序列比对软件STAR时出现报错如何解决
  • ¥50 树莓派安卓APK系统签名
  • ¥65 汇编语言除法溢出问题