Bootstrap

【一起学NLP】Chapter3-使用神经网络解决问题

使用神经网络解决问题

import sys
sys.path.append('..') # 为了引入父目录的文件而进行的设定
from dataset import spiral
import matplotlib.pyplot as plt

x,t = spiral.load_data()
print('x',x.shape)# (300,2)
print('y',t.shape)# (300,3)

x (300, 2)
y (300, 3)

在上面的例子中,要从ch01目录的dataset目录引入spiral.py。因此,上面的代码通过sys.path.append(‘…’)将父目录添加到了import的检索路径中。

print(x)
print(t)
[[-0.00000000e+00  0.00000000e+00]
 [-9.76986432e-04  9.95216044e-03]
 [ 5.12668241e-03  1.93317647e-02]
 [-3.86043324e-04  2.99975161e-02]
 [ 1.42509650e-02  3.73752591e-02]
 [ 9.41914082e-04  4.99911272e-02]
 [ 2.25361319e-02  5.56068589e-02]
 [ 6.52848904e-03  6.96948982e-02]
 [ 2.50649535e-02  7.59720219e-02]
 [ 2.03287580e-02  8.76740646e-02]
 [ 5.98440862e-02  8.01166983e-02]
 [ 6.19050693e-02  9.09272368e-02]
 [ 3.22809763e-02  1.15576549e-01]
 [ 8.28423530e-02  1.00185551e-01]
 [ 1.09856959e-01  8.67839183e-02]
 [ 9.33208222e-02  1.17436043e-01]
 [ 7.82976217e-02  1.39533087e-01]
 [ 1.23994559e-01  1.16298535e-01]
 [ 8.06199110e-02  1.60936105e-01]
 [ 1.39235917e-01  1.29280158e-01]
 [ 1.53599653e-01  1.28090384e-01]
 [ 1.38981638e-01  1.57429680e-01]
 [ 1.89873275e-01  1.11122183e-01]
 [ 1.41160729e-01  1.81586477e-01]
 [ 1.50631465e-01  1.86842613e-01]
 [ 1.71714639e-01  1.81697778e-01]
 [ 2.10050449e-01  1.53227963e-01]
 [ 2.32019716e-01  1.38082770e-01]
 [ 2.31219125e-01  1.57916802e-01]
 [ 2.51798226e-01  1.43866791e-01]
 [ 2.59928236e-01  1.49790894e-01]
 [ 3.03168799e-01  6.47200053e-02]
 [ 2.99776174e-01  1.11956445e-01]
 [ 2.71902926e-01  1.86999462e-01]
 [ 3.39649101e-01  1.54430706e-02]
 [ 3.41955960e-01  7.46064458e-02]
 [ 3.56153446e-01  5.24854558e-02]
 [ 3.69893826e-01 -8.86325970e-03]
 [ 3.78749506e-01  3.08027885e-02]
 [ 3.89323347e-01 -2.29636922e-02]
 [ 3.97756477e-01 -4.23058548e-02]
 [ 3.97790750e-01 -9.93102174e-02]
 [ 4.19082447e-01 -2.77471188e-02]
 [ 4.26755751e-01 -5.27212418e-02]
 [ 4.27342016e-01 -1.04779774e-01]
 [ 4.37333586e-01 -1.06015726e-01]
 [ 4.54778403e-01 -6.91129800e-02]
 [ 4.33582415e-01 -1.81400909e-01]
 [ 4.75290298e-01 -6.70755731e-02]
 [ 4.36148818e-01 -2.23325341e-01]
 [ 2.88929307e-01 -4.08068444e-01]
 [ 4.69656878e-01 -1.98802456e-01]
 [ 3.86773022e-01 -3.47572480e-01]
 [ 5.00808621e-01 -1.73466783e-01]
 [ 4.24431168e-01 -3.33853536e-01]
 [ 4.24040623e-01 -3.50270681e-01]
 [ 2.80512809e-01 -4.84677794e-01]
 [ 5.30623098e-01 -2.08180518e-01]
 [ 4.11038907e-01 -4.09202904e-01]
 [ 4.62137267e-01 -3.66782151e-01]
 [ 3.95281189e-01 -4.51389834e-01]
 [ 3.67613445e-01 -4.86785738e-01]
 [ 4.94556678e-01 -3.73916691e-01]
 [ 4.35611632e-01 -4.55129109e-01]
 [ 3.28234290e-01 -5.49419922e-01]
 [ 2.78014129e-01 -5.87544164e-01]
 [ 3.04441838e-01 -5.85589590e-01]
 [-3.84353415e-02 -6.68896647e-01]
 [ 4.32663510e-01 -5.24597262e-01]
 [ 3.67995385e-01 -5.83677477e-01]
 [ 1.62513204e-01 -6.80874040e-01]
 [-7.01870180e-02 -7.06522316e-01]
 [ 1.20922889e-01 -7.09772960e-01]
 [ 1.81179805e-01 -7.07159019e-01]
 [ 4.37091210e-02 -7.38708002e-01]
 [ 1.03491018e-01 -7.42825423e-01]
 [-6.40535178e-03 -7.59973007e-01]
 [-1.70362786e-01 -7.50917120e-01]
 [ 1.37916439e-01 -7.67710268e-01]
 [-1.54966573e-02 -7.89847994e-01]
 [-4.01784946e-02 -7.98990418e-01]
 [-2.98916375e-01 -7.52827338e-01]
 [-1.88135777e-01 -7.98125886e-01]
 [-1.16656887e-03 -8.29999180e-01]
 [-2.86409662e-01 -7.89664173e-01]
 [-1.45414274e-01 -8.37469217e-01]
 [-3.08244630e-01 -8.02860666e-01]
 [-1.13882388e-01 -8.62514233e-01]
 [-1.97807480e-01 -8.57480146e-01]
 [-7.28153677e-01 -5.11754065e-01]
 [-5.28735647e-01 -7.28312169e-01]
 [-1.95859389e-01 -8.88672662e-01]
 [-6.55599644e-01 -6.45437144e-01]
 [-4.36321749e-01 -8.21293694e-01]
 [-3.97220716e-01 -8.51948181e-01]
 [-7.40823025e-01 -5.94711060e-01]
 [-7.83135284e-01 -5.55246906e-01]
 [-6.83380750e-01 -6.88397233e-01]
 [-6.79385602e-01 -7.06282665e-01]
 [-5.87009159e-01 -7.97195238e-01]
 [-0.00000000e+00 -0.00000000e+00]
 [-8.37020532e-03 -5.47171481e-03]
 [-1.83996723e-02 -7.83913638e-03]
 [-2.38615572e-02 -1.81831264e-02]
 [-2.71639022e-02 -2.93619212e-02]
 [-4.67712876e-02 -1.76761608e-02]
 [-5.13932286e-02 -3.09634633e-02]
 [-6.72785631e-02 -1.93286044e-02]
 [-7.58815643e-02 -2.53374861e-02]
 [-8.87183570e-02 -1.51345014e-02]
 [-8.95920555e-02 -4.44214317e-02]
 [-1.05681839e-01 -3.05180093e-02]
 [-1.12859458e-01 -4.07767431e-02]
 [-1.29959580e-01 -3.24155301e-03]
 [-1.39813880e-01 -7.21657356e-03]
 [-1.45884653e-01 -3.48951009e-02]
 [-1.58860475e-01 -1.90617278e-02]
 [-1.69789215e-01  8.46300817e-03]
 [-1.79892182e-01  6.22920756e-03]
 [-1.82531530e-01  5.27469495e-02]
 [-1.96787914e-01  3.57003745e-02]
 [-1.78610713e-01  1.10445521e-01]
 [-2.14327911e-01  4.96341281e-02]
 [-2.23948847e-01  5.24110098e-02]
 [-2.39863331e-01  8.09829103e-03]
 [-2.45011435e-01  4.96930264e-02]
 [-2.33354905e-01  1.14653775e-01]
 [-2.50877484e-01  9.98022449e-02]
 [-2.77281919e-01  3.89196302e-02]
 [-2.70172382e-01  1.05389203e-01]
 [-2.92387849e-01  6.71516607e-02]
 [-2.73972209e-01  1.45049056e-01]
 [-2.85671818e-01  1.44192969e-01]
 [-3.21248452e-01  7.54945832e-02]
 [-3.07272511e-01  1.45545883e-01]
 [-2.39745837e-01  2.54993988e-01]
 [-2.48112212e-01  2.60845415e-01]
 [-1.51453586e-01  3.37582303e-01]
 [-2.67068820e-01  2.70322484e-01]
 [-2.89146403e-01  2.61714267e-01]
 [-1.38404573e-01  3.75292118e-01]
 [-2.47585805e-01  3.26804635e-01]
 [-2.18612860e-01  3.58620158e-01]
 [-1.25241675e-01  4.11356929e-01]
 [-1.95264902e-01  3.94298895e-01]
 [-1.46778711e-01  4.25389245e-01]
 [-1.04988763e-01  4.47858638e-01]
 [-2.68359413e-01  3.85853892e-01]
 [-1.93548031e-01  4.39248403e-01]
 [-8.77361039e-02  4.82081296e-01]
 [-1.39120457e-01  4.80255659e-01]
 [-8.31852186e-02  5.03170169e-01]
 [-7.48872073e-02  5.14579349e-01]
 [-5.70803544e-03  5.29969262e-01]
 [-1.58687180e-02  5.39766786e-01]
 [-1.49002564e-01  5.29431994e-01]
 [-1.98435901e-03  5.59996484e-01]
 [-2.08488575e-01  5.30502134e-01]
 [ 5.65227669e-02  5.77239272e-01]
 [-2.89358604e-02  5.89290010e-01]
 [ 1.25355588e-01  5.86758874e-01]
 [ 6.26495621e-02  6.06774285e-01]
 [ 3.01335207e-01  5.41846005e-01]
 [ 2.75775684e-01  5.66434261e-01]
 [ 1.32352421e-01  6.26165183e-01]
 [ 3.47342639e-01  5.49411586e-01]
 [ 3.13540745e-01  5.80768630e-01]
 [-1.44067966e-01  6.54327457e-01]
 [ 5.02643212e-01  4.57984499e-01]
 [ 3.70493915e-01  5.82094716e-01]
 [ 3.49928740e-01  6.06258919e-01]
 [ 4.93501165e-01  5.10447451e-01]
 [ 3.63212538e-01  6.21672464e-01]
 [ 6.62295009e-01  3.07026581e-01]
 [ 3.18207549e-01  6.68089781e-01]
 [ 4.46081933e-01  6.02918659e-01]
 [ 6.18103628e-01  4.42207988e-01]
 [ 6.07553924e-01  4.73052037e-01]
 [ 5.05555512e-01  5.93981165e-01]
 [ 2.72157190e-01  7.41640387e-01]
 [ 5.70347440e-01  5.60984668e-01]
 [ 6.05299121e-01  5.38249918e-01]
 [ 6.46809095e-01  5.04021819e-01]
 [ 7.10706203e-01  4.28715165e-01]
 [ 7.35105615e-01  4.06472305e-01]
 [ 6.74187768e-01  5.17659012e-01]
 [ 8.14670518e-01  2.75521227e-01]
 [ 8.45117378e-01  2.06583196e-01]
 [ 6.40506934e-01  6.03449142e-01]
 [ 8.66120175e-01  2.04782425e-01]
 [ 8.96181358e-01  8.28189190e-02]
 [ 8.33525542e-01  3.65150888e-01]
 [ 9.18575330e-01  5.11797183e-02]
 [ 9.22965423e-01  1.14170172e-01]
 [ 9.36555210e-01  8.04011093e-02]
 [ 9.08727618e-01  2.76973131e-01]
 [ 9.50896208e-01 -1.31895416e-01]
 [ 9.17784255e-01  3.13961879e-01]
 [ 9.33629259e-01 -2.97886567e-01]
 [ 9.88995443e-01 -4.45871499e-02]
 [ 0.00000000e+00 -0.00000000e+00]
 [ 9.79880762e-03 -1.99583797e-03]
 [ 1.99940729e-02  4.86876255e-04]
 [ 2.61662525e-02 -1.46740326e-02]
 [ 3.84511087e-02 -1.10232589e-02]
 [ 4.95890467e-02 -6.39737823e-03]
 [ 5.99771637e-02 -1.65524386e-03]
 [ 6.21697472e-02 -3.21702119e-02]
 [ 7.43436411e-02 -2.95469632e-02]
 [ 8.06974118e-02 -3.98488109e-02]
 [ 8.14343751e-02 -5.80382853e-02]
 [ 9.22001492e-02 -5.99927703e-02]
 [ 9.37395802e-02 -7.49192306e-02]
 [ 9.60790524e-02 -8.75717745e-02]
 [ 1.03261762e-01 -9.45357529e-02]
 [ 1.06370805e-01 -1.05760351e-01]
 [ 1.37466576e-01 -8.18714884e-02]
 [ 1.01082273e-01 -1.36683481e-01]
 [ 1.10151505e-01 -1.42360971e-01]
 [ 9.16540054e-02 -1.66431798e-01]
 [ 1.29709392e-01 -1.52234929e-01]
 [ 3.26863586e-02 -2.07440599e-01]
 [ 1.37426305e-01 -1.71796422e-01]
 [ 1.50106931e-01 -1.74263907e-01]
 [ 1.32737490e-01 -1.99951891e-01]
 [ 1.34246622e-01 -2.10897711e-01]
 [ 9.69068794e-02 -2.41265532e-01]
 [ 2.21194549e-02 -2.69092419e-01]
 [ 1.03753054e-01 -2.60067883e-01]
 [ 3.87351563e-02 -2.87401440e-01]
 [ 1.55776942e-01 -2.56385539e-01]
 [ 5.05915167e-02 -3.05843912e-01]
 [ 4.49940024e-02 -3.16820990e-01]
 [-1.74914339e-02 -3.29536113e-01]
 [ 4.91708033e-03 -3.39964443e-01]
 [-4.84363632e-02 -3.46632253e-01]
 [ 1.94660931e-02 -3.59473325e-01]
 [-1.00885395e-01 -3.55980529e-01]
 [ 2.46037929e-03 -3.79992035e-01]
 [-4.68117997e-02 -3.87180391e-01]
 [-9.43358132e-02 -3.88716805e-01]
 [-3.36669158e-02 -4.08615392e-01]
 [-1.88529982e-01 -3.75308468e-01]
 [-5.06039869e-02 -4.27011986e-01]
 [-1.62425981e-01 -4.08922732e-01]
 [-2.12648211e-01 -3.96586357e-01]
 [-8.03418780e-02 -4.52929556e-01]
 [-1.52934025e-01 -4.44422304e-01]
 [-2.65614079e-01 -3.99811407e-01]
 [-2.74970348e-01 -4.05575280e-01]
 [-2.34573442e-01 -4.41560076e-01]
 [-3.78612687e-01 -3.41690551e-01]
 [-3.34460262e-01 -3.98166213e-01]
 [-2.65222370e-01 -4.58865007e-01]
 [-3.32119042e-01 -4.25789786e-01]
 [-4.12557377e-01 -3.63725735e-01]
 [-4.02002778e-01 -3.89863779e-01]
 [-3.04770028e-01 -4.81679593e-01]
 [-4.50761110e-01 -3.64985508e-01]
 [-5.66743279e-01 -1.64018461e-01]
 [-4.15392351e-01 -4.32954033e-01]
 [-4.91811803e-01 -3.60861678e-01]
 [-5.48909912e-01 -2.88267079e-01]
 [-5.87812164e-01 -2.26664641e-01]
 [-5.79543657e-01 -2.71531122e-01]
 [-5.86563807e-01 -2.80076598e-01]
 [-6.00326136e-01 -2.74241738e-01]
 [-5.67589542e-01 -3.56008584e-01]
 [-6.11252234e-01 -2.97944133e-01]
 [-6.62681829e-01 -1.92231095e-01]
 [-6.99434809e-01  2.81237977e-02]
 [-6.63528162e-01 -2.52646745e-01]
 [-7.07437708e-01 -1.33910002e-01]
 [-6.95635703e-01 -2.21339037e-01]
 [-7.39522482e-01 -2.65800316e-02]
 [-7.41761456e-01  1.10860010e-01]
 [-7.31852772e-01  2.04918327e-01]
 [-7.62821244e-01  1.04898761e-01]
 [-7.78110642e-01  5.42570570e-02]
 [-7.10909578e-01  3.44539651e-01]
 [-7.84879826e-01  1.54801999e-01]
 [-7.75097682e-01  2.35209657e-01]
 [-8.09110142e-01  1.33194513e-01]
 [-7.12117393e-01  4.26367000e-01]
 [-7.84729601e-01  2.99665570e-01]
 [-8.49662414e-01 -2.39537407e-02]
 [-7.11847617e-01  4.82569136e-01]
 [-7.47488621e-01  4.45152515e-01]
 [-7.42609239e-01  4.72156244e-01]
 [-8.43868125e-01  2.82819001e-01]
 [-6.50907748e-01  6.21545737e-01]
 [-8.11713922e-01  4.11364205e-01]
 [-8.46795454e-01  3.59635174e-01]
 [-8.35929064e-01  4.07581402e-01]
 [-9.22350700e-01  1.81298612e-01]
 [-4.56405752e-01  8.33182927e-01]
 [-6.30262690e-01  7.24133235e-01]
 [-5.75705887e-01  7.80680941e-01]
 [-2.06074574e-01  9.58088341e-01]
 [-5.97431027e-01  7.89415080e-01]]
