错过水仙季 2022-10-17 13:24 采纳率: 71.4%
浏览 16
已结题

亲们,看一看我这代码算不算是一个自编码器

问题遇到的现象和发生背景
用代码块功能插入代码,请勿粘贴截图
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
data_set = pd.read_csv('C:\\Users\\pc.000\\Desktop\\数据集\\iris.csv', delimiter=',', header=None)
x_train = data_set.iloc[0:100, 0:4].values.T
x_test = data_set.iloc[100:151, 0:4].values.T
# y_train = Train_Y[0:100]
# y_test = Train_Y[100:151]


def sigmiod(x):
    return 1/(1+np.exp(-x))


def dsigmiod(y):
    return y*(1-y)


n, m = np.shape(x_train)
n_x = 4
n_h = 2
n_y = 4
np.random.seed(2)
w1 = np.random.randn(n_h, n_x)*0.1
b1 = np.zeros((n_h, 1))
w2 = np.random.randn(n_y, n_h)*0.1
b2 = np.zeros((n_y, 1))


def forward(w1, x_train, b1, w2, b2):
    z1 = np.dot(w1, x_train) + b1
    a1 = sigmiod(z1)
    z2 = np.dot(w2, a1) + b2
    a2 = z2
    return z1, a1, z2, a2


def costfuction(a2, x_train):
    error =np.sum(0.5*(a2 - x_train)**2)/m
    return error


def backward(y_train, a2, a1, w2, x_train):
    dz2 = a2 - y_train  # 1 * 90
    dw2 = np.dot(dz2, a1.T)  # 1 * 90 * 90 * 5 = 1 * 5
    db2 = np.sum(dz2, axis=1, keepdims=True)/m
    dz1 = np.dot(w2.T, dz2) * dsigmiod(a1)
    dw1 = np.dot(dz1, x_train.T)
    db1 = np.sum(dz1, axis=1, keepdims=True)/m
    return dz2, dw2, db2, dz1, dw1, db1


alpha = 0.0001
number = 15000
for i in range(1, number+1):
    z1, a1, z2, a2 = forward(w1, x_train, b1, w2, b2)
    error = np.sum(0.5*(a2 - x_train)**2)/m
    dz2, dw2, db2, dz1, dw1, db1 = backward(x_train, a2, a1, w2, x_train)
    w1 = w1 - alpha * dw1
    w2 = w2 - alpha * dw2
    b1 = b1 - alpha * db1
    b2 = b2 - alpha * db2

    if i % 1000 == 0:
        print("迭代次数", i)
        print("误差为", error)


z1_test, a1_test, z2_test, a2_test = forward(w1, x_test, b1, w2, b2)
n_test, m_test = np.shape(x_test)
error_test = np.sum(0.5 * (a2_test - x_test) ** 2) / m_test
print("测试的误差为", error_test)

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 10月25日
    • 创建了问题 10月17日

    悬赏问题

    • ¥15 ogg dd trandata 报错
    • ¥15 高缺失率数据如何选择填充方式
    • ¥50 potsgresql15备份问题
    • ¥15 Mac系统vs code使用phpstudy如何配置debug来调试php
    • ¥15 目前主流的音乐软件,像网易云音乐,QQ音乐他们的前端和后台部分是用的什么技术实现的?求解!
    • ¥60 pb数据库修改与连接
    • ¥15 spss统计中二分类变量和有序变量的相关性分析可以用kendall相关分析吗?
    • ¥15 拟通过pc下指令到安卓系统,如果追求响应速度,尽可能无延迟,是不是用安卓模拟器会优于实体的安卓手机?如果是,可以快多少毫秒?
    • ¥20 神经网络Sequential name=sequential, built=False
    • ¥16 Qphython 用xlrd读取excel报错