buaa_dmc 2017-09-12 10:55
浏览 516

求大神帮忙看看这个单隐层BP算法到底错在哪里...

from numpy import *

def sigmoid(inX):
return 1/(1+exp(-inX))

def dsigmoid(inX):
return exp(-inX)/multiply((1+exp(-inX)),(1+exp(-inX)))

def loadDataset(testSet):
dataMat=[]
labelMat=[]
fr=open(testSet)
for line in fr.readlines():
lineArr=line.strip().split()
n=len(lineArr)-1
tmp=[float(lineArr[i]) for i in range(n)]
dataMat.append(tmp)
labelMat.append(float(lineArr[-1]))
m=len(dataMat)
m,n=shape(dataMat)
dataMat=mat(dataMat)
dataMat=dataMat.T
labelMat=mat(labelMat)
return dataMat,labelMat,m,n

def iteration(X,Y,a,times,W1,b1,W2,b2):
n,m=shape(X)
for i in range(times):
#forwardProp
Z1=W1*X+b1
A1=sigmoid(Z1)
Z2=W2*A1+b2
A2=sigmoid(Z2)
#backProp
dZ2=A2-Y
dW2=1.0/m*dZ2*A1.T
db2=1.0/m*sum(dZ2,axis=1)
db2=reshape(1,1)
dZ1=multiply(W2.T*dZ2,dsigmoid(Z1))
dW1=1.0/m*dZ1*X.T
db1=1.0/m*sum(dZ1,axis=1)
db1=reshape(2,1)
W1=W1-a*dW1
b1=b1-a*db1
W2=W2-a*dW2
b2=b2-a*db2
return W1,b1,W2,b2

def BP(testSet,n1,a,iterTimes):
dataMat,labelMat,m,n=loadDataset(testSet)
n0=n
n2=1
W1=mat(random.randn(n1,n0)*0.01)
b1=mat(random.randn(n1,1)*0.01)
W2=mat(random.randn(n2,n1)*0.01)
b2=mat(random.randn(n2,1)*0.01)
#print(W1)
X=dataMat
Y=labelMat
return iteration(X,Y,a,iterTimes,W1,b1,W2,b2)

def test(testSet,W1,b1,W2,b2):
dataMat=[]
fr=open(testSet)
for line in fr.readlines():
lineArr=line.strip().split()
n=len(lineArr)-1
tmp=[float(lineArr[i]) for i in range(n)]
dataMat.append(tmp)
dataMat=mat(dataMat)
dataMat=dataMat.T
X=dataMat
#forwardProp
Z1=W1*X+b1
A1=sigmoid(Z1)
Z2=W2*A1+b2
A2=sigmoid(Z2)
return A2


  • 写回答

0条回答

    报告相同问题?

    悬赏问题

    • ¥15 求差集那个函数有问题,有无佬可以解决
    • ¥15 【提问】基于Invest的水源涵养
    • ¥20 微信网友居然可以通过vx号找到我绑的手机号
    • ¥15 寻一个支付宝扫码远程授权登录的软件助手app
    • ¥15 解riccati方程组
    • ¥15 display:none;样式在嵌套结构中的已设置了display样式的元素上不起作用?
    • ¥15 使用rabbitMQ 消息队列作为url源进行多线程爬取时,总有几个url没有处理的问题。
    • ¥15 Ubuntu在安装序列比对软件STAR时出现报错如何解决
    • ¥50 树莓派安卓APK系统签名
    • ¥65 汇编语言除法溢出问题