[[1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [1 0 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 1 0]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]
 [0 0 1]]

此时,x是输入数据, t是监督标签。观察x和t的形状,可知它们各自有300笔样本数据,其中x是二维数据,t是三维数据。另外,t是one-hot向量,对应的正确解标签的类标记为1,其余的标记为0。下面,我们把这些数据绘制在图上,结果如图1-31所示。

# 绘制数据点
N = 100
CLS_NUM = 3
markers = ['o', 'x', '^']
for i in range(CLS_NUM):
    plt.scatter(x[i*N:(i+1)*N, 0], x[i*N:(i+1)*N, 1], s=40, marker=markers[i])
plt.show()

在这里插入图片描述

如图1-31所示,输入是二维数据,类别数是3。观察这个数据集可知,它不能被直线分割。因此,我们需要学习非线性的分割线。那么,我们的神经网络(具有使用非线性的sigmoid激活函数的隐藏层的神经网络)能否正确学习这种非线性模式呢?让我们实验一下。

Tip:数据集划分

因为这个实验相对简单,所以我们不把数据集分成训练数据、验证数据和测试数据。不过,实际任务中会将数据集分为训练数据和测试数据(以及验证数据)来进行学习和评估。
在这里插入图片描述

import sys
sys.path.append('..')
import numpy as np
from common.layers import Affine,Sigmoid,SoftmaxWithLoss

