Bach_Orc 2019-06-10 12:49 采纳率: 100%
浏览 483
已采纳

求mnist多数字识别,修改完成我的代码

实现多个 数字的识别 如 图片说明

实现方法: 把五个数字的图 拼成一个 再进行 训练 和 测试

学校讲的自己看的都一知半解,我都不知道 我一个大二的,没什么基础的学生是怎么选上做这个研究的。。马上due就快到了,求大神在我代码基础上帮我完成。。

import tensorflow as tf
import numpy as np
from numpy import array
import os
import cv2
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = tf.keras.utils.normalize(x_train, axis=1)

x_test = tf.keras.utils.normalize(x_test, axis=1)

import random

countain = 0

myxtrain = []
myytrain = []

while countain < 5:
a_1 = random.randint(0, 10000)
a_2 = random.randint(0, 10000)
a_3 = random.randint(0, 10000)
a_4 = random.randint(0, 10000)
a_5 = random.randint(0, 10000)

a = np.concatenate((x_train[a_1], x_train[a_2], x_train[a_3], x_train[a_4], x_train[a_5]), axis=1)
myxtrain.append(a)
labelx = []
s1 = str(y_train[a_1])
s2 = str(y_train[a_2])
s3 = str(y_train[a_3])
s4 = str(y_train[a_4])
s5 = str(y_train[a_5])

labelx.append(s1)
labelx.append(s2)
labelx.append(s3)
labelx.append(s4)
labelx.append(s5)

s = '  '.join(labelx)
print('----', s)
# cv2.imwrite(os.path.join(path1, s + '.jpg'), a)
# cv2.waitKey(0)
myytrain.append(s)
countain += 1

x_train = array(myxtrain)
y_train = array(myytrain)

from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD

model = tf.keras.models.Sequential()

model.add(Conv2D(5, (3, 3), activation='relu', input_shape=(28, 140, 2)))
model.add(Conv2D(5, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))

sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)

model.fit(x_train, y_train, batch_size=1, epochs=10)


  • 写回答

3条回答 默认 最新

  • dabocaiqq 2019-06-10 12:57
    关注

    关键代码如下:

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("mnist_set",one_hot=True)
    #数据集相关常数
    INPUT_NODE = 784
    OUTPU_NODE = 10
    #配置神经网络参数
    LAYER1_NODE = 500
    BATCH_SIZE = 100
    LEARNING_RATE_BASE = 0.8 #基础学习率
    LEARNING_RATE_DECAY = 0.99#学习衰减率
    REGULARIZATION_RATE = 0.0001#正则的惩罚系数
    MOVE_AVG_RATE = 0.99 #滑动平均衰减率
    TRAIN_STEPS = 30000
    
    
    def inference(input_tensor,weights1,biases1,weight2,biases2,avg_class=None):
        #当没有提供滑动平均类时,直接使用当前值
        if avg_class == None:
            #计算隐藏层的前向传播结果,使用RELU激活函数
            layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
            #返回输出层的前向传播
            return tf.matmul(layer1,weight2)+ biases2
        else:
            #前向传播之前,用avg——class计算出变量的滑动平均值
            layer1 = tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1))+avg_class.average(biases1))
            return tf.matmul(layer1,avg_class.average(weight2))+ avg_class.average(biases2)
    #模型的训练过程
    def train(mnist):
        x = tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input')
        y_ = tf.placeholder(tf.float32,[None,OUTPU_NODE],name="y_input")
    
        #隐藏层参数
        w1 = tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1))
        b1 = tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
        #输出层参数
        w2 = tf.Variable(tf.random_normal([LAYER1_NODE,OUTPU_NODE],stddev=0.1))
        b2 = tf.Variable(tf.constant(0.1,shape=[OUTPU_NODE]))
        y = inference(x,w1,b1,w2,b2)
        #定义存储训练轮数的变量,设为不可训练
        global_step = tf.Variable(0,trainable=False)
        #初始化滑动平均类
        variable_averages = tf.train.ExponentialMovingAverage(MOVE_AVG_RATE,global_step)
        #在神经网络的所有参数变量上使用滑动平均
        variable_averages_op = variable_averages.apply(tf.trainable_variables())
        averages_y = inference(x,w1,b1,w2,b2,avg_class=variable_averages)
        #计算交叉熵损失
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
        #计算交叉熵平均值
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        #计算L2正则的损失函数
        regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
        regularization = regularizer(w1)+ regularizer(w2)
        #总的损失
        loss = cross_entropy_mean + regularization
    
        #设置指数衰减的学习率
        training_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)
        train_step = tf.train.GradientDescentOptimizer(training_rate).minimize(loss,global_step)
        #更新滑动平均值
        with tf.control_dependencies([train_step,variable_averages_op]):
            train_op = tf.no_op(name='train')
    
        #验证前向传播结果是否正确
        correct_prediction = tf.equal(tf.argmax(averages_y,1),tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
        #初始会话,开始训练
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            #准备验证数据和测试数据
            validate_feed = {x:mnist.validation.images,y_: mnist.validation.labels}
            test_feed = {x:mnist.test.images,y_: mnist.test.labels}
            #迭代训练神经网络
            for i in range(TRAIN_STEPS):
                if i % 1000 == 0:
                    validate_acc = sess.run(accuracy,feed_dict=validate_feed)
                    print("After %s training steps ,validation accuracy is %s"%(i,validate_acc))
                xs,ys = mnist.train.next_batch(BATCH_SIZE)
                sess.run(train_op,feed_dict={x:xs,y_:ys})
            #训练结束后,在测试集上验证准确率
            test_acc =  sess.run(accuracy,test_feed)
            print("After %s training steps ,test accuracy is %s"%(TRAIN_STEPS,test_acc))
    
    def main(argv=None):
        mnist = input_data.read_data_sets("mnist_set",one_hot=True)
        train(mnist)
    if __name__ == '__main__':
        #TF 提供了一个主程序入口,tf.app.run会自动调用上面的main()
        tf.app.run()
    

    完整代码和训练数据模型要等采纳以后给你。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

悬赏问题

  • ¥20 双层网络上信息-疾病传播
  • ¥50 paddlepaddle pinn
  • ¥20 idea运行测试代码报错问题
  • ¥15 网络监控:网络故障告警通知
  • ¥15 django项目运行报编码错误
  • ¥15 请问这个是什么意思?
  • ¥15 STM32驱动继电器
  • ¥15 Windows server update services
  • ¥15 关于#c语言#的问题:我现在在做一个墨水屏设计,2.9英寸的小屏怎么换4.2英寸大屏
  • ¥15 模糊pid与pid仿真结果几乎一样