反向传播是利用函数的链式求导来进行推导的,目的是通过不断调整权重和偏置来不断减小误差,最终得到误差最小的神经网络。
下面是其python的实现,只是一个3层的全连接网络。
import numpy as np
def nonlin(x,deriv = False):
if(deriv==True):
return x*(1-x)
return 1/(1+np.exp(-x))
input_layer = np.array([[0.35],[0.9]])
print(input_layer)
output = np.array([[0.5]])
print(output)
w0 = np.array([[0.1,0.8],[0.4,0.6]])
print(w0)
w1 = np.array([[0.3,0.9]])
print(w1)
for j in range(100):
l0 = input_layer
l1 = nonlin(np.dot(w0,l0))
l2 = nonlin(np.dot(w1,l1))
l2_error = output - l2
error = 1/2.0*(output-l2)**2
print("error")
print(error)
l2_delta = l2_error * nonlin(l2,deriv = True)
l1_error = l2_delta * w1
l1_delta = l1_error * nonlin(l1,deriv = True)
w1 += l2_delta * l1.T
w0 += l0.T.dot(l1_delta)
print('w0')
print(w0)
print('w1')
print(w1)
这是算法使用的全连接网络
最终输出为:
error
[[8.22616248e-06]]
w0
[[0.09815656 0.64379488]
[0.39815656 0.44379488]]
w1
[[-0.29141427 0.32375987]]
error
[[7.40440534e-06]]
w0
[[0.09823763 0.64370482]
[0.39823763 0.44370482]]
w1
[[-0.29203841 0.32315234]]
error
[[6.66477319e-06]]
w0
[[0.09831471 0.64361953]
[0.39831471 0.44361953]]
w1
[[-0.29263055 0.32257597]]
error
[[5.99905368e-06]]
w0
[[0.09838799 0.64353875]
[0.39838799 0.44353875]]
w1
[[-0.29319234 0.32202914]]
error
[[5.39985637e-06]]
w0
[[0.09845764 0.64346224]
[0.39845764 0.44346224]]
w1
[[-0.29372533 0.32151034]]
error
[[4.86053034e-06]]
w0
[[0.09852385 0.64338977]
[0.39852385 0.44338977]]
w1
[[-0.29423099 0.32101815]]
error
[[4.37509018e-06]]
w0
[[0.09858678 0.64332111]
[0.39858678 0.44332111]]
w1
[[-0.29471073 0.32055118]]
error
[[3.93814926e-06]]
w0
[[0.09864658 0.64325607]
[0.39864658 0.44325607]]
w1
[[-0.29516588 0.32010815]]
error
[[3.5448598e-06]]
w0
[[0.0987034 0.64319445]
[0.3987034 0.44319445]]
w1
[[-0.29559771 0.31968782]]
error
[[3.19085892e-06]]
w0
[[0.09875739 0.64313606]
[0.39875739 0.44313606]]
w1
[[-0.2960074 0.31928904]]
error
[[2.87222006e-06]]
w0
[[0.09880868 0.64308073]
[0.39880868 0.44308073]]
w1
[[-0.29639609 0.3189107 ]]
error
[[2.58540935e-06]]
w0
[[0.09885741 0.6430283 ]
[0.39885741 0.4430283 ]]
w1
[[-0.29676486 0.31855175]]
error
[[2.32724628e-06]]
w0
[[0.0989037 0.64297861]
[0.3989037 0.44297861]]
w1
[[-0.29711474 0.31821119]]
error
[[2.09486834e-06]]
w0
[[0.09894768 0.64293151]
[0.39894768 0.44293151]]
w1
[[-0.29744668 0.31788809]]
error
[[1.88569918e-06]]
w0
[[0.09898944 0.64288688]
[0.39898944 0.44288688]]
w1
[[-0.29776161 0.31758154]]
error
[[1.69742002e-06]]
w0
[[0.09902911 0.64284457]
[0.39902911 0.44284457]]
w1
[[-0.29806041 0.31729071]]
error
[[1.52794387e-06]]
w0
[[0.09906678 0.64280446]
[0.39906678 0.44280446]]
w1
[[-0.2983439 0.31701477]]
error
[[1.37539232e-06]]
w0
[[0.09910256 0.64276645]
[0.39910256 0.44276645]]
w1
[[-0.29861286 0.31675297]]
error
[[1.23807471e-06]]
w0
[[0.09913654 0.64273041]
[0.39913654 0.44273041]]
w1
[[-0.29886803 0.31650459]]
error
[[1.11446931e-06]]
w0
[[0.0991688 0.64269624]
[0.3991688 0.44269624]]
w1
[[-0.29911014 0.31626894]]
error
[[1.0032065e-06]]
w0
[[0.09919944 0.64266385]
[0.39919944 0.44266385]]
w1
[[-0.29933984 0.31604536]]
error
[[9.03053486e-07]]
w0
[[0.09922853 0.64263314]
[0.39922853 0.44263314]]
w1
[[-0.29955777 0.31583323]]
error
[[8.12900662e-07]]
w0
[[0.09925614 0.64260402]
[0.39925614 0.44260402]]
w1
[[-0.29976454 0.31563197]]
error
[[7.31749283e-07]]
w0
[[0.09928237 0.64257641]
[0.39928237 0.44257641]]
w1
[[-0.29996071 0.31544102]]
error
[[6.58700387e-07]]
w0
[[0.09930726 0.64255023]
[0.39930726 0.44255023]]
w1
[[-0.30014683 0.31525986]]
error
[[5.92944818e-07]]
w0
[[0.0993309 0.6425254]
[0.3993309 0.4425254]]
w1
[[-0.30032342 0.31508797]]
这些仅为部分输出,这便是bp算法的python实现。