qq_31630271 2018-11-27 12:42 采纳率: 0%
浏览 1317

关于Tensorflow的DNN分类器

用Tensorflow写了一个简易的DNN网络(输入,一个隐层,输出),用作分类,数据集选用的是UCI 的iris数据集
激活函数使用softmax loss函数使用对数似然 以便最后的结果是一个概率解,选概率最大的分类的结果
目前的问题是预测结果出现问题,用测试数据测试显示结果如下

图片说明

刚刚入门...希望大家指点一下,谢谢辣!

 #coding:utf-8
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn import preprocessing
from sklearn.model_selection import cross_val_score

BATCH_SIZE = 30

iris = pd.read_csv('F:\dataset\iris\dataset.data', sep=',', header=None)

'''
# 查看导入的数据
print("Dataset Lenght:: ", len(iris))
print("Dataset Shape:: ", iris.shape)
print("Dataset:: ")
print(iris.head(150))
'''

#将每条数据划分为样本值和标签值
X = iris.values[:, 0:4]
Y = iris.values[:, 4]

# 整理一下标签数据
# Iris-setosa       ---> 0
# Iris-versicolor   ---> 1
# Iris-virginica    ---> 2
for i in range(len(Y)):
    if Y[i] == 'Iris-setosa':
        Y[i] = 0
    elif Y[i] == 'Iris-versicolor':
        Y[i] = 1
    elif Y[i] == 'Iris-virginica':
        Y[i] = 2

# 划分训练集与测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=10)

#对数据集X与Y进行shape整理,让第一个参数为-1表示整理X成n行2列,整理Y成n行1列
X_train = np.vstack(X_train).reshape(-1, 4)
Y_train = np.vstack(Y_train).reshape(-1, 1)
X_test = np.vstack(X_test).reshape(-1, 4)
Y_test = np.vstack(Y_test).reshape(-1, 1)

'''
print(X_train)
print(Y_train)
print(X_test)
print(Y_test)
'''

#定义神经网络的输入,参数和输出,定义前向传播过程
def get_weight(shape):
    w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
    return w

def get_bias(shape):
    b = tf.Variable(tf.constant(0.01, shape=shape))
    return b

x = tf.placeholder(tf.float32, shape=(None, 4))
yi = tf.placeholder(tf.float32, shape=(None, 1))

def BP_Model():
    w1 = get_weight([4, 10])  # 第一个隐藏层,10个神经元,4个输入
    b1 = get_bias([10])
    y1 = tf.nn.softmax(tf.matmul(x, w1) + b1)  # 注意维度

    w2 = get_weight([10, 3])  # 输出层,3个神经元,10个输入
    b2 = get_bias([3])
    y = tf.nn.softmax(tf.matmul(y1, w2) + b2)

    return y

def train():
    # 生成计算图
    y = BP_Model()
    # 定义损失函数
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(yi, 1))
    loss_cem = tf.reduce_mean(ce)
    # 定义反向传播方法,正则化
    train_step = tf.train.AdamOptimizer(0.001).minimize(loss_cem)
    # 定义保存器
    saver = tf.train.Saver(tf.global_variables())
    #生成会话
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        Steps = 5000
        for i in range(Steps):
            start = (i * BATCH_SIZE) % 300
            end = start + BATCH_SIZE
            reslut = sess.run(train_step, feed_dict={x: X_train[start:end], yi: Y_train[start:end]})
            if i % 100 == 0:
                loss_val = sess.run(loss_cem, feed_dict={x: X_train, yi: Y_train})
                print("step: ", i, "loss: ", loss_val)
        print("保存模型: ", saver.save(sess, './model_iris/bp_model.model'))
    tf.summary.FileWriter("logs/", sess.graph)
#train()

def prediction():
    # 生成计算图
    y = BP_Model()
    # 定义损失函数
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(yi, 1))
    loss_cem = tf.reduce_mean(ce)
    # 定义保存器
    saver = tf.train.Saver(tf.global_variables())
    with tf.Session() as sess:
        saver.restore(sess, './model_iris/bp_model.model')
        result = sess.run(y, feed_dict={x: X_test})
        loss_val = sess.run(loss_cem, feed_dict={x: X_test, yi: Y_test})
        print("result :", result)
        print("loss :", loss_val)
        result_set = sess.run(tf.argmax(result, axis=1))
        print("predict result: ", result_set)
        print("real result: ", Y_test.reshape(1, -1))

#prediction()

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-10-25 19:26
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

悬赏问题

  • ¥100 有人会搭建GPT-J-6B框架吗?有偿
  • ¥15 求差集那个函数有问题,有无佬可以解决
  • ¥15 【提问】基于Invest的水源涵养
  • ¥20 微信网友居然可以通过vx号找到我绑的手机号
  • ¥15 寻一个支付宝扫码远程授权登录的软件助手app
  • ¥15 解riccati方程组
  • ¥15 display:none;样式在嵌套结构中的已设置了display样式的元素上不起作用?
  • ¥15 使用rabbitMQ 消息队列作为url源进行多线程爬取时,总有几个url没有处理的问题。
  • ¥15 Ubuntu在安装序列比对软件STAR时出现报错如何解决
  • ¥50 树莓派安卓APK系统签名