buaa_dmc 2017-09-12 10:55
浏览 516

求大神帮忙看看这个单隐层BP算法到底错在哪里...

from numpy import *

def sigmoid(inX):
return 1/(1+exp(-inX))

def dsigmoid(inX):
return exp(-inX)/multiply((1+exp(-inX)),(1+exp(-inX)))

def loadDataset(testSet):
dataMat=[]
labelMat=[]
fr=open(testSet)
for line in fr.readlines():
lineArr=line.strip().split()
n=len(lineArr)-1
tmp=[float(lineArr[i]) for i in range(n)]
dataMat.append(tmp)
labelMat.append(float(lineArr[-1]))
m=len(dataMat)
m,n=shape(dataMat)
dataMat=mat(dataMat)
dataMat=dataMat.T
labelMat=mat(labelMat)
return dataMat,labelMat,m,n

def iteration(X,Y,a,times,W1,b1,W2,b2):
n,m=shape(X)
for i in range(times):
#forwardProp
Z1=W1*X+b1
A1=sigmoid(Z1)
Z2=W2*A1+b2
A2=sigmoid(Z2)
#backProp
dZ2=A2-Y
dW2=1.0/m*dZ2*A1.T
db2=1.0/m*sum(dZ2,axis=1)
db2=reshape(1,1)
dZ1=multiply(W2.T*dZ2,dsigmoid(Z1))
dW1=1.0/m*dZ1*X.T
db1=1.0/m*sum(dZ1,axis=1)
db1=reshape(2,1)
W1=W1-a*dW1
b1=b1-a*db1
W2=W2-a*dW2
b2=b2-a*db2
return W1,b1,W2,b2

def BP(testSet,n1,a,iterTimes):
dataMat,labelMat,m,n=loadDataset(testSet)
n0=n
n2=1
W1=mat(random.randn(n1,n0)*0.01)
b1=mat(random.randn(n1,1)*0.01)
W2=mat(random.randn(n2,n1)*0.01)
b2=mat(random.randn(n2,1)*0.01)
#print(W1)
X=dataMat
Y=labelMat
return iteration(X,Y,a,iterTimes,W1,b1,W2,b2)

def test(testSet,W1,b1,W2,b2):
dataMat=[]
fr=open(testSet)
for line in fr.readlines():
lineArr=line.strip().split()
n=len(lineArr)-1
tmp=[float(lineArr[i]) for i in range(n)]
dataMat.append(tmp)
dataMat=mat(dataMat)
dataMat=dataMat.T
X=dataMat
#forwardProp
Z1=W1*X+b1
A1=sigmoid(Z1)
Z2=W2*A1+b2
A2=sigmoid(Z2)
return A2


  • 写回答

0条回答

    报告相同问题?

    悬赏问题

    • ¥50 成都蓉城足球俱乐部小程序抢票
    • ¥15 yolov7训练自己的数据集
    • ¥15 esp8266与51单片机连接问题(标签-单片机|关键词-串口)(相关搜索:51单片机|单片机|测试代码)
    • ¥15 电力市场出清matlab yalmip kkt 双层优化问题
    • ¥30 ros小车路径规划实现不了,如何解决?(操作系统-ubuntu)
    • ¥20 matlab yalmip kkt 双层优化问题
    • ¥15 如何在3D高斯飞溅的渲染的场景中获得一个可控的旋转物体
    • ¥88 实在没有想法,需要个思路
    • ¥15 MATLAB报错输入参数太多
    • ¥15 python中合并修改日期相同的CSV文件并按照修改日期的名字命名文件