错过水仙季 2022-10-17 21:23 采纳率: 71.4%
浏览 27
已结题

Use a.any() or a.all() 加了all,还是报一样的错

问题遇到的现象和发生背景
用代码块功能插入代码,请勿粘贴截图 
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
# import matplotlib.pyplot as plt

# x_train = pd.read_csv('D:\pycharm文件\轴承\轴承数据集/trdata.txt', delimiter=' ', header=None)
# data1 = pd.read_csv('D:\pycharm文件\轴承\轴承数据集/trlabel.txt', delimiter=' ')
# y_train = data1.iloc[0:4, :].values.T
# x_test = pd.read_csv('D:\pycharm文件\轴承\轴承数据集/txdata.txt', delimiter=' ', header=None)
# data2 = pd.read_csv('D:\pycharm文件\轴承\轴承数据集/txlabel.txt', delimiter=' ')
# y_test = data2.iloc[0:4, :].values.T
x_train = np.loadtxt('D:\pycharm文件\轴承\轴承数据集/trdata.txt', delimiter=' ')
y_ytrain = np.loadtxt('D:\pycharm文件\轴承\轴承数据集/trlabel.txt', delimiter=' ')
y_train = y_ytrain[0:4, 0:]
x_test = np.loadtxt('D:\pycharm文件\轴承\轴承数据集/txdata.txt', delimiter=' ')
y_ytest = np.loadtxt('D:\pycharm文件\轴承\轴承数据集/txlabel.txt', delimiter=' ')
y_test = y_ytest[0:4, 0:]


# def sigmoid(x):
#     return 1/(1+np.exp(-x))
def sigmoid(x):
    if x >= 0:      #对sigmoid函数的优化,避免了出现极大的数据溢出
        return 1.0/(1+np.exp(-x))
    else:
        return np.exp(x)/(1+np.exp(x))


def dsigmoid(y):
    return y * (1-y)


n, m = np.shape(x_train)
n_x = 400
n_h1 = 600
n_h2 = 200
n_h3 = 50
n_y = 4
w1 = np.random.randn(n_h1, n_x)*0.1   # 600*400
b1 = np.zeros((n_h1, 1))              # 600*1

w2 = np.random.randn(n_h2, n_h1)*0.1  # 200*600
b2 = np.zeros((n_h2, 1))              # 200*1

w3 = np.random.randn(n_h3, n_h2)*0.1  # 50*200
b3 = np.zeros((n_h3, 1))              # 50*1

w4 = np.random.randn(n_y, n_h3)*0.1  # 4*50
b4 = np.zeros((n_y, 1))              # 4*1

def forward(x_train, w1, b1, w2, b2, w3, b3, w4, b4):
    z1 = np.dot(w1, x_train) + b1  # 600*400 * 400*8000 =600 * 8000
    a1 = sigmoid(z1)  # 600 * 8000
    z2 = np.dot(w2, a1) + b2  # 200*600 * 600*8000 = 200*8000
    a2 = sigmoid(z2)  # 200*8000
    z3 = np.dot(w3, a2) + b3  # 50*200 * 200*8000 = 50*8000
    a3 = sigmoid(z3)  # 50*8000
    z4 = np.dot(w4, a3) + b4  # 4*50 * 50*8000 = 4*8000
    a4 = z4           # 4*8000
    return z1, z2, z3, z4, a1, a2, a3, a4


def costfuction(a4, y_train):
    n, m = np.shape(x_train)
    error = np.sum(0.5 * (a4 - y_train.T) ** 2)/m
    return error