class TwoLayerNet:
    def __init__(self,input_size,hidden_size,output_size):
        I,H,O = input_size,hidden_size,output_size
        # 初始化权重和偏置
        W1 = 0.01*np.random.randn(I,H)# (1,I)*(I,H)
        b1 = np.zeros(H)
        W2 = 0.01*np.random.randn(H,O)# (1,H)*(H,O)
        b2 = np.zeros(O)
        # 生成层
        self.layers = [
            Affine(W1,b1),
            Sigmoid(),
            Affine(W2,b2)
        ]
        # 损失函数
        self.loss_layer = SoftmaxWithLoss()
        # 将所有的权重和梯度整理到列表中
        self.params, self.grads = [],[]
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads

    def predict(self,x):
        for layer in self.layers:
            x = layer.forward(x)
        return x
    
    def forward(self,x,t):
        score = self.predict(x)
        loss = self.loss_layer.forward(score,t)
        return loss
    
    def backward(self,dout=1):
        dout = self.loss_layer.backward(dout)
        for layer in reversed(self.layers):# 反向传播
            dout = layer.backward(dout)
        return dout
    

初始化程序接收3个参数。input_size是输入层的神经元数,hidden_size是隐藏层的神经元数,output_size是输出层的神经元数。在内部实现中,首先用零向量(np.zeros(​)​)初始化偏置,再用小的随机数(0.01 *np.random.randn(​)​)初始化权重。通过将权重设成小的随机数,学习可以更容易地进行。接着,生成必要的层,并将它们整理到实例变量layers列表中。最后,将这个模型使用到的参数和梯度归纳在一起。

学习使用的代码

# 学习使用的代码
import sys
sys.path.append('..')
import numpy as np
from common.optimizer import SGD
from dataset import spiral
import matplotlib.pyplot as plt
# from two_layer_net import TwoLayerNet

# 设定超参数
max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0

# 读入数据,生成模型和优化器
x, t = spiral.load_data()
model = TwoLayerNet(input_size=2, hidden_size=hidden_size, output_size=3)
optimizer = SGD(lr=learning_rate)

# 学习用的变量
data_size = len(x)
max_iters = data_size // batch_size  # 每个epoch中的批次数
total_loss = 0
loss_count = 0
loss_list = []

for epoch in range(max_epoch):
    # 打乱数据
    idx = np.random.permutation(data_size)
    x = x[idx]  # data sample
    t = t[idx]  # data label

    for iters in range(max_iters):
        batch_x = x[iters * batch_size:(iters + 1) * batch_size]
        batch_t = t[iters * batch_size:(iters + 1) * batch_size]

        # 确保batch_x和batch_t都是二维数组
        if batch_x.ndim == 1:
            batch_x = batch_x.reshape(1, -1)  # 如果是1D数组,变成2D(1, 特征数)
        if batch_t.ndim == 1:
            batch_t = batch_t.reshape(1, -1)  # 如果是1D数组,变成2D(1, 类别数)

        loss = model.forward(batch_x, batch_t)
        model.backward()
        optimizer.update(model.params, model.grads)
        total_loss += loss
        loss_count += 1

    # 定期输出学习过程
    if (iters + 1) % 10 == 0:
        avg_loss = total_loss / loss_count
        print('|epoch %d | iter %d / %d | loss %.2f' % (epoch + 1, iters + 1, max_iters, avg_loss))
        loss_list.append(avg_loss)
        total_loss, loss_count = 0, 0
