错过水仙季 2022-10-13 20:15 采纳率: 71.4%
浏览 11
已结题

将基于迭代次数进行神经网络的截止改为通过设置预值大小来使神经网络截止

将基于迭代次数进行神经网络的截止改为通过设置预值大小来使神经网络截止,比如说损失值达到多少的时候让它跳出循环不在训练了。下面代码改如何调整

from sklearn import preprocessing
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
data_set = pd.read_csv('C:\\Users\\pc.000\\Desktop\\数据集\\iris.csv', delimiter=',', header=None)
Train_Y = data_set.iloc[0:150, 4:].values.T
Train_Y = np.array(Train_Y).reshape([-1, 1])
Encoder = preprocessing.OneHotEncoder()
Encoder.fit(Train_Y)
Train_Y = Encoder.transform(Train_Y).toarray()
Train_Y = np.asarray(Train_Y, dtype=np.int32)
x_train = data_set.iloc[0:120, 0:4].values.T
x_test = data_set.iloc[120:151, 0:4].values.T
y_train = Train_Y[0:120]
y_test = Train_Y[120:151]



def sigmiod(x):
    return 1/(1+np.exp(-x))


def dsigmiod(y):
    return y*(1-y)


n, m = np.shape(x_train)
n_x = 4
n_h1 = 10
n_h2 = 5
n_y = 3
np.random.seed(2)
w1 = np.random.randn(n_h1, n_x)
b1 = np.zeros((n_h1, 1))
w2 = np.random.randn(n_h2, n_h1)
b2 = np.zeros((n_h2, 1))
w3 = np.random.randn(n_y, n_h2)
b3 = np.zeros((n_y, 1))


def forward(w1, x_train, b1, w2, b2, w3, b3):
    z1 = np.dot(w1, x_train) + b1
    a1 = sigmiod(z1)
    z2 = np.dot(w2, a1) + b2
    a2 = sigmiod(z2)
    z3 = np.dot(w3, a2) + b3
    a3 = z3
    return z1, a1, z2, a2, z3, a3


def costfuction(a3, y_train):
    error =np.sum(0.5*(a3 - y_train.T)**2)
    return error


def backward(a3, y_train, a2, w3, a1, w2, x_train):
    dz3 = a3 - y_train.T  # 1 * 90
    dw3 = np.dot(dz3, a2.T)  # 1 * 90 * 90 * 5 = 1 * 5
    db3 = np.sum(dz3, axis=1, keepdims=True)
    dz2 = np.dot(w3.T, dz3) * dsigmiod(a2)
    dw2 = np.dot(dz2, a1.T)
    db2 = np.sum(dz2, axis=1, keepdims=True)
    dz1 = np.dot(w2.T, dz2) * dsigmiod(a1)
    dw1 = np.dot(dz1, x_train.T)
    db1 = np.sum(dz1, axis=1, keepdims=True)
    return dz3, dw3, db3, dz2, dw2, db2, dz1, dw1, db1


alpha = 0.0007
number = 15000
for i in range(1, number+1):
    z1, a1, z2, a2, z3, a3 = forward(w1, x_train, b1, w2, b2, w3, b3)
    error = np.sum(0.5*(a3 - y_train.T)**2)/m
    dz3, dw3, db3, dz2, dw2, db2, dz1, dw1, db1 = backward(a3, y_train, a2, w3, a1, w2, x_train)
    w1 = w1 - alpha * dw1
    w2 = w2 - alpha * dw2
    w3 = w3 - alpha * dw3
    b1 = b1 - alpha * db1
    b2 = b2 - alpha * db2
    b3 = b3 - alpha * db3

    if i % 1000 == 0:
        print(i)
    plt.plot(i, error, 'ro')
plt.show()

z1_test, a1_test, z2_test, a2_test, z3_test, a3_test = forward(w1, x_test, b1, w2, b2, w3, b3)
n_test, m_test = np.shape(x_test)
error_test = np.sum(0.5 * (a3_test - y_test.T) ** 2) / m_test
print(error_test)

c = 0
b = np.rint(a3_test.T)  ##  对矩阵取整
n_2, m_2 = np.shape(a3_test)
for i in range(0, m_2):
    c1 = b[i, :]    ##  对两个矩阵相对应的位置作比较
    c2 = y_test[i, :]
    if (c1 == c2).all():  ##  不加all有3个tf,if函数识别不了三个,必须加all才能明白这个一整对矩阵是否相同
        c = c + 1
acc = c / m_test * 100
print('正确了', c, "个")
print('准确率:%.2f%%' % acc)


  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 10月21日
    • 创建了问题 10月13日

    悬赏问题

    • ¥15 高缺失率数据如何选择填充方式
    • ¥50 potsgresql15备份问题
    • ¥15 Mac系统vs code使用phpstudy如何配置debug来调试php
    • ¥15 目前主流的音乐软件,像网易云音乐,QQ音乐他们的前端和后台部分是用的什么技术实现的?求解!
    • ¥60 pb数据库修改与连接
    • ¥15 spss统计中二分类变量和有序变量的相关性分析可以用kendall相关分析吗?
    • ¥15 拟通过pc下指令到安卓系统,如果追求响应速度,尽可能无延迟,是不是用安卓模拟器会优于实体的安卓手机?如果是,可以快多少毫秒?
    • ¥20 神经网络Sequential name=sequential, built=False
    • ¥16 Qphython 用xlrd读取excel报错
    • ¥15 单片机学习顺序问题!!