Bootstrap

机器学习模型评估与改进:网格化调参(grid search)


在交叉验证部分我们知道了如何科学地评估算法模型的泛化能力,那么我们可以进一步看看,如何对模型进行调参,以达到改进模型效果的目的。

首先,在调参之前,必须对算法参数的意义有清楚的理解。对那些重要参数“调试”是一个比较trick的任务,但却又必不可少。好在scikit learn提供了很多工具来辅助这个让人头疼的过程。其中,最常用的方法就是网格化搜索(grid search),也就是对所有可能的参数取值和不同参数取值的组合逐一尝试,最后确定最佳参数。

以SVM(RBF核)为例, 有两个非常重要的参数:

  1. kernel bandwidth : gamma
  2. 正则化参数: C

为了得到参数gamma,C的最佳设置,我们依次从[0.001, 0.001, 0.01, 0.1, 1, 10, 100]这些数字选择一个组合,作为最优的SVM参数,显然,这样的组合有6*6=36种。

简单网格化搜索

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state = 0)
print(“Size of training set: {
    }   size of test set: {
    }.format(
    X_train.shape[0], X_test.shape[0]))
best_score=0
for gamma in [0.001, 0.01, 0.1, 1, 10, 100]:
    for C in [0.001, 0.01, 0.1, 1, 10, 100]:
        # for each combination of parameters, train an SVC
        svm = SVC(gamma=gamma, C=C)
        svm.fit(X_train, y_train)
        # evaluate the SVC on the test set
        score = svm.score(X_test, y_test)
        # if we got better score, store it
        if score > best_score:
            best_score=score
            best_parameters = {
   ‘C’:C, ‘gamma’:gamma}
print(‘Best score: {
   :.2f}.format(best_score))
print("Best parameters: {}".format(best_parameters) )

得到:

Size of training set: 112   size of test set: 38
    Best score: 0.97
    Best parameters: {
   'C': 100, 'gamma': 0.001}

参数过拟合的风险

上面的例子给出了“最好”的参数组合下模型在测试集合上的得分为97%, 似乎模型优化已经完成了。

其实则不然。因为这里给出的正确率仅仅是在参数网格化搜索时候用到的测试数据上模型的表现,换句话说,我们是用了一个特定的测试集合衡量模型的表现,我们不能用同一组数据来测试模型的泛化能力。也就是说,97%的表现,在新的数据集合上模型的正确率是不是还是这样呢?答案是不一定。

这里一定要非常小心,因为一不小心就会中了循环论证的圈套。比较正确的方法应该是:将数据集合分为训练集、验证集和测试集三部分。

训练集合 — 训练模型
验证集合 — 模型参数选择
测试集合. — 测试模型性能

接着上面的例子,我们用最佳的参数组合作为模型参数,在原来的训练数据和验证数据上对模型重新训练,然后在测试集合上对模型的性能进行测试。

from sklearn.svm import SVC
 # split data into train+validation set and test set 
X_trainval, X_test, y_trainval, y_test = train_test_split( iris.data, iris.target, random_state=0)
 # split train+validation set into training and validation sets X_train, X_valid, y_train, y_valid = train_test_split( 
X_trainval, y_trainval, random_state=
;