|epoch 1 | iter 10 / 10 | loss 1.13
|epoch 2 | iter 10 / 10 | loss 1.13
|epoch 3 | iter 10 / 10 | loss 1.12
|epoch 4 | iter 10 / 10 | loss 1.12
|epoch 5 | iter 10 / 10 | loss 1.11
|epoch 6 | iter 10 / 10 | loss 1.14
|epoch 7 | iter 10 / 10 | loss 1.16
|epoch 8 | iter 10 / 10 | loss 1.11
|epoch 9 | iter 10 / 10 | loss 1.12
|epoch 10 | iter 10 / 10 | loss 1.13
|epoch 11 | iter 10 / 10 | loss 1.12
|epoch 12 | iter 10 / 10 | loss 1.11
|epoch 13 | iter 10 / 10 | loss 1.09
|epoch 14 | iter 10 / 10 | loss 1.08
|epoch 15 | iter 10 / 10 | loss 1.04
|epoch 16 | iter 10 / 10 | loss 1.03
|epoch 17 | iter 10 / 10 | loss 0.96
|epoch 18 | iter 10 / 10 | loss 0.92
|epoch 19 | iter 10 / 10 | loss 0.92
|epoch 20 | iter 10 / 10 | loss 0.87
|epoch 21 | iter 10 / 10 | loss 0.85
|epoch 22 | iter 10 / 10 | loss 0.82
|epoch 23 | iter 10 / 10 | loss 0.79
|epoch 24 | iter 10 / 10 | loss 0.78
|epoch 25 | iter 10 / 10 | loss 0.82
|epoch 26 | iter 10 / 10 | loss 0.78
|epoch 27 | iter 10 / 10 | loss 0.76
|epoch 28 | iter 10 / 10 | loss 0.76
|epoch 29 | iter 10 / 10 | loss 0.78
|epoch 30 | iter 10 / 10 | loss 0.75
|epoch 31 | iter 10 / 10 | loss 0.78
|epoch 32 | iter 10 / 10 | loss 0.77
|epoch 33 | iter 10 / 10 | loss 0.77
|epoch 34 | iter 10 / 10 | loss 0.78
|epoch 35 | iter 10 / 10 | loss 0.75
|epoch 36 | iter 10 / 10 | loss 0.74
|epoch 37 | iter 10 / 10 | loss 0.76
|epoch 38 | iter 10 / 10 | loss 0.76
|epoch 39 | iter 10 / 10 | loss 0.73
|epoch 40 | iter 10 / 10 | loss 0.75
|epoch 41 | iter 10 / 10 | loss 0.76
|epoch 42 | iter 10 / 10 | loss 0.76
|epoch 43 | iter 10 / 10 | loss 0.76
|epoch 44 | iter 10 / 10 | loss 0.74
|epoch 45 | iter 10 / 10 | loss 0.75
|epoch 46 | iter 10 / 10 | loss 0.73
|epoch 47 | iter 10 / 10 | loss 0.72
|epoch 48 | iter 10 / 10 | loss 0.73
|epoch 49 | iter 10 / 10 | loss 0.72
|epoch 50 | iter 10 / 10 | loss 0.72
|epoch 51 | iter 10 / 10 | loss 0.72
|epoch 52 | iter 10 / 10 | loss 0.72
|epoch 53 | iter 10 / 10 | loss 0.74
|epoch 54 | iter 10 / 10 | loss 0.74
|epoch 55 | iter 10 / 10 | loss 0.72
|epoch 56 | iter 10 / 10 | loss 0.72
|epoch 57 | iter 10 / 10 | loss 0.71
|epoch 58 | iter 10 / 10 | loss 0.70
|epoch 59 | iter 10 / 10 | loss 0.72
|epoch 60 | iter 10 / 10 | loss 0.70
|epoch 61 | iter 10 / 10 | loss 0.71
|epoch 62 | iter 10 / 10 | loss 0.72
|epoch 63 | iter 10 / 10 | loss 0.70
|epoch 64 | iter 10 / 10 | loss 0.71
|epoch 65 | iter 10 / 10 | loss 0.73
|epoch 66 | iter 10 / 10 | loss 0.70
|epoch 67 | iter 10 / 10 | loss 0.71
|epoch 68 | iter 10 / 10 | loss 0.69
|epoch 69 | iter 10 / 10 | loss 0.70
|epoch 70 | iter 10 / 10 | loss 0.71
|epoch 71 | iter 10 / 10 | loss 0.68
|epoch 72 | iter 10 / 10 | loss 0.69
|epoch 73 | iter 10 / 10 | loss 0.67
|epoch 74 | iter 10 / 10 | loss 0.68
|epoch 75 | iter 10 / 10 | loss 0.67
|epoch 76 | iter 10 / 10 | loss 0.66
|epoch 77 | iter 10 / 10 | loss 0.69
|epoch 78 | iter 10 / 10 | loss 0.64
|epoch 79 | iter 10 / 10 | loss 0.68
|epoch 80 | iter 10 / 10 | loss 0.64
|epoch 81 | iter 10 / 10 | loss 0.64
|epoch 82 | iter 10 / 10 | loss 0.66
|epoch 83 | iter 10 / 10 | loss 0.62
|epoch 84 | iter 10 / 10 | loss 0.62
|epoch 85 | iter 10 / 10 | loss 0.61
|epoch 86 | iter 10 / 10 | loss 0.60
|epoch 87 | iter 10 / 10 | loss 0.60
|epoch 88 | iter 10 / 10 | loss 0.61
|epoch 89 | iter 10 / 10 | loss 0.59
|epoch 90 | iter 10 / 10 | loss 0.58
|epoch 91 | iter 10 / 10 | loss 0.56
|epoch 92 | iter 10 / 10 | loss 0.56
|epoch 93 | iter 10 / 10 | loss 0.54
|epoch 94 | iter 10 / 10 | loss 0.53
|epoch 95 | iter 10 / 10 | loss 0.53
|epoch 96 | iter 10 / 10 | loss 0.52
|epoch 97 | iter 10 / 10 | loss 0.51
|epoch 98 | iter 10 / 10 | loss 0.50
|epoch 99 | iter 10 / 10 | loss 0.48
|epoch 100 | iter 10 / 10 | loss 0.48
|epoch 101 | iter 10 / 10 | loss 0.46
|epoch 102 | iter 10 / 10 | loss 0.45
|epoch 103 | iter 10 / 10 | loss 0.45
|epoch 104 | iter 10 / 10 | loss 0.44
|epoch 105 | iter 10 / 10 | loss 0.44
|epoch 106 | iter 10 / 10 | loss 0.41
|epoch 107 | iter 10 / 10 | loss 0.40
|epoch 108 | iter 10 / 10 | loss 0.41
|epoch 109 | iter 10 / 10 | loss 0.40
|epoch 110 | iter 10 / 10 | loss 0.40
|epoch 111 | iter 10 / 10 | loss 0.38
|epoch 112 | iter 10 / 10 | loss 0.38
|epoch 113 | iter 10 / 10 | loss 0.36
|epoch 114 | iter 10 / 10 | loss 0.37
|epoch 115 | iter 10 / 10 | loss 0.35
|epoch 116 | iter 10 / 10 | loss 0.34
|epoch 117 | iter 10 / 10 | loss 0.34
|epoch 118 | iter 10 / 10 | loss 0.34
|epoch 119 | iter 10 / 10 | loss 0.33
|epoch 120 | iter 10 / 10 | loss 0.34
|epoch 121 | iter 10 / 10 | loss 0.32
|epoch 122 | iter 10 / 10 | loss 0.32
|epoch 123 | iter 10 / 10 | loss 0.31
|epoch 124 | iter 10 / 10 | loss 0.31
|epoch 125 | iter 10 / 10 | loss 0.30
|epoch 126 | iter 10 / 10 | loss 0.30
|epoch 127 | iter 10 / 10 | loss 0.28
|epoch 128 | iter 10 / 10 | loss 0.28
|epoch 129 | iter 10 / 10 | loss 0.28
|epoch 130 | iter 10 / 10 | loss 0.28
|epoch 131 | iter 10 / 10 | loss 0.27
|epoch 132 | iter 10 / 10 | loss 0.27
|epoch 133 | iter 10 / 10 | loss 0.27
|epoch 134 | iter 10 / 10 | loss 0.27
|epoch 135 | iter 10 / 10 | loss 0.27
|epoch 136 | iter 10 / 10 | loss 0.26
|epoch 137 | iter 10 / 10 | loss 0.26
|epoch 138 | iter 10 / 10 | loss 0.26
|epoch 139 | iter 10 / 10 | loss 0.25
|epoch 140 | iter 10 / 10 | loss 0.24
|epoch 141 | iter 10 / 10 | loss 0.24
|epoch 142 | iter 10 / 10 | loss 0.25
|epoch 143 | iter 10 / 10 | loss 0.24
|epoch 144 | iter 10 / 10 | loss 0.24
|epoch 145 | iter 10 / 10 | loss 0.23
|epoch 146 | iter 10 / 10 | loss 0.24
|epoch 147 | iter 10 / 10 | loss 0.23
|epoch 148 | iter 10 / 10 | loss 0.23
|epoch 149 | iter 10 / 10 | loss 0.22
|epoch 150 | iter 10 / 10 | loss 0.22
|epoch 151 | iter 10 / 10 | loss 0.22
|epoch 152 | iter 10 / 10 | loss 0.22
|epoch 153 | iter 10 / 10 | loss 0.22
|epoch 154 | iter 10 / 10 | loss 0.22
|epoch 155 | iter 10 / 10 | loss 0.22
|epoch 156 | iter 10 / 10 | loss 0.21
|epoch 157 | iter 10 / 10 | loss 0.21
|epoch 158 | iter 10 / 10 | loss 0.20
|epoch 159 | iter 10 / 10 | loss 0.21
|epoch 160 | iter 10 / 10 | loss 0.20
|epoch 161 | iter 10 / 10 | loss 0.20
|epoch 162 | iter 10 / 10 | loss 0.20
|epoch 163 | iter 10 / 10 | loss 0.21
|epoch 164 | iter 10 / 10 | loss 0.20
|epoch 165 | iter 10 / 10 | loss 0.20
|epoch 166 | iter 10 / 10 | loss 0.19
|epoch 167 | iter 10 / 10 | loss 0.19
|epoch 168 | iter 10 / 10 | loss 0.19
|epoch 169 | iter 10 / 10 | loss 0.19
|epoch 170 | iter 10 / 10 | loss 0.19
|epoch 171 | iter 10 / 10 | loss 0.19
|epoch 172 | iter 10 / 10 | loss 0.18
|epoch 173 | iter 10 / 10 | loss 0.18
|epoch 174 | iter 10 / 10 | loss 0.18
|epoch 175 | iter 10 / 10 | loss 0.18
|epoch 176 | iter 10 / 10 | loss 0.18
|epoch 177 | iter 10 / 10 | loss 0.18
|epoch 178 | iter 10 / 10 | loss 0.18
|epoch 179 | iter 10 / 10 | loss 0.17
|epoch 180 | iter 10 / 10 | loss 0.17
|epoch 181 | iter 10 / 10 | loss 0.18
|epoch 182 | iter 10 / 10 | loss 0.17
|epoch 183 | iter 10 / 10 | loss 0.18
|epoch 184 | iter 10 / 10 | loss 0.17
|epoch 185 | iter 10 / 10 | loss 0.17
|epoch 186 | iter 10 / 10 | loss 0.18
|epoch 187 | iter 10 / 10 | loss 0.17
|epoch 188 | iter 10 / 10 | loss 0.17
|epoch 189 | iter 10 / 10 | loss 0.17
|epoch 190 | iter 10 / 10 | loss 0.17
|epoch 191 | iter 10 / 10 | loss 0.16
|epoch 192 | iter 10 / 10 | loss 0.17
|epoch 193 | iter 10 / 10 | loss 0.16
|epoch 194 | iter 10 / 10 | loss 0.16
|epoch 195 | iter 10 / 10 | loss 0.16
|epoch 196 | iter 10 / 10 | loss 0.16
|epoch 197 | iter 10 / 10 | loss 0.16
|epoch 198 | iter 10 / 10 | loss 0.15
|epoch 199 | iter 10 / 10 | loss 0.16
|epoch 200 | iter 10 / 10 | loss 0.16
|epoch 201 | iter 10 / 10 | loss 0.15
|epoch 202 | iter 10 / 10 | loss 0.16
|epoch 203 | iter 10 / 10 | loss 0.16
|epoch 204 | iter 10 / 10 | loss 0.15
|epoch 205 | iter 10 / 10 | loss 0.16
|epoch 206 | iter 10 / 10 | loss 0.15
|epoch 207 | iter 10 / 10 | loss 0.15
|epoch 208 | iter 10 / 10 | loss 0.15
|epoch 209 | iter 10 / 10 | loss 0.15
|epoch 210 | iter 10 / 10 | loss 0.15
|epoch 211 | iter 10 / 10 | loss 0.15
|epoch 212 | iter 10 / 10 | loss 0.15
|epoch 213 | iter 10 / 10 | loss 0.15
|epoch 214 | iter 10 / 10 | loss 0.15
|epoch 215 | iter 10 / 10 | loss 0.15
|epoch 216 | iter 10 / 10 | loss 0.14
|epoch 217 | iter 10 / 10 | loss 0.14
|epoch 218 | iter 10 / 10 | loss 0.15
|epoch 219 | iter 10 / 10 | loss 0.14
|epoch 220 | iter 10 / 10 | loss 0.14
|epoch 221 | iter 10 / 10 | loss 0.14
|epoch 222 | iter 10 / 10 | loss 0.14
|epoch 223 | iter 10 / 10 | loss 0.14
|epoch 224 | iter 10 / 10 | loss 0.14
|epoch 225 | iter 10 / 10 | loss 0.14
|epoch 226 | iter 10 / 10 | loss 0.14
|epoch 227 | iter 10 / 10 | loss 0.14
|epoch 228 | iter 10 / 10 | loss 0.14
|epoch 229 | iter 10 / 10 | loss 0.13
|epoch 230 | iter 10 / 10 | loss 0.14
|epoch 231 | iter 10 / 10 | loss 0.13
|epoch 232 | iter 10 / 10 | loss 0.14
|epoch 233 | iter 10 / 10 | loss 0.13
|epoch 234 | iter 10 / 10 | loss 0.13
|epoch 235 | iter 10 / 10 | loss 0.13
|epoch 236 | iter 10 / 10 | loss 0.13
|epoch 237 | iter 10 / 10 | loss 0.14
|epoch 238 | iter 10 / 10 | loss 0.13
|epoch 239 | iter 10 / 10 | loss 0.13
|epoch 240 | iter 10 / 10 | loss 0.14
|epoch 241 | iter 10 / 10 | loss 0.13
|epoch 242 | iter 10 / 10 | loss 0.13
|epoch 243 | iter 10 / 10 | loss 0.13
|epoch 244 | iter 10 / 10 | loss 0.13
|epoch 245 | iter 10 / 10 | loss 0.13
|epoch 246 | iter 10 / 10 | loss 0.13
|epoch 247 | iter 10 / 10 | loss 0.13
|epoch 248 | iter 10 / 10 | loss 0.13
|epoch 249 | iter 10 / 10 | loss 0.13
|epoch 250 | iter 10 / 10 | loss 0.13
|epoch 251 | iter 10 / 10 | loss 0.13
|epoch 252 | iter 10 / 10 | loss 0.12
|epoch 253 | iter 10 / 10 | loss 0.12
|epoch 254 | iter 10 / 10 | loss 0.12
|epoch 255 | iter 10 / 10 | loss 0.12
|epoch 256 | iter 10 / 10 | loss 0.12
|epoch 257 | iter 10 / 10 | loss 0.12
|epoch 258 | iter 10 / 10 | loss 0.12
|epoch 259 | iter 10 / 10 | loss 0.13
|epoch 260 | iter 10 / 10 | loss 0.12
|epoch 261 | iter 10 / 10 | loss 0.13
|epoch 262 | iter 10 / 10 | loss 0.12
|epoch 263 | iter 10 / 10 | loss 0.12
|epoch 264 | iter 10 / 10 | loss 0.13
|epoch 265 | iter 10 / 10 | loss 0.12
|epoch 266 | iter 10 / 10 | loss 0.12
|epoch 267 | iter 10 / 10 | loss 0.12
|epoch 268 | iter 10 / 10 | loss 0.12
|epoch 269 | iter 10 / 10 | loss 0.11
|epoch 270 | iter 10 / 10 | loss 0.12
|epoch 271 | iter 10 / 10 | loss 0.12
|epoch 272 | iter 10 / 10 | loss 0.12
|epoch 273 | iter 10 / 10 | loss 0.12
|epoch 274 | iter 10 / 10 | loss 0.12
|epoch 275 | iter 10 / 10 | loss 0.11
|epoch 276 | iter 10 / 10 | loss 0.12
|epoch 277 | iter 10 / 10 | loss 0.12
|epoch 278 | iter 10 / 10 | loss 0.11
|epoch 279 | iter 10 / 10 | loss 0.11
|epoch 280 | iter 10 / 10 | loss 0.11
|epoch 281 | iter 10 / 10 | loss 0.11
|epoch 282 | iter 10 / 10 | loss 0.12
|epoch 283 | iter 10 / 10 | loss 0.11
|epoch 284 | iter 10 / 10 | loss 0.11
|epoch 285 | iter 10 / 10 | loss 0.11
|epoch 286 | iter 10 / 10 | loss 0.11
|epoch 287 | iter 10 / 10 | loss 0.11
|epoch 288 | iter 10 / 10 | loss 0.12
|epoch 289 | iter 10 / 10 | loss 0.11
|epoch 290 | iter 10 / 10 | loss 0.11
|epoch 291 | iter 10 / 10 | loss 0.11
|epoch 292 | iter 10 / 10 | loss 0.11
|epoch 293 | iter 10 / 10 | loss 0.11
|epoch 294 | iter 10 / 10 | loss 0.11
|epoch 295 | iter 10 / 10 | loss 0.12
|epoch 296 | iter 10 / 10 | loss 0.11
|epoch 297 | iter 10 / 10 | loss 0.12
|epoch 298 | iter 10 / 10 | loss 0.11
|epoch 299 | iter 10 / 10 | loss 0.11
|epoch 300 | iter 10 / 10 | loss 0.11

