使用神经网络解决问题
import sys
sys.path.append('..') # 为了引入父目录的文件而进行的设定
from dataset import spiral
import matplotlib.pyplot as plt
x,t = spiral.load_data()
print('x',x.shape)# (300,2)
print('y',t.shape)# (300,3)
x (300, 2)
y (300, 3)
在上面的例子中,要从ch01目录的dataset目录引入spiral.py。因此,上面的代码通过sys.path.append(‘…’)将父目录添加到了import的检索路径中。
print(x)
print(t)
[[-0.00000000e+00 0.00000000e+00]
[-9.76986432e-04 9.95216044e-03]
[ 5.12668241e-03 1.93317647e-02]
[-3.86043324e-04 2.99975161e-02]
[ 1.42509650e-02 3.73752591e-02]
[ 9.41914082e-04 4.99911272e-02]
[ 2.25361319e-02 5.56068589e-02]
[ 6.52848904e-03 6.96948982e-02]
[ 2.50649535e-02 7.59720219e-02]
[ 2.03287580e-02 8.76740646e-02]
[ 5.98440862e-02 8.01166983e-02]
[ 6.19050693e-02 9.09272368e-02]
[ 3.22809763e-02 1.15576549e-01]
[ 8.28423530e-02 1.00185551e-01]
[ 1.09856959e-01 8.67839183e-02]
[ 9.33208222e-02 1.17436043e-01]
[ 7.82976217e-02 1.39533087e-01]
[ 1.23994559e-01 1.16298535e-01]
[ 8.06199110e-02 1.60936105e-01]
[ 1.39235917e-01 1.29280158e-01]
[ 1.53599653e-01 1.28090384e-01]
[ 1.38981638e-01 1.57429680e-01]
[ 1.89873275e-01 1.11122183e-01]
[ 1.41160729e-01 1.81586477e-01]
[ 1.50631465e-01 1.86842613e-01]
[ 1.71714639e-01 1.81697778e-01]
[ 2.10050449e-01 1.53227963e-01]
[ 2.32019716e-01 1.38082770e-01]
[ 2.31219125e-01 1.57916802e-01]
[ 2.51798226e-01 1.43866791e-01]
[ 2.59928236e-01 1.49790894e-01]
[ 3.03168799e-01 6.47200053e-02]
[ 2.99776174e-01 1.11956445e-01]
[ 2.71902926e-01 1.86999462e-01]
[ 3.39649101e-01 1.54430706e-02]
[ 3.41955960e-01 7.46064458e-02]
[ 3.56153446e-01 5.24854558e-02]
[ 3.69893826e-01 -8.86325970e-03]
[ 3.78749506e-01 3.08027885e-02]
[ 3.89323347e-01 -2.29636922e-02]
[ 3.97756477e-01 -4.23058548e-02]
[ 3.97790750e-01 -9.93102174e-02]
[ 4.19082447e-01 -2.77471188e-02]
[ 4.26755751e-01 -5.27212418e-02]
[ 4.27342016e-01 -1.04779774e-01]
[ 4.37333586e-01 -1.06015726e-01]
[ 4.54778403e-01 -6.91129800e-02]
[ 4.33582415e-01 -1.81400909e-01]
[ 4.75290298e-01 -6.70755731e-02]
[ 4.36148818e-01 -2.23325341e-01]
[ 2.88929307e-01 -4.08068444e-01]
[ 4.69656878e-01 -1.98802456e-01]
[ 3.86773022e-01 -3.47572480e-01]
[ 5.00808621e-01 -1.73466783e-01]
[ 4.24431168e-01 -3.33853536e-01]
[ 4.24040623e-01 -3.50270681e-01]
[ 2.80512809e-01 -4.84677794e-01]
[ 5.30623098e-01 -2.08180518e-01]
[ 4.11038907e-01 -4.09202904e-01]
[ 4.62137267e-01 -3.66782151e-01]
[ 3.95281189e-01 -4.51389834e-01]
[ 3.67613445e-01 -4.86785738e-01]
[ 4.94556678e-01 -3.73916691e-01]
[ 4.35611632e-01 -4.55129109e-01]
[ 3.28234290e-01 -5.49419922e-01]
[ 2.78014129e-01 -5.87544164e-01]
[ 3.04441838e-01 -5.85589590e-01]
[-3.84353415e-02 -6.68896647e-01]
[ 4.32663510e-01 -5.24597262e-01]
[ 3.67995385e-01 -5.83677477e-01]
[ 1.62513204e-01 -6.80874040e-01]
[-7.01870180e-02 -7.06522316e-01]
[ 1.20922889e-01 -7.09772960e-01]
[ 1.81179805e-01 -7.07159019e-01]
[ 4.37091210e-02 -7.38708002e-01]
[ 1.03491018e-01 -7.42825423e-01]
[-6.40535178e-03 -7.59973007e-01]
[-1.70362786e-01 -7.50917120e-01]
[ 1.37916439e-01 -7.67710268e-01]
[-1.54966573e-02 -7.89847994e-01]
[-4.01784946e-02 -7.98990418e-01]
[-2.98916375e-01 -7.52827338e-01]
[-1.88135777e-01 -7.98125886e-01]
[-1.16656887e-03 -8.29999180e-01]
[-2.86409662e-01 -7.89664173e-01]
[-1.45414274e-01 -8.37469217e-01]
[-3.08244630e-01 -8.02860666e-01]
[-1.13882388e-01 -8.62514233e-01]
[-1.97807480e-01 -8.57480146e-01]
[-7.28153677e-01 -5.11754065e-01]
[-5.28735647e-01 -7.28312169e-01]
[-1.95859389e-01 -8.88672662e-01]
[-6.55599644e-01 -6.45437144e-01]
[-4.36321749e-01 -8.21293694e-01]
[-3.97220716e-01 -8.51948181e-01]
[-7.40823025e-01 -5.94711060e-01]
[-7.83135284e-01 -5.55246906e-01]
[-6.83380750e-01 -6.88397233e-01]
[-6.79385602e-01 -7.06282665e-01]
[-5.87009159e-01 -7.97195238e-01]
[-0.00000000e+00 -0.00000000e+00]
[-8.37020532e-03 -5.47171481e-03]
[-1.83996723e-02 -7.83913638e-03]
[-2.38615572e-02 -1.81831264e-02]
[-2.71639022e-02 -2.93619212e-02]
[-4.67712876e-02 -1.76761608e-02]
[-5.13932286e-02 -3.09634633e-02]
[-6.72785631e-02 -1.93286044e-02]
[-7.58815643e-02 -2.53374861e-02]
[-8.87183570e-02 -1.51345014e-02]
[-8.95920555e-02 -4.44214317e-02]
[-1.05681839e-01 -3.05180093e-02]
[-1.12859458e-01 -4.07767431e-02]
[-1.29959580e-01 -3.24155301e-03]
[-1.39813880e-01 -7.21657356e-03]
[-1.45884653e-01 -3.48951009e-02]
[-1.58860475e-01 -1.90617278e-02]
[-1.69789215e-01 8.46300817e-03]
[-1.79892182e-01 6.22920756e-03]
[-1.82531530e-01 5.27469495e-02]
[-1.96787914e-01 3.57003745e-02]
[-1.78610713e-01 1.10445521e-01]
[-2.14327911e-01 4.96341281e-02]
[-2.23948847e-01 5.24110098e-02]
[-2.39863331e-01 8.09829103e-03]
[-2.45011435e-01 4.96930264e-02]
[-2.33354905e-01 1.14653775e-01]
[-2.50877484e-01 9.98022449e-02]
[-2.77281919e-01 3.89196302e-02]
[-2.70172382e-01 1.05389203e-01]
[-2.92387849e-01 6.71516607e-02]
[-2.73972209e-01 1.45049056e-01]
[-2.85671818e-01 1.44192969e-01]
[-3.21248452e-01 7.54945832e-02]
[-3.07272511e-01 1.45545883e-01]
[-2.39745837e-01 2.54993988e-01]
[-2.48112212e-01 2.60845415e-01]
[-1.51453586e-01 3.37582303e-01]
[-2.67068820e-01 2.70322484e-01]
[-2.89146403e-01 2.61714267e-01]
[-1.38404573e-01 3.75292118e-01]
[-2.47585805e-01 3.26804635e-01]
[-2.18612860e-01 3.58620158e-01]
[-1.25241675e-01 4.11356929e-01]
[-1.95264902e-01 3.94298895e-01]
[-1.46778711e-01 4.25389245e-01]
[-1.04988763e-01 4.47858638e-01]
[-2.68359413e-01 3.85853892e-01]
[-1.93548031e-01 4.39248403e-01]
[-8.77361039e-02 4.82081296e-01]
[-1.39120457e-01 4.80255659e-01]
[-8.31852186e-02 5.03170169e-01]
[-7.48872073e-02 5.14579349e-01]
[-5.70803544e-03 5.29969262e-01]
[-1.58687180e-02 5.39766786e-01]
[-1.49002564e-01 5.29431994e-01]
[-1.98435901e-03 5.59996484e-01]
[-2.08488575e-01 5.30502134e-01]
[ 5.65227669e-02 5.77239272e-01]
[-2.89358604e-02 5.89290010e-01]
[ 1.25355588e-01 5.86758874e-01]
[ 6.26495621e-02 6.06774285e-01]
[ 3.01335207e-01 5.41846005e-01]
[ 2.75775684e-01 5.66434261e-01]
[ 1.32352421e-01 6.26165183e-01]
[ 3.47342639e-01 5.49411586e-01]
[ 3.13540745e-01 5.80768630e-01]
[-1.44067966e-01 6.54327457e-01]
[ 5.02643212e-01 4.57984499e-01]
[ 3.70493915e-01 5.82094716e-01]
[ 3.49928740e-01 6.06258919e-01]
[ 4.93501165e-01 5.10447451e-01]
[ 3.63212538e-01 6.21672464e-01]
[ 6.62295009e-01 3.07026581e-01]
[ 3.18207549e-01 6.68089781e-01]
[ 4.46081933e-01 6.02918659e-01]
[ 6.18103628e-01 4.42207988e-01]
[ 6.07553924e-01 4.73052037e-01]
[ 5.05555512e-01 5.93981165e-01]
[ 2.72157190e-01 7.41640387e-01]
[ 5.70347440e-01 5.60984668e-01]
[ 6.05299121e-01 5.38249918e-01]
[ 6.46809095e-01 5.04021819e-01]
[ 7.10706203e-01 4.28715165e-01]
[ 7.35105615e-01 4.06472305e-01]
[ 6.74187768e-01 5.17659012e-01]
[ 8.14670518e-01 2.75521227e-01]
[ 8.45117378e-01 2.06583196e-01]
[ 6.40506934e-01 6.03449142e-01]
[ 8.66120175e-01 2.04782425e-01]
[ 8.96181358e-01 8.28189190e-02]
[ 8.33525542e-01 3.65150888e-01]
[ 9.18575330e-01 5.11797183e-02]
[ 9.22965423e-01 1.14170172e-01]
[ 9.36555210e-01 8.04011093e-02]
[ 9.08727618e-01 2.76973131e-01]
[ 9.50896208e-01 -1.31895416e-01]
[ 9.17784255e-01 3.13961879e-01]
[ 9.33629259e-01 -2.97886567e-01]
[ 9.88995443e-01 -4.45871499e-02]
[ 0.00000000e+00 -0.00000000e+00]
[ 9.79880762e-03 -1.99583797e-03]
[ 1.99940729e-02 4.86876255e-04]
[ 2.61662525e-02 -1.46740326e-02]
[ 3.84511087e-02 -1.10232589e-02]
[ 4.95890467e-02 -6.39737823e-03]
[ 5.99771637e-02 -1.65524386e-03]
[ 6.21697472e-02 -3.21702119e-02]
[ 7.43436411e-02 -2.95469632e-02]
[ 8.06974118e-02 -3.98488109e-02]
[ 8.14343751e-02 -5.80382853e-02]
[ 9.22001492e-02 -5.99927703e-02]
[ 9.37395802e-02 -7.49192306e-02]
[ 9.60790524e-02 -8.75717745e-02]
[ 1.03261762e-01 -9.45357529e-02]
[ 1.06370805e-01 -1.05760351e-01]
[ 1.37466576e-01 -8.18714884e-02]
[ 1.01082273e-01 -1.36683481e-01]
[ 1.10151505e-01 -1.42360971e-01]
[ 9.16540054e-02 -1.66431798e-01]
[ 1.29709392e-01 -1.52234929e-01]
[ 3.26863586e-02 -2.07440599e-01]
[ 1.37426305e-01 -1.71796422e-01]
[ 1.50106931e-01 -1.74263907e-01]
[ 1.32737490e-01 -1.99951891e-01]
[ 1.34246622e-01 -2.10897711e-01]
[ 9.69068794e-02 -2.41265532e-01]
[ 2.21194549e-02 -2.69092419e-01]
[ 1.03753054e-01 -2.60067883e-01]
[ 3.87351563e-02 -2.87401440e-01]
[ 1.55776942e-01 -2.56385539e-01]
[ 5.05915167e-02 -3.05843912e-01]
[ 4.49940024e-02 -3.16820990e-01]
[-1.74914339e-02 -3.29536113e-01]
[ 4.91708033e-03 -3.39964443e-01]
[-4.84363632e-02 -3.46632253e-01]
[ 1.94660931e-02 -3.59473325e-01]
[-1.00885395e-01 -3.55980529e-01]
[ 2.46037929e-03 -3.79992035e-01]
[-4.68117997e-02 -3.87180391e-01]
[-9.43358132e-02 -3.88716805e-01]
[-3.36669158e-02 -4.08615392e-01]
[-1.88529982e-01 -3.75308468e-01]
[-5.06039869e-02 -4.27011986e-01]
[-1.62425981e-01 -4.08922732e-01]
[-2.12648211e-01 -3.96586357e-01]
[-8.03418780e-02 -4.52929556e-01]
[-1.52934025e-01 -4.44422304e-01]
[-2.65614079e-01 -3.99811407e-01]
[-2.74970348e-01 -4.05575280e-01]
[-2.34573442e-01 -4.41560076e-01]
[-3.78612687e-01 -3.41690551e-01]
[-3.34460262e-01 -3.98166213e-01]
[-2.65222370e-01 -4.58865007e-01]
[-3.32119042e-01 -4.25789786e-01]
[-4.12557377e-01 -3.63725735e-01]
[-4.02002778e-01 -3.89863779e-01]
[-3.04770028e-01 -4.81679593e-01]
[-4.50761110e-01 -3.64985508e-01]
[-5.66743279e-01 -1.64018461e-01]
[-4.15392351e-01 -4.32954033e-01]
[-4.91811803e-01 -3.60861678e-01]
[-5.48909912e-01 -2.88267079e-01]
[-5.87812164e-01 -2.26664641e-01]
[-5.79543657e-01 -2.71531122e-01]
[-5.86563807e-01 -2.80076598e-01]
[-6.00326136e-01 -2.74241738e-01]
[-5.67589542e-01 -3.56008584e-01]
[-6.11252234e-01 -2.97944133e-01]
[-6.62681829e-01 -1.92231095e-01]
[-6.99434809e-01 2.81237977e-02]
[-6.63528162e-01 -2.52646745e-01]
[-7.07437708e-01 -1.33910002e-01]
[-6.95635703e-01 -2.21339037e-01]
[-7.39522482e-01 -2.65800316e-02]
[-7.41761456e-01 1.10860010e-01]
[-7.31852772e-01 2.04918327e-01]
[-7.62821244e-01 1.04898761e-01]
[-7.78110642e-01 5.42570570e-02]
[-7.10909578e-01 3.44539651e-01]
[-7.84879826e-01 1.54801999e-01]
[-7.75097682e-01 2.35209657e-01]
[-8.09110142e-01 1.33194513e-01]
[-7.12117393e-01 4.26367000e-01]
[-7.84729601e-01 2.99665570e-01]
[-8.49662414e-01 -2.39537407e-02]
[-7.11847617e-01 4.82569136e-01]
[-7.47488621e-01 4.45152515e-01]
[-7.42609239e-01 4.72156244e-01]
[-8.43868125e-01 2.82819001e-01]
[-6.50907748e-01 6.21545737e-01]
[-8.11713922e-01 4.11364205e-01]
[-8.46795454e-01 3.59635174e-01]
[-8.35929064e-01 4.07581402e-01]
[-9.22350700e-01 1.81298612e-01]
[-4.56405752e-01 8.33182927e-01]
[-6.30262690e-01 7.24133235e-01]
[-5.75705887e-01 7.80680941e-01]
[-2.06074574e-01 9.58088341e-01]
[-5.97431027e-01 7.89415080e-01]]
[[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[1 0 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]]
此时,x是输入数据, t是监督标签。观察x和t的形状,可知它们各自有300笔样本数据,其中x是二维数据,t是三维数据。另外,t是one-hot向量,对应的正确解标签的类标记为1,其余的标记为0。下面,我们把这些数据绘制在图上,结果如图1-31所示。
# 绘制数据点
N = 100
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()
如图1-31所示,输入是二维数据,类别数是3。观察这个数据集可知,它不能被直线分割。因此,我们需要学习非线性的分割线。那么,我们的神经网络(具有使用非线性的sigmoid激活函数的隐藏层的神经网络)能否正确学习这种非线性模式呢?让我们实验一下。
Tip:数据集划分
因为这个实验相对简单,所以我们不把数据集分成训练数据、验证数据和测试数据。不过,实际任务中会将数据集分为训练数据和测试数据(以及验证数据)来进行学习和评估。
import sys
sys.path.append('..')
import numpy as np
from common.layers import Affine,Sigmoid,SoftmaxWithLoss
class TwoLayerNet:
def __init__(self,input_size,hidden_size,output_size):
I,H,O = input_size,hidden_size,output_size
# 初始化权重和偏置
W1 = 0.01*np.random.randn(I,H)# (1,I)*(I,H)
b1 = np.zeros(H)
W2 = 0.01*np.random.randn(H,O)# (1,H)*(H,O)
b2 = np.zeros(O)
# 生成层
self.layers = [
Affine(W1,b1),
Sigmoid(),
Affine(W2,b2)
]
# 损失函数
self.loss_layer = SoftmaxWithLoss()
# 将所有的权重和梯度整理到列表中
self.params, self.grads = [],[]
for layer in self.layers:
self.params += layer.params
self.grads += layer.grads
def predict(self,x):
for layer in self.layers:
x = layer.forward(x)
return x
def forward(self,x,t):
score = self.predict(x)
loss = self.loss_layer.forward(score,t)
return loss
def backward(self,dout=1):
dout = self.loss_layer.backward(dout)
for layer in reversed(self.layers):# 反向传播
dout = layer.backward(dout)
return dout
初始化程序接收3个参数。input_size是输入层的神经元数,hidden_size是隐藏层的神经元数,output_size是输出层的神经元数。在内部实现中,首先用零向量(np.zeros())初始化偏置,再用小的随机数(0.01 *np.random.randn())初始化权重。通过将权重设成小的随机数,学习可以更容易地进行。接着,生成必要的层,并将它们整理到实例变量layers列表中。最后,将这个模型使用到的参数和梯度归纳在一起。
学习使用的代码
# 学习使用的代码
import sys
sys.path.append('..')
import numpy as np
from common.optimizer import SGD
from dataset import spiral
import matplotlib.pyplot as plt
# from two_layer_net import TwoLayerNet
# 设定超参数
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0
# 读入数据,生成模型和优化器
x, t = spiral.load_data()
model = TwoLayerNet(input_size=2, hidden_size=hidden_size, output_size=3)
optimizer = SGD(lr=learning_rate)
# 学习用的变量
data_size = len(x)
max_iters = data_size // batch_size # 每个epoch中的批次数
total_loss = 0
loss_count = 0
loss_list = []
for epoch in range(max_epoch):
# 打乱数据
idx = np.random.permutation(data_size)
x = x[idx] # data sample
t = t[idx] # data label
for iters in range(max_iters):
batch_x = x[iters * batch_size:(iters + 1) * batch_size]
batch_t = t[iters * batch_size:(iters + 1) * batch_size]
# 确保batch_x和batch_t都是二维数组
if batch_x.ndim == 1:
batch_x = batch_x.reshape(1, -1) # 如果是1D数组,变成2D(1, 特征数)
if batch_t.ndim == 1:
batch_t = batch_t.reshape(1, -1) # 如果是1D数组,变成2D(1, 类别数)
loss = model.forward(batch_x, batch_t)
model.backward()
optimizer.update(model.params, model.grads)
total_loss += loss
loss_count += 1
# 定期输出学习过程
if (iters + 1) % 10 == 0:
avg_loss = total_loss / loss_count
print('|epoch %d | iter %d / %d | loss %.2f' % (epoch + 1, iters + 1, max_iters, avg_loss))
loss_list.append(avg_loss)
total_loss, loss_count = 0, 0
|epoch 1 | iter 10 / 10 | loss 1.13
|epoch 2 | iter 10 / 10 | loss 1.13
|epoch 3 | iter 10 / 10 | loss 1.12
|epoch 4 | iter 10 / 10 | loss 1.12
|epoch 5 | iter 10 / 10 | loss 1.11
|epoch 6 | iter 10 / 10 | loss 1.14
|epoch 7 | iter 10 / 10 | loss 1.16
|epoch 8 | iter 10 / 10 | loss 1.11
|epoch 9 | iter 10 / 10 | loss 1.12
|epoch 10 | iter 10 / 10 | loss 1.13
|epoch 11 | iter 10 / 10 | loss 1.12
|epoch 12 | iter 10 / 10 | loss 1.11
|epoch 13 | iter 10 / 10 | loss 1.09
|epoch 14 | iter 10 / 10 | loss 1.08
|epoch 15 | iter 10 / 10 | loss 1.04
|epoch 16 | iter 10 / 10 | loss 1.03
|epoch 17 | iter 10 / 10 | loss 0.96
|epoch 18 | iter 10 / 10 | loss 0.92
|epoch 19 | iter 10 / 10 | loss 0.92
|epoch 20 | iter 10 / 10 | loss 0.87
|epoch 21 | iter 10 / 10 | loss 0.85
|epoch 22 | iter 10 / 10 | loss 0.82
|epoch 23 | iter 10 / 10 | loss 0.79
|epoch 24 | iter 10 / 10 | loss 0.78
|epoch 25 | iter 10 / 10 | loss 0.82
|epoch 26 | iter 10 / 10 | loss 0.78
|epoch 27 | iter 10 / 10 | loss 0.76
|epoch 28 | iter 10 / 10 | loss 0.76
|epoch 29 | iter 10 / 10 | loss 0.78
|epoch 30 | iter 10 / 10 | loss 0.75
|epoch 31 | iter 10 / 10 | loss 0.78
|epoch 32 | iter 10 / 10 | loss 0.77
|epoch 33 | iter 10 / 10 | loss 0.77
|epoch 34 | iter 10 / 10 | loss 0.78
|epoch 35 | iter 10 / 10 | loss 0.75
|epoch 36 | iter 10 / 10 | loss 0.74
|epoch 37 | iter 10 / 10 | loss 0.76
|epoch 38 | iter 10 / 10 | loss 0.76
|epoch 39 | iter 10 / 10 | loss 0.73
|epoch 40 | iter 10 / 10 | loss 0.75
|epoch 41 | iter 10 / 10 | loss 0.76
|epoch 42 | iter 10 / 10 | loss 0.76
|epoch 43 | iter 10 / 10 | loss 0.76
|epoch 44 | iter 10 / 10 | loss 0.74
|epoch 45 | iter 10 / 10 | loss 0.75
|epoch 46 | iter 10 / 10 | loss 0.73
|epoch 47 | iter 10 / 10 | loss 0.72
|epoch 48 | iter 10 / 10 | loss 0.73
|epoch 49 | iter 10 / 10 | loss 0.72
|epoch 50 | iter 10 / 10 | loss 0.72
|epoch 51 | iter 10 / 10 | loss 0.72
|epoch 52 | iter 10 / 10 | loss 0.72
|epoch 53 | iter 10 / 10 | loss 0.74
|epoch 54 | iter 10 / 10 | loss 0.74
|epoch 55 | iter 10 / 10 | loss 0.72
|epoch 56 | iter 10 / 10 | loss 0.72
|epoch 57 | iter 10 / 10 | loss 0.71
|epoch 58 | iter 10 / 10 | loss 0.70
|epoch 59 | iter 10 / 10 | loss 0.72
|epoch 60 | iter 10 / 10 | loss 0.70
|epoch 61 | iter 10 / 10 | loss 0.71
|epoch 62 | iter 10 / 10 | loss 0.72
|epoch 63 | iter 10 / 10 | loss 0.70
|epoch 64 | iter 10 / 10 | loss 0.71
|epoch 65 | iter 10 / 10 | loss 0.73
|epoch 66 | iter 10 / 10 | loss 0.70
|epoch 67 | iter 10 / 10 | loss 0.71
|epoch 68 | iter 10 / 10 | loss 0.69
|epoch 69 | iter 10 / 10 | loss 0.70
|epoch 70 | iter 10 / 10 | loss 0.71
|epoch 71 | iter 10 / 10 | loss 0.68
|epoch 72 | iter 10 / 10 | loss 0.69
|epoch 73 | iter 10 / 10 | loss 0.67
|epoch 74 | iter 10 / 10 | loss 0.68
|epoch 75 | iter 10 / 10 | loss 0.67
|epoch 76 | iter 10 / 10 | loss 0.66
|epoch 77 | iter 10 / 10 | loss 0.69
|epoch 78 | iter 10 / 10 | loss 0.64
|epoch 79 | iter 10 / 10 | loss 0.68
|epoch 80 | iter 10 / 10 | loss 0.64
|epoch 81 | iter 10 / 10 | loss 0.64
|epoch 82 | iter 10 / 10 | loss 0.66
|epoch 83 | iter 10 / 10 | loss 0.62
|epoch 84 | iter 10 / 10 | loss 0.62
|epoch 85 | iter 10 / 10 | loss 0.61
|epoch 86 | iter 10 / 10 | loss 0.60
|epoch 87 | iter 10 / 10 | loss 0.60
|epoch 88 | iter 10 / 10 | loss 0.61
|epoch 89 | iter 10 / 10 | loss 0.59
|epoch 90 | iter 10 / 10 | loss 0.58
|epoch 91 | iter 10 / 10 | loss 0.56
|epoch 92 | iter 10 / 10 | loss 0.56
|epoch 93 | iter 10 / 10 | loss 0.54
|epoch 94 | iter 10 / 10 | loss 0.53
|epoch 95 | iter 10 / 10 | loss 0.53
|epoch 96 | iter 10 / 10 | loss 0.52
|epoch 97 | iter 10 / 10 | loss 0.51
|epoch 98 | iter 10 / 10 | loss 0.50
|epoch 99 | iter 10 / 10 | loss 0.48
|epoch 100 | iter 10 / 10 | loss 0.48
|epoch 101 | iter 10 / 10 | loss 0.46
|epoch 102 | iter 10 / 10 | loss 0.45
|epoch 103 | iter 10 / 10 | loss 0.45
|epoch 104 | iter 10 / 10 | loss 0.44
|epoch 105 | iter 10 / 10 | loss 0.44
|epoch 106 | iter 10 / 10 | loss 0.41
|epoch 107 | iter 10 / 10 | loss 0.40
|epoch 108 | iter 10 / 10 | loss 0.41
|epoch 109 | iter 10 / 10 | loss 0.40
|epoch 110 | iter 10 / 10 | loss 0.40
|epoch 111 | iter 10 / 10 | loss 0.38
|epoch 112 | iter 10 / 10 | loss 0.38
|epoch 113 | iter 10 / 10 | loss 0.36
|epoch 114 | iter 10 / 10 | loss 0.37
|epoch 115 | iter 10 / 10 | loss 0.35
|epoch 116 | iter 10 / 10 | loss 0.34
|epoch 117 | iter 10 / 10 | loss 0.34
|epoch 118 | iter 10 / 10 | loss 0.34
|epoch 119 | iter 10 / 10 | loss 0.33
|epoch 120 | iter 10 / 10 | loss 0.34
|epoch 121 | iter 10 / 10 | loss 0.32
|epoch 122 | iter 10 / 10 | loss 0.32
|epoch 123 | iter 10 / 10 | loss 0.31
|epoch 124 | iter 10 / 10 | loss 0.31
|epoch 125 | iter 10 / 10 | loss 0.30
|epoch 126 | iter 10 / 10 | loss 0.30
|epoch 127 | iter 10 / 10 | loss 0.28
|epoch 128 | iter 10 / 10 | loss 0.28
|epoch 129 | iter 10 / 10 | loss 0.28
|epoch 130 | iter 10 / 10 | loss 0.28
|epoch 131 | iter 10 / 10 | loss 0.27
|epoch 132 | iter 10 / 10 | loss 0.27
|epoch 133 | iter 10 / 10 | loss 0.27
|epoch 134 | iter 10 / 10 | loss 0.27
|epoch 135 | iter 10 / 10 | loss 0.27
|epoch 136 | iter 10 / 10 | loss 0.26
|epoch 137 | iter 10 / 10 | loss 0.26
|epoch 138 | iter 10 / 10 | loss 0.26
|epoch 139 | iter 10 / 10 | loss 0.25
|epoch 140 | iter 10 / 10 | loss 0.24
|epoch 141 | iter 10 / 10 | loss 0.24
|epoch 142 | iter 10 / 10 | loss 0.25
|epoch 143 | iter 10 / 10 | loss 0.24
|epoch 144 | iter 10 / 10 | loss 0.24
|epoch 145 | iter 10 / 10 | loss 0.23
|epoch 146 | iter 10 / 10 | loss 0.24
|epoch 147 | iter 10 / 10 | loss 0.23
|epoch 148 | iter 10 / 10 | loss 0.23
|epoch 149 | iter 10 / 10 | loss 0.22
|epoch 150 | iter 10 / 10 | loss 0.22
|epoch 151 | iter 10 / 10 | loss 0.22
|epoch 152 | iter 10 / 10 | loss 0.22
|epoch 153 | iter 10 / 10 | loss 0.22
|epoch 154 | iter 10 / 10 | loss 0.22
|epoch 155 | iter 10 / 10 | loss 0.22
|epoch 156 | iter 10 / 10 | loss 0.21
|epoch 157 | iter 10 / 10 | loss 0.21
|epoch 158 | iter 10 / 10 | loss 0.20
|epoch 159 | iter 10 / 10 | loss 0.21
|epoch 160 | iter 10 / 10 | loss 0.20
|epoch 161 | iter 10 / 10 | loss 0.20
|epoch 162 | iter 10 / 10 | loss 0.20
|epoch 163 | iter 10 / 10 | loss 0.21
|epoch 164 | iter 10 / 10 | loss 0.20
|epoch 165 | iter 10 / 10 | loss 0.20
|epoch 166 | iter 10 / 10 | loss 0.19
|epoch 167 | iter 10 / 10 | loss 0.19
|epoch 168 | iter 10 / 10 | loss 0.19
|epoch 169 | iter 10 / 10 | loss 0.19
|epoch 170 | iter 10 / 10 | loss 0.19
|epoch 171 | iter 10 / 10 | loss 0.19
|epoch 172 | iter 10 / 10 | loss 0.18
|epoch 173 | iter 10 / 10 | loss 0.18
|epoch 174 | iter 10 / 10 | loss 0.18
|epoch 175 | iter 10 / 10 | loss 0.18
|epoch 176 | iter 10 / 10 | loss 0.18
|epoch 177 | iter 10 / 10 | loss 0.18
|epoch 178 | iter 10 / 10 | loss 0.18
|epoch 179 | iter 10 / 10 | loss 0.17
|epoch 180 | iter 10 / 10 | loss 0.17
|epoch 181 | iter 10 / 10 | loss 0.18
|epoch 182 | iter 10 / 10 | loss 0.17
|epoch 183 | iter 10 / 10 | loss 0.18
|epoch 184 | iter 10 / 10 | loss 0.17
|epoch 185 | iter 10 / 10 | loss 0.17
|epoch 186 | iter 10 / 10 | loss 0.18
|epoch 187 | iter 10 / 10 | loss 0.17
|epoch 188 | iter 10 / 10 | loss 0.17
|epoch 189 | iter 10 / 10 | loss 0.17
|epoch 190 | iter 10 / 10 | loss 0.17
|epoch 191 | iter 10 / 10 | loss 0.16
|epoch 192 | iter 10 / 10 | loss 0.17
|epoch 193 | iter 10 / 10 | loss 0.16
|epoch 194 | iter 10 / 10 | loss 0.16
|epoch 195 | iter 10 / 10 | loss 0.16
|epoch 196 | iter 10 / 10 | loss 0.16
|epoch 197 | iter 10 / 10 | loss 0.16
|epoch 198 | iter 10 / 10 | loss 0.15
|epoch 199 | iter 10 / 10 | loss 0.16
|epoch 200 | iter 10 / 10 | loss 0.16
|epoch 201 | iter 10 / 10 | loss 0.15
|epoch 202 | iter 10 / 10 | loss 0.16
|epoch 203 | iter 10 / 10 | loss 0.16
|epoch 204 | iter 10 / 10 | loss 0.15
|epoch 205 | iter 10 / 10 | loss 0.16
|epoch 206 | iter 10 / 10 | loss 0.15
|epoch 207 | iter 10 / 10 | loss 0.15
|epoch 208 | iter 10 / 10 | loss 0.15
|epoch 209 | iter 10 / 10 | loss 0.15
|epoch 210 | iter 10 / 10 | loss 0.15
|epoch 211 | iter 10 / 10 | loss 0.15
|epoch 212 | iter 10 / 10 | loss 0.15
|epoch 213 | iter 10 / 10 | loss 0.15
|epoch 214 | iter 10 / 10 | loss 0.15
|epoch 215 | iter 10 / 10 | loss 0.15
|epoch 216 | iter 10 / 10 | loss 0.14
|epoch 217 | iter 10 / 10 | loss 0.14
|epoch 218 | iter 10 / 10 | loss 0.15
|epoch 219 | iter 10 / 10 | loss 0.14
|epoch 220 | iter 10 / 10 | loss 0.14
|epoch 221 | iter 10 / 10 | loss 0.14
|epoch 222 | iter 10 / 10 | loss 0.14
|epoch 223 | iter 10 / 10 | loss 0.14
|epoch 224 | iter 10 / 10 | loss 0.14
|epoch 225 | iter 10 / 10 | loss 0.14
|epoch 226 | iter 10 / 10 | loss 0.14
|epoch 227 | iter 10 / 10 | loss 0.14
|epoch 228 | iter 10 / 10 | loss 0.14
|epoch 229 | iter 10 / 10 | loss 0.13
|epoch 230 | iter 10 / 10 | loss 0.14
|epoch 231 | iter 10 / 10 | loss 0.13
|epoch 232 | iter 10 / 10 | loss 0.14
|epoch 233 | iter 10 / 10 | loss 0.13
|epoch 234 | iter 10 / 10 | loss 0.13
|epoch 235 | iter 10 / 10 | loss 0.13
|epoch 236 | iter 10 / 10 | loss 0.13
|epoch 237 | iter 10 / 10 | loss 0.14
|epoch 238 | iter 10 / 10 | loss 0.13
|epoch 239 | iter 10 / 10 | loss 0.13
|epoch 240 | iter 10 / 10 | loss 0.14
|epoch 241 | iter 10 / 10 | loss 0.13
|epoch 242 | iter 10 / 10 | loss 0.13
|epoch 243 | iter 10 / 10 | loss 0.13
|epoch 244 | iter 10 / 10 | loss 0.13
|epoch 245 | iter 10 / 10 | loss 0.13
|epoch 246 | iter 10 / 10 | loss 0.13
|epoch 247 | iter 10 / 10 | loss 0.13
|epoch 248 | iter 10 / 10 | loss 0.13
|epoch 249 | iter 10 / 10 | loss 0.13
|epoch 250 | iter 10 / 10 | loss 0.13
|epoch 251 | iter 10 / 10 | loss 0.13
|epoch 252 | iter 10 / 10 | loss 0.12
|epoch 253 | iter 10 / 10 | loss 0.12
|epoch 254 | iter 10 / 10 | loss 0.12
|epoch 255 | iter 10 / 10 | loss 0.12
|epoch 256 | iter 10 / 10 | loss 0.12
|epoch 257 | iter 10 / 10 | loss 0.12
|epoch 258 | iter 10 / 10 | loss 0.12
|epoch 259 | iter 10 / 10 | loss 0.13
|epoch 260 | iter 10 / 10 | loss 0.12
|epoch 261 | iter 10 / 10 | loss 0.13
|epoch 262 | iter 10 / 10 | loss 0.12
|epoch 263 | iter 10 / 10 | loss 0.12
|epoch 264 | iter 10 / 10 | loss 0.13
|epoch 265 | iter 10 / 10 | loss 0.12
|epoch 266 | iter 10 / 10 | loss 0.12
|epoch 267 | iter 10 / 10 | loss 0.12
|epoch 268 | iter 10 / 10 | loss 0.12
|epoch 269 | iter 10 / 10 | loss 0.11
|epoch 270 | iter 10 / 10 | loss 0.12
|epoch 271 | iter 10 / 10 | loss 0.12
|epoch 272 | iter 10 / 10 | loss 0.12
|epoch 273 | iter 10 / 10 | loss 0.12
|epoch 274 | iter 10 / 10 | loss 0.12
|epoch 275 | iter 10 / 10 | loss 0.11
|epoch 276 | iter 10 / 10 | loss 0.12
|epoch 277 | iter 10 / 10 | loss 0.12
|epoch 278 | iter 10 / 10 | loss 0.11
|epoch 279 | iter 10 / 10 | loss 0.11
|epoch 280 | iter 10 / 10 | loss 0.11
|epoch 281 | iter 10 / 10 | loss 0.11
|epoch 282 | iter 10 / 10 | loss 0.12
|epoch 283 | iter 10 / 10 | loss 0.11
|epoch 284 | iter 10 / 10 | loss 0.11
|epoch 285 | iter 10 / 10 | loss 0.11
|epoch 286 | iter 10 / 10 | loss 0.11
|epoch 287 | iter 10 / 10 | loss 0.11
|epoch 288 | iter 10 / 10 | loss 0.12
|epoch 289 | iter 10 / 10 | loss 0.11
|epoch 290 | iter 10 / 10 | loss 0.11
|epoch 291 | iter 10 / 10 | loss 0.11
|epoch 292 | iter 10 / 10 | loss 0.11
|epoch 293 | iter 10 / 10 | loss 0.11
|epoch 294 | iter 10 / 10 | loss 0.11
|epoch 295 | iter 10 / 10 | loss 0.12
|epoch 296 | iter 10 / 10 | loss 0.11
|epoch 297 | iter 10 / 10 | loss 0.12
|epoch 298 | iter 10 / 10 | loss 0.11
|epoch 299 | iter 10 / 10 | loss 0.11
|epoch 300 | iter 10 / 10 | loss 0.11
Tip:epoch
Epoch表示学习的单位。1个epoch相当于模型“看过”一遍所有的学习数据(遍历数据集)。这里我们进行300个epoch的学习。
Tip:数据打乱
在进行学习时,需要随机选择数据作为mini-batch。这里,我们以epoch为单位打乱数据,对于打乱后的数据,按顺序从头开始抽取数据。数据的打乱(准确地说,是数据索引的打乱)使用np.random.permutation()方法。给定参数N,该方法可以返回0到N-1的随机序列,其实际的使用示例如下所示。
import numpy as np
np.random.permutation(10)
array([5, 1, 8, 4, 9, 7, 0, 2, 6, 3])
np.random.permutation(10)
array([3, 4, 2, 7, 8, 5, 6, 0, 9, 1])
像这样,调用np.random.permutation()可以随机打乱数据的索引。
loss_list
[1.1256062166823237,
1.1255202354489933,
1.1162613752115285,
1.1162867078413503,
1.1123000112951948,
1.1384639824108038,
1.1590961883070312,
1.1086316143023154,
1.1173305676924539,
1.1287957712269248,
1.1168438089353867,
1.108338779101816,
1.0876149200499459,
1.076681386581935,
1.0442376735950387,
1.0345782626337772,
0.9572932039643971,
0.918385321087945,
0.9241491096212101,
0.8685139076509195,
0.849380704784154,
0.8171629191788113,
0.7924414711357766,
0.7826646392986113,
0.8235432039035636,
0.7754573601774306,
0.7557857636797779,
0.7644773546985875,
0.783489908441849,
0.7507895610696304,
0.7773067036165259,
0.7650839562418821,
0.7727897179944694,
0.7819402998382252,
0.7479802970891092,
0.7449918634368045,
0.7560347486336814,
0.762136567723541,
0.7308895411004578,
0.7530268898576871,
0.7598416342022494,
0.7594443798911804,
0.7609245612331736,
0.7385235003122192,
0.7483287079215573,
0.732212322256757,
0.7226947264484566,
0.7329453633807874,
0.7228591003722222,
0.7225109819294906,
0.7151355271138069,
0.7195462325887142,
0.7375188235491987,
0.7361580823220941,
0.7224648808897995,
0.7182891638960814,
0.7074840271194414,
0.7004036126944774,
0.7172821745385198,
0.7014167269154442,
0.7139798052248814,
0.7158390973991744,
0.70241458479173,
0.7147655630827306,
0.7258385981107364,
0.6991952628756113,
0.7149037812469097,
0.6923793329208154,
0.6950715496129248,
0.7051743201139319,
0.6818892201279282,
0.6931081236194652,
0.6678843529136961,
0.6795690012596001,
0.6696918781908965,
0.6601032915888689,
0.6948944014779308,
0.644871400994823,
0.6797970357896568,
0.6389928717097755,
0.6352100394457484,
0.6642182679001878,
0.6194764020308783,
0.6229674573767083,
0.6125107706517634,
0.5971911133520729,
0.5988198466130645,
0.6083885307835829,
0.5881132092407987,
0.5773251645421092,
0.5573829192639743,
0.5604590214223,
0.5448990532547215,
0.5286015850570895,
0.5299314123468276,
0.5163065862636679,
0.5124308579159358,
0.5016440362124143,
0.4844736823098785,
0.48000592169629375,
0.4609934880592183,
0.45402197600669847,
0.45357141775905935,
0.4429929465269392,
0.43615541714274964,
0.41049348011889714,
0.4031641890220195,
0.4076112682225584,
0.4036238903780629,
0.39951469523134286,
0.37921674077822315,
0.3750471397806644,
0.36229498494049545,
0.3675710439808242,
0.34616510526107874,
0.344898000290116,
0.33528933507001024,
0.33807418815033274,
0.3283438143916972,
0.33519658210523795,
0.31583944005785075,
0.3174067835957746,
0.30558094499258315,
0.31306743220376837,
0.3001822128388422,
0.298741892547424,
0.28380294707062764,
0.2811840756517869,
0.28183411256599333,
0.2777145993871101,
0.27238825698634855,
0.26707712316069976,
0.27094530362207153,
0.27057245386599776,
0.27030885972727675,
0.26437260107230176,
0.2637830554224593,
0.2560501993977538,
0.2521541449844564,
0.2429111293445622,
0.23945524042724556,
0.24542262746858579,
0.2366753717699007,
0.2403193106192818,
0.23361288810496253,
0.23650963179911216,
0.23025407656786476,
0.22932871901830215,
0.2240169314658727,
0.22394676535578756,
0.22310757814713048,
0.22264097958236592,
0.21697683471580703,
0.21755951288880998,
0.21507058662677148,
0.20835755321509958,
0.20913544783239285,
0.20480456033012645,
0.20860699412184833,
0.19815557489605248,
0.20435571889234488,
0.19724020746228565,
0.21198898405557265,
0.19990055286420488,
0.19670798433788012,
0.19493732434411776,
0.18861275602513147,
0.1871358291328788,
0.18592571707387812,
0.18837376894472496,
0.18999546718467414,
0.18273072019428158,
0.18441370204922142,
0.1833058740625836,
0.1802345548091979,
0.17511064852464456,
0.17767416217643664,
0.1766872863613579,
0.17293847154024714,
0.17328551329855016,
0.17653792738499427,
0.1728261261811141,
0.17626893090372214,
0.17434628901735752,
0.16949071085989292,
0.17755351335716013,
0.17285695698153453,
0.16648514397547962,
0.16594153132977602,
0.16931139951263377,
0.16301106121602332,
0.16958890572080493,
0.161906824311852,
0.15939537909039864,
0.16286032857961993,
0.1612419736005516,
0.16364914579280607,
0.1527678767991844,
0.15970328019708874,
0.15747782851517947,
0.1546769422097933,
0.1569377065521722,
0.1554765278672734,
0.15394834058798243,
0.15758367150781283,
0.15348488267419602,
0.15446293102249137,
0.1476797718663996,
0.15141489293752314,
0.1493564627042096,
0.1517051939074469,
0.14755917786826186,
0.14785145742055836,
0.14515980626797986,
0.14678399383604507,
0.1438836332008838,
0.14259768661059344,
0.1455518479651135,
0.13922422236572785,
0.13730944343188073,
0.14410598841352643,
0.1405077822816634,
0.14224984249809627,
0.14254640666896867,
0.13667202381979077,
0.1365901190048486,
0.14251665205070402,
0.1360865718323748,
0.1290996692132135,
0.13883107120690055,
0.13142337073882102,
0.13594951747222708,
0.13429273376232015,
0.13144069789855423,
0.13148269177878147,
0.12860673738419168,
0.135392048601573,
0.13223101774593124,
0.13134628823538033,
0.1378159605341089,
0.13367286871605075,
0.12860547570667225,
0.12775015207856294,
0.1300224354664546,
0.13145974901706844,
0.12796605893873647,
0.12643439113307112,
0.13201041906408498,
0.12779009853793596,
0.1304032884799826,
0.13348456295442618,
0.12453963049125043,
0.12132675691111519,
0.12286008452064738,
0.12267013423577518,
0.12312488940516728,
0.12127839505918629,
0.12467328584597182,
0.13195709803216454,
0.1202665723651631,
0.12500316552637772,
0.12062063523861206,
0.1181079868686632,
0.12681552271035718,
0.11807400896605273,
0.12014005000495154,
0.11993006486808415,
0.11732517384704762,
0.1136047759790468,
0.11888250117808816,
0.11633511040515612,
0.11825391183559267,
0.11623703569727495,
0.12322514116418926,
0.11326334261276325,
0.11580125295481199,
0.12418527047451162,
0.11289494501453376,
0.11077561346473126,
0.11431217925067252,
0.11427378474471335,
0.11731530918730074,
0.1117398474559641,
0.113824373282757,
0.11354569974260224,
0.11082773737319436,
0.11235645391470028,
0.1169893924945046,
0.11394781289859857,
0.11472703869653125,
0.11195179439782446,
0.1125271442536377,
0.11019011632052719,
0.11147947028025722,
0.11578316334144989,
0.110222826957397,
0.11966849814776921,
0.11160268411696476,
0.10668719692739766,
0.10854954795724314]
# 绘制损失曲线
plt.plot(loss_list)
plt.xlabel('Epoch (every 10 epochs)')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.grid()
plt.show()
Trainer类
如前所述,本书中有很多机会执行神经网络的学习。为此,就需要编写前面那样的学习用的代码。然而,每次都写相同的代码太无聊了,因此我们将进行学习的类作为Trainer类提供出来。Trainer类的内部实现和刚才的源代码几乎相同,只是添加了一些新的功能而已,我们在需要的时候再详细说明其用法。
这个类的初始化程序接收神经网络(模型)和优化器,具体如下所示。
# model = TwoLayerNet(...)
# optimizer = SGD(lr=1.0)
# trainer = Trainer(model,optimizer)
然后,调用fit()方法开始学习。fit()方法的参数如表1-1所示。
另外,Trainer类有plot()方法,它将fit()方法记录的损失(准确地说,是按照eval_interval评价的平均损失)在图上画出来。使用Trainer类进行学习的代码如下所示
import sys
sys.path.append('..')
from common.optimizer import SGD
from common.trainer import Trainer
from dataset import spiral
# from two_layer_net import TwoLayerNet
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0
x,t = spiral.load_data()
model = TwoLayerNet(input_size=2,hidden_size=hidden_size,output_size=3)
optimizer = SGD(lr=learning_rate)
trainer = Trainer(model,optimizer)
trainer.fit(x,t,max_epoch,batch_size,eval_interval=10)
trainer.plot()
| epoch 1 | iter 1 / 10 | time 0[s] | loss 1.10
| epoch 2 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 3 | iter 1 / 10 | time 0[s] | loss 1.13
| epoch 4 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 5 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 6 | iter 1 / 10 | time 0[s] | loss 1.10
| epoch 7 | iter 1 / 10 | time 0[s] | loss 1.14
| epoch 8 | iter 1 / 10 | time 0[s] | loss 1.16
| epoch 9 | iter 1 / 10 | time 0[s] | loss 1.11
| epoch 10 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 11 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 12 | iter 1 / 10 | time 0[s] | loss 1.12
| epoch 13 | iter 1 / 10 | time 0[s] | loss 1.10
| epoch 14 | iter 1 / 10 | time 0[s] | loss 1.09
| epoch 15 | iter 1 / 10 | time 0[s] | loss 1.08
| epoch 16 | iter 1 / 10 | time 0[s] | loss 1.04
| epoch 17 | iter 1 / 10 | time 0[s] | loss 1.03
| epoch 18 | iter 1 / 10 | time 0[s] | loss 0.94
| epoch 19 | iter 1 / 10 | time 0[s] | loss 0.92
| epoch 20 | iter 1 / 10 | time 0[s] | loss 0.92
| epoch 21 | iter 1 / 10 | time 0[s] | loss 0.87
| epoch 22 | iter 1 / 10 | time 0[s] | loss 0.85
| epoch 23 | iter 1 / 10 | time 0[s] | loss 0.80
| epoch 24 | iter 1 / 10 | time 0[s] | loss 0.79
| epoch 25 | iter 1 / 10 | time 0[s] | loss 0.78
| epoch 26 | iter 1 / 10 | time 0[s] | loss 0.83
| epoch 27 | iter 1 / 10 | time 0[s] | loss 0.77
| epoch 28 | iter 1 / 10 | time 0[s] | loss 0.76
| epoch 29 | iter 1 / 10 | time 0[s] | loss 0.77
| epoch 30 | iter 1 / 10 | time 0[s] | loss 0.76
| epoch 31 | iter 1 / 10 | time 0[s] | loss 0.77
| epoch 32 | iter 1 / 10 | time 0[s] | loss 0.75
| epoch 33 | iter 1 / 10 | time 0[s] | loss 0.78
| epoch 34 | iter 1 / 10 | time 0[s] | loss 0.77
| epoch 35 | iter 1 / 10 | time 0[s] | loss 0.78
| epoch 36 | iter 1 / 10 | time 0[s] | loss 0.74
| epoch 37 | iter 1 / 10 | time 0[s] | loss 0.75
| epoch 38 | iter 1 / 10 | time 0[s] | loss 0.77
| epoch 39 | iter 1 / 10 | time 0[s] | loss 0.75
| epoch 40 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 41 | iter 1 / 10 | time 0[s] | loss 0.75
| epoch 42 | iter 1 / 10 | time 0[s] | loss 0.76
| epoch 43 | iter 1 / 10 | time 0[s] | loss 0.79
| epoch 44 | iter 1 / 10 | time 0[s] | loss 0.74
| epoch 45 | iter 1 / 10 | time 0[s] | loss 0.75
| epoch 46 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 47 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 48 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 49 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 50 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 51 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 52 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 53 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 54 | iter 1 / 10 | time 0[s] | loss 0.74
| epoch 55 | iter 1 / 10 | time 0[s] | loss 0.74
| epoch 56 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 57 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 58 | iter 1 / 10 | time 0[s] | loss 0.69
| epoch 59 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 60 | iter 1 / 10 | time 0[s] | loss 0.70
| epoch 61 | iter 1 / 10 | time 0[s] | loss 0.69
| epoch 62 | iter 1 / 10 | time 0[s] | loss 0.71
| epoch 63 | iter 1 / 10 | time 0[s] | loss 0.70
| epoch 64 | iter 1 / 10 | time 0[s] | loss 0.71
| epoch 65 | iter 1 / 10 | time 0[s] | loss 0.72
| epoch 66 | iter 1 / 10 | time 0[s] | loss 0.71
| epoch 67 | iter 1 / 10 | time 0[s] | loss 0.71
| epoch 68 | iter 1 / 10 | time 0[s] | loss 0.71
| epoch 69 | iter 1 / 10 | time 0[s] | loss 0.70
| epoch 70 | iter 1 / 10 | time 0[s] | loss 0.68
| epoch 71 | iter 1 / 10 | time 0[s] | loss 0.73
| epoch 72 | iter 1 / 10 | time 0[s] | loss 0.66
| epoch 73 | iter 1 / 10 | time 0[s] | loss 0.69
| epoch 74 | iter 1 / 10 | time 0[s] | loss 0.66
| epoch 75 | iter 1 / 10 | time 0[s] | loss 0.70
| epoch 76 | iter 1 / 10 | time 0[s] | loss 0.65
| epoch 77 | iter 1 / 10 | time 0[s] | loss 0.67
| epoch 78 | iter 1 / 10 | time 0[s] | loss 0.70
| epoch 79 | iter 1 / 10 | time 0[s] | loss 0.63
| epoch 80 | iter 1 / 10 | time 0[s] | loss 0.66
| epoch 81 | iter 1 / 10 | time 0[s] | loss 0.65
| epoch 82 | iter 1 / 10 | time 0[s] | loss 0.66
| epoch 83 | iter 1 / 10 | time 0[s] | loss 0.64
| epoch 84 | iter 1 / 10 | time 0[s] | loss 0.62
| epoch 85 | iter 1 / 10 | time 0[s] | loss 0.62
| epoch 86 | iter 1 / 10 | time 0[s] | loss 0.63
| epoch 87 | iter 1 / 10 | time 0[s] | loss 0.59
| epoch 88 | iter 1 / 10 | time 0[s] | loss 0.58
| epoch 89 | iter 1 / 10 | time 0[s] | loss 0.61
| epoch 90 | iter 1 / 10 | time 0[s] | loss 0.59
| epoch 91 | iter 1 / 10 | time 0[s] | loss 0.58
| epoch 92 | iter 1 / 10 | time 0[s] | loss 0.57
| epoch 93 | iter 1 / 10 | time 0[s] | loss 0.55
| epoch 94 | iter 1 / 10 | time 0[s] | loss 0.54
| epoch 95 | iter 1 / 10 | time 0[s] | loss 0.53
| epoch 96 | iter 1 / 10 | time 0[s] | loss 0.54
| epoch 97 | iter 1 / 10 | time 0[s] | loss 0.51
| epoch 98 | iter 1 / 10 | time 0[s] | loss 0.51
| epoch 99 | iter 1 / 10 | time 0[s] | loss 0.50
| epoch 100 | iter 1 / 10 | time 0[s] | loss 0.47
| epoch 101 | iter 1 / 10 | time 0[s] | loss 0.49
| epoch 102 | iter 1 / 10 | time 0[s] | loss 0.46
| epoch 103 | iter 1 / 10 | time 0[s] | loss 0.44
| epoch 104 | iter 1 / 10 | time 0[s] | loss 0.47
| epoch 105 | iter 1 / 10 | time 0[s] | loss 0.44
| epoch 106 | iter 1 / 10 | time 0[s] | loss 0.43
| epoch 107 | iter 1 / 10 | time 0[s] | loss 0.43
| epoch 108 | iter 1 / 10 | time 0[s] | loss 0.39
| epoch 109 | iter 1 / 10 | time 0[s] | loss 0.40
| epoch 110 | iter 1 / 10 | time 0[s] | loss 0.41
| epoch 111 | iter 1 / 10 | time 0[s] | loss 0.38
| epoch 112 | iter 1 / 10 | time 0[s] | loss 0.38
| epoch 113 | iter 1 / 10 | time 0[s] | loss 0.38
| epoch 114 | iter 1 / 10 | time 0[s] | loss 0.37
| epoch 115 | iter 1 / 10 | time 0[s] | loss 0.36
| epoch 116 | iter 1 / 10 | time 0[s] | loss 0.34
| epoch 117 | iter 1 / 10 | time 0[s] | loss 0.35
| epoch 118 | iter 1 / 10 | time 0[s] | loss 0.33
| epoch 119 | iter 1 / 10 | time 0[s] | loss 0.35
| epoch 120 | iter 1 / 10 | time 0[s] | loss 0.33
| epoch 121 | iter 1 / 10 | time 0[s] | loss 0.33
| epoch 122 | iter 1 / 10 | time 0[s] | loss 0.32
| epoch 123 | iter 1 / 10 | time 0[s] | loss 0.31
| epoch 124 | iter 1 / 10 | time 0[s] | loss 0.31
| epoch 125 | iter 1 / 10 | time 0[s] | loss 0.31
| epoch 126 | iter 1 / 10 | time 0[s] | loss 0.30
| epoch 127 | iter 1 / 10 | time 0[s] | loss 0.30
| epoch 128 | iter 1 / 10 | time 0[s] | loss 0.27
| epoch 129 | iter 1 / 10 | time 0[s] | loss 0.30
| epoch 130 | iter 1 / 10 | time 0[s] | loss 0.28
| epoch 131 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 132 | iter 1 / 10 | time 0[s] | loss 0.27
| epoch 133 | iter 1 / 10 | time 0[s] | loss 0.27
| epoch 134 | iter 1 / 10 | time 0[s] | loss 0.28
| epoch 135 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 136 | iter 1 / 10 | time 0[s] | loss 0.28
| epoch 137 | iter 1 / 10 | time 0[s] | loss 0.25
| epoch 138 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 139 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 140 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 141 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 142 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 143 | iter 1 / 10 | time 0[s] | loss 0.26
| epoch 144 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 145 | iter 1 / 10 | time 0[s] | loss 0.24
| epoch 146 | iter 1 / 10 | time 0[s] | loss 0.24
| epoch 147 | iter 1 / 10 | time 0[s] | loss 0.25
| epoch 148 | iter 1 / 10 | time 0[s] | loss 0.21
| epoch 149 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 150 | iter 1 / 10 | time 0[s] | loss 0.22
| epoch 151 | iter 1 / 10 | time 0[s] | loss 0.22
| epoch 152 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 153 | iter 1 / 10 | time 0[s] | loss 0.23
| epoch 154 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 155 | iter 1 / 10 | time 0[s] | loss 0.22
| epoch 156 | iter 1 / 10 | time 0[s] | loss 0.21
| epoch 157 | iter 1 / 10 | time 0[s] | loss 0.21
| epoch 158 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 159 | iter 1 / 10 | time 0[s] | loss 0.21
| epoch 160 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 161 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 162 | iter 1 / 10 | time 0[s] | loss 0.22
| epoch 163 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 164 | iter 1 / 10 | time 0[s] | loss 0.21
| epoch 165 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 166 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 167 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 168 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 169 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 170 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 171 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 172 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 173 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 174 | iter 1 / 10 | time 0[s] | loss 0.20
| epoch 175 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 176 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 177 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 178 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 179 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 180 | iter 1 / 10 | time 0[s] | loss 0.19
| epoch 181 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 182 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 183 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 184 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 185 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 186 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 187 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 188 | iter 1 / 10 | time 0[s] | loss 0.18
| epoch 189 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 190 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 191 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 192 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 193 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 194 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 195 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 196 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 197 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 198 | iter 1 / 10 | time 0[s] | loss 0.17
| epoch 199 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 200 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 201 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 202 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 203 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 204 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 205 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 206 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 207 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 208 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 209 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 210 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 211 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 212 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 213 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 214 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 215 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 216 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 217 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 218 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 219 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 220 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 221 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 222 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 223 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 224 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 225 | iter 1 / 10 | time 0[s] | loss 0.16
| epoch 226 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 227 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 228 | iter 1 / 10 | time 0[s] | loss 0.15
| epoch 229 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 230 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 231 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 232 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 233 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 234 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 235 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 236 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 237 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 238 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 239 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 240 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 241 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 242 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 243 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 244 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 245 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 246 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 247 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 248 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 249 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 250 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 251 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 252 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 253 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 254 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 255 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 256 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 257 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 258 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 259 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 260 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 261 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 262 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 263 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 264 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 265 | iter 1 / 10 | time 0[s] | loss 0.14
| epoch 266 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 267 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 268 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 269 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 270 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 271 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 272 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 273 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 274 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 275 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 276 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 277 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 278 | iter 1 / 10 | time 0[s] | loss 0.13
| epoch 279 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 280 | iter 1 / 10 | time 0[s] | loss 0.10
| epoch 281 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 282 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 283 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 284 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 285 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 286 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 287 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 288 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 289 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 290 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 291 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 292 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 293 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 294 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 295 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 296 | iter 1 / 10 | time 0[s] | loss 0.12
| epoch 297 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 298 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 299 | iter 1 / 10 | time 0[s] | loss 0.11
| epoch 300 | iter 1 / 10 | time 0[s] | loss 0.11
执行这段代码,会进行和之前一样的神经网络的学习。通过将之前展示的学习用的代码交给Trainer类负责,代码变简洁了。本书今后都将使用Trainer类进行学习。
Tip-高速化计算
安装NVIDIA Cuda管理器,并使用Cupy调用GPU进行告高速计算。
https://baijiahao.baidu.com/s?id=1781877547348944869&wfr=spider&for=pc