def backforward(y_train, a4, a3, a2, a1, w4, w3, w2):
    dz4 = a4 - y_train.T  # 4*8000
    dw4 = np.dot(dz4, a3.T)  # 4*8000 * 8000*50 = 4*50
    db4 = np.sum(dz4, axis=1, keepdims=True)/m  # 4*8000
    dz3 = np.dot(w4.T, dz4) * sigmoid(a3)     # 50*4 * 4*8000 = 50*8000
    dw3 = np.dot(dz3, a2.T)  # 50*8000 * 8000*200 = 50*200
    db3 = np.sum(dz3, axis=1, keepdims=True)/m  # 50*8000
    dz2 = np.dot(w3.T, dz3) * sigmoid(a2)     # 200*50 * 50*8000 = 200*8000
    dw2 = np.dot(dz2, a1.T)  # 200*8000 * 8000*600 = 200*600
    db2 = np.sum(dz2, axis=1, keepdims=True)/m  # 200*8000
    dz1 = np.dot(w2.T, dz2) * sigmoid(a1)     # 600*200 * 200*8000 = 600*8000
    dw1 = np.dot(dz1, x_train.T)              # 600*8000 * 8000*400 = 600*400
    db1 = np.sum(dz1, axis=1, keepdims=True)/m  # 600*8000
    return dz4, dw4, db4, dz3, dw3, db3, dz2, dw2, db2, dz1, dw1, db1


alpha = 0.01
number = 5000
for i in range(1, number+1):
    z1, z2, z3, z4, a1, a2, a3, a4 = forward(x_train, w1, b1, w2, b2, w3, b3, w4, b4)
    error = np.sum(0.5 * (a4 - y_train.T) ** 2) / m
    dz4, dw4, db4, dz3, dw3, db3, dz2, dw2, db2, dz1, dw1, db1 = backforward(y_train, a4, a3, a2, a1, w4, w3, w2)
    w1 = w1 - alpha * dw1
    w2 = w2 - alpha * dw2
    w3 = w3 - alpha * dw3
    w4 = w4 - alpha * dw4
    b1 = b1 - alpha * db1
    b2 = b2 - alpha * db2
    b3 = b3 - alpha * db3
    b4 = b4 - alpha * db4

    if (i % 1000 == 0):
        print(i)
#     plt.plot(i, error, 'ro')
# plt.show()


z1_test, z2_test, z3_test, z4_test, a1_test, a2_test, a3_test, a4_test = forward(x_test, w1, b1, w2, b2, w3, b3, w4, b4)
print(a4_test)
n_test, m_test = np.shape(x_test)
error_test = np.sum(0.5 * (a4_test - y_test.T) ** 2) / m_test
print(error_test)


c = 0
b = np.rint(a4_test.T)  ##  对矩阵取整
n_2, m_2 = np.shape(a4_test.T)
for i in range(0, m_2):
    c1 = b[i, :]    ##  对两个矩阵相对应的位置作比较
    c2 = y_test[i, :]
    if (c1 == c2).any():  ##  不加all有3个tf,if函数识别不了三个,必须加all才能明白这个一整对矩阵是否相同
        c = c + 1
acc = c / m_test * 100
print('正确了', c, "个")
print('准确率:%.2f%%' % acc)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我已经加了all(),为什么还是报这个错误
  • 写回答

1条回答 默认 最新

  • 吴天德少侠 2022-10-18 13:05
    关注

    你这个是2个numpy数组比较相等啊,不是这样比较吧

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 5月18日
  • 已采纳回答 5月18日
  • 创建了问题 10月17日

悬赏问题

  • ¥15 Mac系统vs code使用phpstudy如何配置debug来调试php
  • ¥15 目前主流的音乐软件,像网易云音乐,QQ音乐他们的前端和后台部分是用的什么技术实现的?求解!
  • ¥60 pb数据库修改与连接
  • ¥15 spss统计中二分类变量和有序变量的相关性分析可以用kendall相关分析吗?
  • ¥15 拟通过pc下指令到安卓系统,如果追求响应速度,尽可能无延迟,是不是用安卓模拟器会优于实体的安卓手机?如果是,可以快多少毫秒?
  • ¥20 神经网络Sequential name=sequential, built=False
  • ¥16 Qphython 用xlrd读取excel报错
  • ¥15 单片机学习顺序问题!!
  • ¥15 ikuai客户端多拨vpn,重启总是有个别重拨不上
  • ¥20 关于#anlogic#sdram#的问题,如何解决?(关键词-performance)