Bootstrap

李宏毅机器学习2020课后作业笔记 【hw1】 pm2.5值的预测

李宏毅机器学习2020课后作业 ML2020spring - hw1

1. 问题描述

给定一年内台湾某市的空气质量观测数据(每个月20天,每小时记录一次,每天共记录24个小时),目的是来预测某时刻的pm2.5值。训练数据是12个月,每个月20天的空气质量观测数据,每天共有18个特征(包括pm2.5),每天共24个小时的数据,也就是说每天共18 * 24个数据值。测试数据是前9个小时的数据,用来预测第10个小时的pm2.5值。

2. 数据预处理

  1. p a n d a s pandas pandas读取 t r a i n . c s v train.csv train.csv文件,先读取成 D a t a F r a m e DataFrame DataFrame类型,再转成 n u m p y . n d a r r a y numpy.ndarray numpy.ndarray类型
  2. 按照作业要求,所有 N R NR NR都置为0
  3. 再把每个月的数据拼成一排,用字典存储, k e y key key就是月份, v a l u e value value就是每个月的数据
  4. 做标准化, n p . m e a n np.mean np.mean n p . s t d np.std np.std
  5. 分割数据集,80% + 20%

3. 模型构建

这个作业构建的是线性回归模型, l o s s loss loss采用 R M S E RMSE RMSE R M S E = 1 N ∑ i = 0 N ( y p r e d − y i ) 2 RMSE = \sqrt{\frac{1}{N}\sum_{i=0}^{N}(y ^{pred} - y^i )^2} RMSE=N1i=0N(ypredyi)2

梯度更新采用 A d a g r a d Adagrad Adagrad优化器, A d a g r a d Adagrad Adagrad相比于普通的梯度下降拥有更好的效果, w t + 1 = w t − η σ t g t w^{t+1} = w^t - \frac{ \eta} {\sigma^t}g^t wt+1=wtσtηgt其中 σ t = ∑ i = 0 n ( g i ) 2 \sigma^t=\sqrt{\sum_{i=0}^{n}(g^i)^2} σt=i=0n(gi)2 核心代码如下所示

loss_list = []
for epoch in range(iter_time):
    # RMSE
    loss = np.sqrt( np.sum(np.power(np.dot(x, w) - y, 2) ) )
    loss_list.append(loss)
    # 100轮打印一个loss出来看看
    if epoch % 100 == 0:
        print(str(epoch) + ":" + str(loss))
    # gradient of the loss 2x(xw - y)
    gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - y)
    adagrad += gradient ** 2
    # Adagrad iteration formula
    w -= learning_rate * gradient / np.sqrt(adagrad + eps)
# 模型存储
np.save('weight.npy', w)

4. 模型评估

也是一样,做一些预处理,比如读成 n p . n d a r r a y np.ndarray np.ndarray N R NR NR补0,标准化之类的,最后调用训练好的模型做预测。

w = np.load('weight.npy')
ans_y = np.dot(test_x, w)
ans_y

最后将结果写入到csv。

with open('submit.csv', mode='w', newline='') as submit_file:
    csv_writer = csv.writer(submit_file)
    header = ['id', 'value']
    print(header)
    
    csv_writer.writerow(header)
    for i in range(240):
        row = ['id_' + str(i), ans_y[i][0]]
        csv_writer.writerow(row)
        print(row)

最后将结果上传到kaggle对应作业提交区域。

5. 个人总结

  1. 数据的预处理很重要,hw1主要是将数据进行了合并、标准化等操作
  2. A d a g r a d Adagrad Adagrad实战
  3. 学会了如何将数据写会到 c s v csv csv文件