Tip:epoch

Epoch表示学习的单位。1个epoch相当于模型“看过”一遍所有的学习数据(遍历数据集)​。这里我们进行300个epoch的学习。

Tip:数据打乱

在进行学习时,需要随机选择数据作为mini-batch。这里,我们以epoch为单位打乱数据,对于打乱后的数据,按顺序从头开始抽取数据。数据的打乱(准确地说,是数据索引的打乱)使用np.random.permutation(​)方法。给定参数N,该方法可以返回0到N-1的随机序列,其实际的使用示例如下所示。

import numpy as np
np.random.permutation(10)
array([5, 1, 8, 4, 9, 7, 0, 2, 6, 3])
np.random.permutation(10)
array([3, 4, 2, 7, 8, 5, 6, 0, 9, 1])

像这样,调用np.random.permutation(​)可以随机打乱数据的索引。

loss_list
[1.1256062166823237,
 1.1255202354489933,
 1.1162613752115285,
 1.1162867078413503,
 1.1123000112951948,
 1.1384639824108038,
 1.1590961883070312,
 1.1086316143023154,
 1.1173305676924539,
 1.1287957712269248,
 1.1168438089353867,
 1.108338779101816,
 1.0876149200499459,
 1.076681386581935,
 1.0442376735950387,
 1.0345782626337772,
 0.9572932039643971,
 0.918385321087945,
 0.9241491096212101,
 0.8685139076509195,
 0.849380704784154,
 0.8171629191788113,
 0.7924414711357766,
 0.7826646392986113,
 0.8235432039035636,
 0.7754573601774306,
 0.7557857636797779,
 0.7644773546985875,
 0.783489908441849,
 0.7507895610696304,
 0.7773067036165259,
 0.7650839562418821,
 0.7727897179944694,
 0.7819402998382252,
 0.7479802970891092,
 0.7449918634368045,
 0.7560347486336814,
 0.762136567723541,
 0.7308895411004578,
 0.7530268898576871,
 0.7598416342022494,
 0.7594443798911804,
 0.7609245612331736,
 0.7385235003122192,
 0.7483287079215573,
 0.732212322256757,
 0.7226947264484566,
 0.7329453633807874,
 0.7228591003722222,
 0.7225109819294906,
 0.7151355271138069,
 0.7195462325887142,
 0.7375188235491987,
 0.7361580823220941,
 0.7224648808897995,
 0.7182891638960814,
 0.7074840271194414,
 0.7004036126944774,
 0.7172821745385198,
 0.7014167269154442,
 0.7139798052248814,
 0.7158390973991744,
 0.70241458479173,
 0.7147655630827306,
 0.7258385981107364,
 0.6991952628756113,
 0.7149037812469097,
 0.6923793329208154,
 0.6950715496129248,
 0.7051743201139319,
 0.6818892201279282,
 0.6931081236194652,
 0.6678843529136961,
 0.6795690012596001,
 0.6696918781908965,
 0.6601032915888689,
 0.6948944014779308,
 0.644871400994823,
 0.6797970357896568,
 0.6389928717097755,
 0.6352100394457484,
 0.6642182679001878,
 0.6194764020308783,
 0.6229674573767083,
 0.6125107706517634,
 0.5971911133520729,
 0.5988198466130645,
 0.6083885307835829,
 0.5881132092407987,
 0.5773251645421092,
 0.5573829192639743,
 0.5604590214223,
 0.5448990532547215,
 0.5286015850570895,
 0.5299314123468276,
 0.5163065862636679,
 0.5124308579159358,
 0.5016440362124143,
 0.4844736823098785,
 0.48000592169629375,
 0.4609934880592183,
 0.45402197600669847,
 0.45357141775905935,
 0.4429929465269392,
 0.43615541714274964,
 0.41049348011889714,
 0.4031641890220195,
 0.4076112682225584,
 0.4036238903780629,
 0.39951469523134286,
 0.37921674077822315,
 0.3750471397806644,
 0.36229498494049545,
 0.3675710439808242,
 0.34616510526107874,
 0.344898000290116,
 0.33528933507001024,
 0.33807418815033274,
 0.3283438143916972,
 0.33519658210523795,
 0.31583944005785075,
 0.3174067835957746,
 0.30558094499258315,
 0.31306743220376837,
 0.3001822128388422,
 0.298741892547424,
 0.28380294707062764,
 0.2811840756517869,
 0.28183411256599333,
 0.2777145993871101,
 0.27238825698634855,
 0.26707712316069976,
 0.27094530362207153,
 0.27057245386599776,
 0.27030885972727675,
 0.26437260107230176,
 0.2637830554224593,
 0.2560501993977538,
 0.2521541449844564,
 0.2429111293445622,
 0.23945524042724556,
 0.24542262746858579,
 0.2366753717699007,
 0.2403193106192818,
 0.23361288810496253,
 0.23650963179911216,
 0.23025407656786476,
 0.22932871901830215,
 0.2240169314658727,
 0.22394676535578756,
 0.22310757814713048,
 0.22264097958236592,
 0.21697683471580703,
 0.21755951288880998,
 0.21507058662677148,
 0.20835755321509958,
 0.20913544783239285,
 0.20480456033012645,
 0.20860699412184833,
 0.19815557489605248,
 0.20435571889234488,
 0.19724020746228565,
 0.21198898405557265,
 0.19990055286420488,
 0.19670798433788012,
 0.19493732434411776,
 0.18861275602513147,
 0.1871358291328788,
 0.18592571707387812,
 0.18837376894472496,
 0.18999546718467414,
 0.18273072019428158,
 0.18441370204922142,
 0.1833058740625836,
 0.1802345548091979,
 0.17511064852464456,
 0.17767416217643664,
 0.1766872863613579,
 0.17293847154024714,
 0.17328551329855016,
 0.17653792738499427,
 0.1728261261811141,
 0.17626893090372214,
 0.17434628901735752,
 0.16949071085989292,
 0.17755351335716013,
 0.17285695698153453,
 0.16648514397547962,
 0.16594153132977602,
 0.16931139951263377,
 0.16301106121602332,
 0.16958890572080493,
 0.161906824311852,
 0.15939537909039864,
 0.16286032857961993,
 0.1612419736005516,
 0.16364914579280607,
 0.1527678767991844,
 0.15970328019708874,
 0.15747782851517947,
 0.1546769422097933,
 0.1569377065521722,
 0.1554765278672734,
 0.15394834058798243,
 0.15758367150781283,
 0.15348488267419602,
 0.15446293102249137,
 0.1476797718663996,
 0.15141489293752314,
 0.1493564627042096,
 0.1517051939074469,
 0.14755917786826186,
 0.14785145742055836,
 0.14515980626797986,
 0.14678399383604507,
 0.1438836332008838,
 0.14259768661059344,
 0.1455518479651135,
 0.13922422236572785,
 0.13730944343188073,
 0.14410598841352643,
 0.1405077822816634,
 0.14224984249809627,
 0.14254640666896867,
 0.13667202381979077,
 0.1365901190048486,
 0.14251665205070402,
 0.1360865718323748,
 0.1290996692132135,
 0.13883107120690055,
 0.13142337073882102,
 0.13594951747222708,
 0.13429273376232015,
 0.13144069789855423,
 0.13148269177878147,
 0.12860673738419168,
 0.135392048601573,
 0.13223101774593124,
 0.13134628823538033,
 0.1378159605341089,
 0.13367286871605075,
 0.12860547570667225,
 0.12775015207856294,
 0.1300224354664546,
 0.13145974901706844,
 0.12796605893873647,
 0.12643439113307112,
 0.13201041906408498,
 0.12779009853793596,
 0.1304032884799826,
 0.13348456295442618,
 0.12453963049125043,
 0.12132675691111519,
 0.12286008452064738,
 0.12267013423577518,
 0.12312488940516728,
 0.12127839505918629,
 0.12467328584597182,
 0.13195709803216454,
 0.1202665723651631,
 0.12500316552637772,
 0.12062063523861206,
 0.1181079868686632,
 0.12681552271035718,
 0.11807400896605273,
 0.12014005000495154,
 0.11993006486808415,
 0.11732517384704762,
 0.1136047759790468,
 0.11888250117808816,
 0.11633511040515612,
 0.11825391183559267,
 0.11623703569727495,
 0.12322514116418926,
 0.11326334261276325,
 0.11580125295481199,
 0.12418527047451162,
 0.11289494501453376,
 0.11077561346473126,
 0.11431217925067252,
 0.11427378474471335,
 0.11731530918730074,
 0.1117398474559641,
 0.113824373282757,
 0.11354569974260224,
 0.11082773737319436,
 0.11235645391470028,
 0.1169893924945046,
 0.11394781289859857,
 0.11472703869653125,
 0.11195179439782446,
 0.1125271442536377,
 0.11019011632052719,
 0.11147947028025722,
 0.11578316334144989,
 0.110222826957397,
 0.11966849814776921,
 0.11160268411696476,
 0.10668719692739766,
 0.10854954795724314]
# 绘制损失曲线
plt.plot(loss_list)
plt.xlabel('Epoch (every 10 epochs)')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.grid()
plt.show()

在这里插入图片描述

Trainer类

如前所述,本书中有很多机会执行神经网络的学习。为此,就需要编写前面那样的学习用的代码。然而,每次都写相同的代码太无聊了,因此我们将进行学习的类作为Trainer类提供出来。Trainer类的内部实现和刚才的源代码几乎相同,只是添加了一些新的功能而已,我们在需要的时候再详细说明其用法。

这个类的初始化程序接收神经网络(模型)和优化器,具体如下所示。

# model = TwoLayerNet(...)
# optimizer = SGD(lr=1.0)
# trainer = Trainer(model,optimizer)

然后,调用fit(​)方法开始学习。fit(​)方法的参数如表1-1所示。
在这里插入图片描述

另外,Trainer类有plot(​)方法,它将fit(​)方法记录的损失(准确地说,是按照eval_interval评价的平均损失)在图上画出来。使用Trainer类进行学习的代码如下所示

import sys
sys.path.append('..')
from common.optimizer import SGD
from common.trainer import Trainer
from dataset import spiral
# from two_layer_net import TwoLayerNet

max_epoch = 300
batch_size = 30
hidden_size = 10
learning_rate = 1.0

x,t = spiral.load_data()

model = TwoLayerNet(input_size=2,hidden_size=hidden_size,output_size=3)

optimizer = SGD(lr=learning_rate)
trainer = Trainer(model,optimizer)
trainer.fit(x,t,max_epoch,batch_size,eval_interval=10)
trainer.plot()




| epoch 1 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 2 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 3 |  iter 1 / 10 | time 0[s] | loss 1.13
| epoch 4 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 5 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 6 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 7 |  iter 1 / 10 | time 0[s] | loss 1.14
| epoch 8 |  iter 1 / 10 | time 0[s] | loss 1.16
| epoch 9 |  iter 1 / 10 | time 0[s] | loss 1.11
| epoch 10 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 11 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 12 |  iter 1 / 10 | time 0[s] | loss 1.12
| epoch 13 |  iter 1 / 10 | time 0[s] | loss 1.10
| epoch 14 |  iter 1 / 10 | time 0[s] | loss 1.09
| epoch 15 |  iter 1 / 10 | time 0[s] | loss 1.08
| epoch 16 |  iter 1 / 10 | time 0[s] | loss 1.04
| epoch 17 |  iter 1 / 10 | time 0[s] | loss 1.03
| epoch 18 |  iter 1 / 10 | time 0[s] | loss 0.94
| epoch 19 |  iter 1 / 10 | time 0[s] | loss 0.92
| epoch 20 |  iter 1 / 10 | time 0[s] | loss 0.92
| epoch 21 |  iter 1 / 10 | time 0[s] | loss 0.87
| epoch 22 |  iter 1 / 10 | time 0[s] | loss 0.85
| epoch 23 |  iter 1 / 10 | time 0[s] | loss 0.80
| epoch 24 |  iter 1 / 10 | time 0[s] | loss 0.79
| epoch 25 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 26 |  iter 1 / 10 | time 0[s] | loss 0.83
| epoch 27 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 28 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 29 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 30 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 31 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 32 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 33 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 34 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 35 |  iter 1 / 10 | time 0[s] | loss 0.78
| epoch 36 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 37 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 38 |  iter 1 / 10 | time 0[s] | loss 0.77
| epoch 39 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 40 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 41 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 42 |  iter 1 / 10 | time 0[s] | loss 0.76
| epoch 43 |  iter 1 / 10 | time 0[s] | loss 0.79
| epoch 44 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 45 |  iter 1 / 10 | time 0[s] | loss 0.75
| epoch 46 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 47 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 48 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 49 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 50 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 51 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 52 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 53 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 54 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 55 |  iter 1 / 10 | time 0[s] | loss 0.74
| epoch 56 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 57 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 58 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 59 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 60 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 61 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 62 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 63 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 64 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 65 |  iter 1 / 10 | time 0[s] | loss 0.72
| epoch 66 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 67 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 68 |  iter 1 / 10 | time 0[s] | loss 0.71
| epoch 69 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 70 |  iter 1 / 10 | time 0[s] | loss 0.68
| epoch 71 |  iter 1 / 10 | time 0[s] | loss 0.73
| epoch 72 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 73 |  iter 1 / 10 | time 0[s] | loss 0.69
| epoch 74 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 75 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 76 |  iter 1 / 10 | time 0[s] | loss 0.65
| epoch 77 |  iter 1 / 10 | time 0[s] | loss 0.67
| epoch 78 |  iter 1 / 10 | time 0[s] | loss 0.70
| epoch 79 |  iter 1 / 10 | time 0[s] | loss 0.63
| epoch 80 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 81 |  iter 1 / 10 | time 0[s] | loss 0.65
| epoch 82 |  iter 1 / 10 | time 0[s] | loss 0.66
| epoch 83 |  iter 1 / 10 | time 0[s] | loss 0.64
| epoch 84 |  iter 1 / 10 | time 0[s] | loss 0.62
| epoch 85 |  iter 1 / 10 | time 0[s] | loss 0.62
| epoch 86 |  iter 1 / 10 | time 0[s] | loss 0.63
| epoch 87 |  iter 1 / 10 | time 0[s] | loss 0.59
| epoch 88 |  iter 1 / 10 | time 0[s] | loss 0.58
| epoch 89 |  iter 1 / 10 | time 0[s] | loss 0.61
| epoch 90 |  iter 1 / 10 | time 0[s] | loss 0.59
| epoch 91 |  iter 1 / 10 | time 0[s] | loss 0.58
| epoch 92 |  iter 1 / 10 | time 0[s] | loss 0.57
| epoch 93 |  iter 1 / 10 | time 0[s] | loss 0.55
| epoch 94 |  iter 1 / 10 | time 0[s] | loss 0.54
| epoch 95 |  iter 1 / 10 | time 0[s] | loss 0.53
| epoch 96 |  iter 1 / 10 | time 0[s] | loss 0.54
| epoch 97 |  iter 1 / 10 | time 0[s] | loss 0.51
| epoch 98 |  iter 1 / 10 | time 0[s] | loss 0.51
| epoch 99 |  iter 1 / 10 | time 0[s] | loss 0.50
| epoch 100 |  iter 1 / 10 | time 0[s] | loss 0.47
| epoch 101 |  iter 1 / 10 | time 0[s] | loss 0.49
| epoch 102 |  iter 1 / 10 | time 0[s] | loss 0.46
| epoch 103 |  iter 1 / 10 | time 0[s] | loss 0.44
| epoch 104 |  iter 1 / 10 | time 0[s] | loss 0.47
| epoch 105 |  iter 1 / 10 | time 0[s] | loss 0.44
| epoch 106 |  iter 1 / 10 | time 0[s] | loss 0.43
| epoch 107 |  iter 1 / 10 | time 0[s] | loss 0.43
| epoch 108 |  iter 1 / 10 | time 0[s] | loss 0.39
| epoch 109 |  iter 1 / 10 | time 0[s] | loss 0.40
| epoch 110 |  iter 1 / 10 | time 0[s] | loss 0.41
| epoch 111 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 112 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 113 |  iter 1 / 10 | time 0[s] | loss 0.38
| epoch 114 |  iter 1 / 10 | time 0[s] | loss 0.37
| epoch 115 |  iter 1 / 10 | time 0[s] | loss 0.36
| epoch 116 |  iter 1 / 10 | time 0[s] | loss 0.34
| epoch 117 |  iter 1 / 10 | time 0[s] | loss 0.35
| epoch 118 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 119 |  iter 1 / 10 | time 0[s] | loss 0.35
| epoch 120 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 121 |  iter 1 / 10 | time 0[s] | loss 0.33
| epoch 122 |  iter 1 / 10 | time 0[s] | loss 0.32
| epoch 123 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 124 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 125 |  iter 1 / 10 | time 0[s] | loss 0.31
| epoch 126 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 127 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 128 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 129 |  iter 1 / 10 | time 0[s] | loss 0.30
| epoch 130 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 131 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 132 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 133 |  iter 1 / 10 | time 0[s] | loss 0.27
| epoch 134 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 135 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 136 |  iter 1 / 10 | time 0[s] | loss 0.28
| epoch 137 |  iter 1 / 10 | time 0[s] | loss 0.25
| epoch 138 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 139 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 140 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 141 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 142 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 143 |  iter 1 / 10 | time 0[s] | loss 0.26
| epoch 144 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 145 |  iter 1 / 10 | time 0[s] | loss 0.24
| epoch 146 |  iter 1 / 10 | time 0[s] | loss 0.24
| epoch 147 |  iter 1 / 10 | time 0[s] | loss 0.25
| epoch 148 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 149 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 150 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 151 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 152 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 153 |  iter 1 / 10 | time 0[s] | loss 0.23
| epoch 154 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 155 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 156 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 157 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 158 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 159 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 160 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 161 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 162 |  iter 1 / 10 | time 0[s] | loss 0.22
| epoch 163 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 164 |  iter 1 / 10 | time 0[s] | loss 0.21
| epoch 165 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 166 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 167 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 168 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 169 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 170 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 171 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 172 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 173 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 174 |  iter 1 / 10 | time 0[s] | loss 0.20
| epoch 175 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 176 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 177 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 178 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 179 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 180 |  iter 1 / 10 | time 0[s] | loss 0.19
| epoch 181 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 182 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 183 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 184 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 185 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 186 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 187 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 188 |  iter 1 / 10 | time 0[s] | loss 0.18
| epoch 189 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 190 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 191 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 192 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 193 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 194 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 195 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 196 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 197 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 198 |  iter 1 / 10 | time 0[s] | loss 0.17
| epoch 199 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 200 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 201 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 202 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 203 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 204 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 205 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 206 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 207 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 208 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 209 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 210 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 211 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 212 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 213 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 214 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 215 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 216 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 217 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 218 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 219 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 220 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 221 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 222 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 223 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 224 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 225 |  iter 1 / 10 | time 0[s] | loss 0.16
| epoch 226 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 227 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 228 |  iter 1 / 10 | time 0[s] | loss 0.15
| epoch 229 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 230 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 231 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 232 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 233 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 234 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 235 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 236 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 237 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 238 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 239 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 240 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 241 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 242 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 243 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 244 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 245 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 246 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 247 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 248 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 249 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 250 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 251 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 252 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 253 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 254 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 255 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 256 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 257 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 258 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 259 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 260 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 261 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 262 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 263 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 264 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 265 |  iter 1 / 10 | time 0[s] | loss 0.14
| epoch 266 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 267 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 268 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 269 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 270 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 271 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 272 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 273 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 274 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 275 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 276 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 277 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 278 |  iter 1 / 10 | time 0[s] | loss 0.13
| epoch 279 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 280 |  iter 1 / 10 | time 0[s] | loss 0.10
| epoch 281 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 282 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 283 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 284 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 285 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 286 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 287 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 288 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 289 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 290 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 291 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 292 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 293 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 294 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 295 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 296 |  iter 1 / 10 | time 0[s] | loss 0.12
| epoch 297 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 298 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 299 |  iter 1 / 10 | time 0[s] | loss 0.11
| epoch 300 |  iter 1 / 10 | time 0[s] | loss 0.11

在这里插入图片描述

执行这段代码,会进行和之前一样的神经网络的学习。通过将之前展示的学习用的代码交给Trainer类负责,代码变简洁了。本书今后都将使用Trainer类进行学习。

Tip-高速化计算

安装NVIDIA Cuda管理器,并使用Cupy调用GPU进行告高速计算。

https://baijiahao.baidu.com/s?id=1781877547348944869&wfr=spider&for=pc

;