Bootstrap

R与机器学习系列|15.可解释的机器学习算法(Interpretable Machine Learning)(中)

在上次推文中我们介绍了几种可解释机器学习算法的常见方法,包括置换特征重要性、偏依赖图和个体条件期望及其实现。本次我们将继续介绍其他的用来解释机器学习算法的方法。

1.特征交互(Feature interactions)

1.1介绍

在机器学习中,Feature Interactions(特征交互)是指不同特征之间的相互作用或联合效应。特征交互可以帮助我们更好地理解数据,发现特征之间的复杂关系,以及提高机器学习模型的性能。当预测模型中的特征之间存在交互作用时,特征对预测结果的影响不是简单的加和,而是更为复杂。在现实生活中,大多数特征与某些响应变量之间的关系都是复杂的,包括交互作用。这也是为什么更复杂的算法(尤其是基于树的算法)往往表现得非常好的原因——它们的复杂性通常能够自然地捕捉复杂的交互作用。然而,识别和理解这些交互作用较为困难的。

估计交互作用强度的一种方法是衡量预测结果的变化有多少取决于特征之间的交互作用。这种衡量称为H统计量,由Friedman、Popescu等人在2008年提出。

特征交互通常在以下两个方面进行考虑:

  1. 特征之间的组合效应:某些特征组合在一起可能具有比它们单独使用更强的预测能力。例如,对于预测房价的问题,房屋的面积和地理位置可能单独对房价的预测有一定的影响,但将它们组合在一起可能会得到更准确的预测结果。

  2. 特征之间的相互作用:某些特征之间可能存在非线性相互作用,即它们的组合效果不是简单的加和关系。例如,对于预测用户购买行为的问题,用户的年龄和购买频率可能存在相互作用,即不同年龄段的用户在购买频率上表现不同。

1.2实施

目前,H统计量的计算仅可以通过iml包实现。我们可以使用Interaction$new()来计算单向交互作用,以评估两个特定特征在模型中如何相互作用,并且强度如何。

不幸的是,由于算法的复杂性,H统计量需要进行2n^2次运行,因此计算非常耗时。例如,计算单向交互作用的H统计量需要花费两个小时的时间!在这种情况下,我们可以通过在iml包中减少grid.size或使用parallel = TRUE进行并行计算来加速计算。
这里为了示例,小编仅使用示例数据中的几个特征演示一下这一部分(运行时间真的太长了,所以就只选几个变量)

# 加载依赖包
library(dplyr)      #数据操纵
library(ggplot2)    # 可视化

# Modeling packages
library(h2o)       # H2O
library(recipes)   # 机器学习蓝图
library(rsample)   # 数据分割
library(xgboost)   # 拟合GBM模型

# 模型可解释性包
library(pdp)       # 偏依赖图及ICE曲线绘制
library(vip)       # 变量重要性VIP图
library(iml)       # 普遍IML相关函数
library(DALEX)     # 普遍IML相关函数
# devtools::install_github('thomasp85/lime')
library(lime)      # 局部可解释模型无关解释
load("inputdata.Rda")#加载示例数据
inputdata<-inputdata[,-1]
inputdata$Event<-factor(inputdata$Event,levels = c(0,1),labels = c("Alive","Death"))#结局变量因子化
set.seed(123)  # 设置随机种子保证可重复性
split <- initial_split(inputdata, strata = "Event")#数据分割
data_train <- training(split)
data_test <- testing(split)

我们这里按照Interaction {iml}示例,使用示例数据拟合一个分类任务的CART算法

library("rpart")
set.seed(42)
data<-data_train[,1:10]
rf<-rpart(Event~.,data=data)
mod <- Predictor$new(rf, data =data, type = "prob")

# For some models we have to specify additional arguments for the
# predict function
ia <- Interaction$new(mod)

接下来我们对交互作用降序排列

ia$results %>% 
  arrange(desc(.interaction)) %>% 
  head()
# .feature .class .interaction
# 1:  AADACP1  Death    0.6121357
# 2:  AADACP1  Alive    0.6121357
# 3:     AAAS  Alive    0.6038541
# 4:     AAAS  Death    0.6038541
# 5:    AAGAB  Alive    0.5892189
# 6:    AAGAB  Death    0.5892189

可以进一步使用plot函数将结果可视化出来

plot(ia)

在确定了具有最强交互作用的变量后,我们可以计算h统计量,以确定它主要与哪些特征存在交互作用。我们可以看到AAGAB和AADAT之间有较强的交互作用。

interact_2way <- Interaction$new(mod, feature = "AAGAB")
interact_2way$results %>% 
  arrange(desc(.interaction)) %>% 
  top_n(10)
Selecting by .interaction
# .feature .class .interaction
# 1:   AADAT:AAGAB  Alive    0.4992814
# 2:   AADAT:AAGAB  Death    0.4992814
# 3:    AACS:AAGAB  Death    0.2733772
# 4:    AACS:AAGAB  Alive    0.2733772
# 5: AADACP1:AAGAB  Alive    0.2707731
# 6: AADACP1:AAGAB  Death    0.2707731
# 7:    AAAS:AAGAB  Death    0.2669400
# 8:    AAAS:AAGAB  Alive    0.2669400
# 9:     A2M:AAGAB  Alive    0.1190083
# 10:     A2M:AAGAB  Death    0.1190083

识别这些交互作用可以帮助我们了解它们与响应变量的关系。我们可以使用PDPs或ICE曲线来观察交互作用对预测结果的影响。我们可以看下上面步骤发现的两个交互作用比较强的变量如何影响结局事件的预测。

# Two-way PDP using iml
interaction_pdp <- Partial$new(
  mod, 
  c("AAGAB", "AADAT"), 
  ice = FALSE, 
  grid.size = 20
) 
plot(interaction_pdp)

1.3其他可供选择方法

显然,计算时间是确定潜在交互效应的主要限制因素,因为这个过程会花费很长的计算时间。尽管 H 统计量是检测交互作用的方法中最具统计学意义的方法,但还有其他选择。在Brandon M Greenwell, Boehmke和McCarthy (2018)中讨论的基于PDP的变量重要性测量也可以用于量化潜在交互效应的强度,可以通过vip::vint()实现。

2.局部可解释模型无关解释(Local interpretable model-agnostic explanations)

2.1介绍

机器学习中的局部可解释模型无关解释(Local Interpretable Model-Agnostic Explanations,LIME)是一种用于解释机器学习模型预测的方法。LIME的目标是在特定样本附近构建一个局部线性模型来近似原始模型的预测结果,并解释该局部模型的系数以得到对预测的解释。

LIME的基本思想是通过生成一组“虚拟样本”,这些虚拟样本是原始样本在特征空间中的近似,然后利用这些虚拟样本来训练一个局部线性模型。在构建局部模型时,LIME使用一种称为“稀疏线性模型”的方法,这是一种可以解释性较好的线性模型。通过解释稀疏线性模型的系数,我们可以得到对预测的解释,即哪些特征对于模型的预测起到了关键作用。

LIME的优点是可以应用于任何类型的机器学习模型,而不仅限于特定类型的模型。它还可以在不需要访问原始模型的内部结构的情况下解释模型的预测结果,因此被称为“模型无关”的解释方法。

然而,LIME也有一些局限性。首先,由于LIME是在局部构建模型,所以解释的可信度可能会受到局部数据分布的影响。其次,LIME的解释是基于稀疏线性模型的,可能会丢失一些复杂模型的细节。因此,在使用LIME时,需要仔细考虑其解释的适用范围和可信度。

LIME所应用的一般算法如下:

  1. 对训练数据进行置换以创建复制的特征数据,这些数据的值可能有轻微修改。
  2. 计算感兴趣观测值与每个置换观测值之间的接近度度量(例如,1 - 距离)。
  3. 使用选定的机器学习模型预测置换数据的结局。
  4. 选择 m 个特征来最好地描述预测结果。
  5. 对置换数据拟合一个简单模型,用 m 个特征来解释复杂模型的结局,并根据其与原始观测值的相似性进行加权。
  6. 使用得到的特征权重来解释局部行为。

2.2实施

以上过程可以通过lime包实现,主要涉及两个过程:lime::lime()和lime::explain()。lime::lime()函数用于创建一个"explainer"对象,它是一个包含已拟合的机器学习模型和训练数据特征分布的列表。其中包含的特征分布包括每个分类变量水平和每个连续变量分为n个箱子(当前默认为四个箱子)的分布统计。这些特征属性将用于对数据进行置换。
首先我们根据前面的堆叠算法先生成一个堆叠模型

#训练一个堆叠模型(见机器学习系列堆叠算法)
ensemble_tree <- h2o.stackedEnsemble(
  x = X, y = Y, training_frame = train_h2o, model_id = "my_tree_ensemble",
  base_models = list(best_glm, best_rf, best_gbm,best_nb,best_nn),
  metalearner_algorithm = "drf"
)

接着提取特征

features<-data_train%>%select(-Event)

使用lime函数基于堆叠算法模型和特征创建一个解释器

# Create explainer object
components_lime <- lime(
  x = features,
  model = ensemble_tree, 
  n_bins = 10
)

class(components_lime)
## [1] "data_frame_explainer" "explainer"            "list"
# Length Class            Mode     
# model                  1    H2OBinomialModel S4       
# preprocess             1    -none-           function 
# bin_continuous         1    -none-           logical  
# n_bins                 1    -none-           numeric  
# quantile_bins          1    -none-           logical  
# use_density            1    -none-           logical  
# feature_type         101    -none-           character
# bin_cuts             101    -none-           list     
# feature_distribution 101    -none-           list     

然后我们在验证集数据中选择两个对象,基于他们的特征来解释模型

lime_explanation <- lime::explain(
  x =data_new, 
  explainer = components_lime, 
  n_permutations = 5000,
  dist_fun = "gower",
  kernel_width = 0.25,
  n_labels = 2,
  n_features = 10, 
  feature_select = "highest_weights"
)

lime::explain()函数的主要参数及意义如下:

x:要为其创建局部解释的观察值。
explainer:采用由lime::lime()创建的解释器对象,将用于创建置换数据。置换是从lime::lime()解释器对象创建的变量分布中采样得到的。
n_permutations:为x中的每个观察值创建的置换数(默认为5000)。
dist_fun:要使用的距离函数。默认为Gower距离,但也可以使用Euclidean、Manhattan或dist()函数允许的任何其他距离函数。为了计算相似性,分类特征将根据它们是否等于实际观察值进行重新编码。如果连续特征被分箱(默认值),则这些特征将根据它们是否与要解释的观察值在同一个箱中进行重新编码。然后,使用重新编码的数据计算到原始观察值的距离。
kernel_width:为了将距离度量转换为相似度得分,使用用户定义的宽度的指数核(默认为特征数的0.75倍的平方根)。
n_features:最能描述预测结果的特征数。
feature_select:lime::lime()可以使用前向选择、岭回归、LASSO或决策树来选择“最佳”的n_features特征。
对于分类模型,我们需要指定一些额外的参数:
labels:要解释的特定标签(类)(例如,0/1,“是”/“否”)?
n_labels:要解释的标签数。

如果原始的机器学习模型是回归模型,局部模型将直接预测复杂模型的输出结果。如果是分类器,局部模型将预测所选择类别的概率。
通过lime::explain()函数得到的输出是一个包含各种信息的数据框,用于描述局部模型的预测结果。其中最重要的信息是对于每个提供的观测值,它包含了拟合的解释模型 (model_r2) 和每个重要特征 (feature_desc) 的加权重要性 (feature_weight),用于最佳描述局部关系。

glimpse(lime_explanation)
# Rows: 40
# Columns: 13
# $ model_type       <chr> "classification", "classification", "classification", "classification", "classification~
# $ case             <chr> "TCGA-D5-6530", "TCGA-D5-6530", "TCGA-D5-6530", "TCGA-D5-6530", "TCGA-D5-6530", "TCGA-D~
#   $ label            <chr> "Alive", "Alive", "Alive", "Alive", "Alive", "Alive", "Alive", "Alive", "Alive", "Alive~
# $ label_prob       <dbl> 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.01, 0.01, 0.01, 0.01, 0.0~
# $ model_r2         <dbl> 0.01444325, 0.01444325, 0.01444325, 0.01444325, 0.01444325, 0.01444325, 0.01444325, 0.0~
# $ model_intercept  <dbl> 0.7045399, 0.7045399, 0.7045399, 0.7045399, 0.7045399, 0.7045399, 0.7045399, 0.7045399,~
# $ model_prediction <dbl> 0.7756461, 0.7756461, 0.7756461, 0.7756461, 0.7756461, 0.7756461, 0.7756461, 0.7756461,~
# $ feature          <chr> "ABCA1", "ABCC4", "ABHD3", "AAMP", "ABHD2", "ABCG1", "ABCG2", "ABCD3", "AC000111.2", "A~
#   $ feature_value    <dbl> 1.4922773, 0.8637574, 3.8144376, 6.1803265, 4.3477301, 2.2083119, 0.2595689, 3.2786676,~
#   $ feature_weight   <dbl> 0.02450017, -0.02491953, -0.02346769, 0.02196452, 0.02330070, 0.01894933, -0.01702930, ~
#   $ feature_desc     <chr> "1.475 < ABCA1 <= 1.642", "ABCC4 <= 1.17", "3.66 < ABHD3 <= 3.92", "6.17 < AAMP", "4.29~
# $ data             <list> [1.041423, 4.595776, 1.484305, 3.626439, 2.718439, 0.09743433, 0.594286, 1.95217, 3.91~
# $ prediction       <list> [0.99, 0.01], [0.99, 0.01], [0.99, 0.01], [0.99, 0.01], [0.99, 0.01], [0.99, 0.01], [0~

我们看看可视化的结果。然而,需要注意模型的低R²(“解释适配度”)。局部模型的拟合效果似乎相当差,因此我们不应过于依赖这些解释。

plot_features(lime_explanation, ncol =2)

2.3参数调整

在执行LIME(局部解释性模型)时,我们可以调整几个参数,将它们视为调参参数,这样可以尝试调整局部模型。这有助于最大程度地增加局部解释性模型的可信性。

#LIME算法调参
lime_explanation2 <- explain(
  x =data_new, 
  explainer = components_lime, 
  n_permutations = 5000,
  dist_fun = "euclidean",
  kernel_width = 0.75,
  n_labels = 2,
  n_features = 10, 
  feature_select = "lasso_path"
)

#可视化结果
plot_features(lime_explanation2, ncol = 2)

在上面的调参过程中,我们将距离函数更改为欧几里得距离,增加了核宽度以创建更大的局部区域,并将特征选择方法改为基于LARS的LASSO模型。

2.4其他可供选择的方法

上面的示例我们主要围绕在表格数据集中使用LIME进行解释性模型的构建。然而,LIME也可以应用于非传统数据集,例如文本和图像。对于文本数据,LIME会创建一个包含扰动文本的新的文档-词矩阵(例如,它会基于现有文本生成新的短语和句子)。然后,LIME会按照类似的步骤对生成的文本与原始文本的相似性进行加权。局部模型然后帮助确定在扰动文本中哪些词语产生了最强的信号。

对于图像数据,LIME会通过用一个常量颜色(例如灰色)替换某些像素组合来创建图像的变体。然后,LIME会评估给定未扰动像素组的预测标签。

3.SHAP值

3.1背景

SHAP(SHapley Additive exPlanations)是一种解释机器学习模型预测的方法,它基于合作博弈理论中的Shapley值概念。SHAP通过计算每个特征对于模型预测输出的贡献,帮助我们理解模型预测的原因和解释。

在机器学习中,模型预测的输出往往由多个特征共同决定。SHAP通过考虑每个特征值与其他特征值之间的交互作用,将模型预测的总体变化分配给每个特征。这样,我们可以了解每个特征对于模型输出的相对重要性,以及特征之间的相互作用对预测的影响。

SHAP值具有以下特点:

  1. 公平性:SHAP值确保在所有可能的特征子集中,特征的贡献是公平的,不受其他特征的影响。
  2. 一致性:SHAP值满足Shapley值的一致性属性,即如果两个模型预测相同,但是特征值不同,那么它们的SHAP值应该相同。
  3. 局部解释性:SHAP值提供了对于单个样本的局部解释,即了解某个特定样本的预测结果是由哪些特征贡献决定的。

在之前的推文中我们介绍了R和Python中SHAP值的可视化过程,当然也包括生存数据的SHAP值可视化。在这次的示例中我们将一起学习如何通过iml包实现SHAP值的可视化。

#SHAP
#提取特征
features <- as.data.frame(train_h2o) %>% select(-Event)

#提取响应变量
response <- as.data.frame(train_h2o) %>% pull(Event)

#自定义函数
pred <- function(object, newdata)  {
  results <- as.vector(h2o.predict(object, as.h2o(newdata)))
  return(results)
}

#创建一个iml模型无关对象
components_iml <- Predictor$new(
  model = ensemble_tree, 
  data = features, 
  y = response, 
  predict.fun = pred
)
#计算SHAP值
(shapley <- Shapley$new(components_iml, x.interest =data_new, sample.size =1000))
# |========================================================================================================| 100%
# |========================================================================================================| 100%
# |========================================================================================================| 100%
# |========================================================================================================| 100%
# |========================================================================================================| 100%
# |========================================================================================================| 100%
# Interpretation method:  Shapley 
# Predicted value: 0.980000, Average prediction: 0.703570 (diff = 0.276430) Predicted value: 0.020000, Average prediction: 0.296430 (diff = -0.276430)
# 
# Analysed predictor: 
#   Prediction task: unknown 
# 
# 
# Analysed data:
#   Sampling from data.frame with 338 rows and 101 columns.
# 
# 
# Head of results:
#   feature class           phi     phi.var            feature.value
# 1    A1CF Alive  0.0017600000 0.006776857    A1CF=1.04142313260927
# 2     A2M Alive -0.0016233333 0.001454914     A2M=4.59577575137959
# 3  A4GALT Alive -0.0028233333 0.002947425  A4GALT=1.48430469549071
# 4    AAAS Alive -0.0008516667 0.005436699    AAAS=3.62643913669732
# 5    AACS Alive  0.0077550000 0.007041842    AACS=2.71843944462677
# 6   AADAC Alive  0.0004400000 0.002993781 AADAC=0.0974343320963354
#可视化结果
plot(shapley)

上述过程的计算时间主要取决于预测变量的数量和样本大小。默认情况下,Shapley$new()函数只使用100个样本,但我们可以通过控制参数来减少计算时间或增加估计值的可信性。

由于iml使用R6,我们可以重复使用Shapley对象来解释特征对预测结局的影响。

#继续使用shapley值解释感兴趣的观测值
shapley$explain(x.interest =data_new)

#可视化结果
shapley$results %>%
  top_n(20, wt = abs(phi)) %>%
  ggplot(aes(phi, reorder(feature.value, phi), color = phi > 0)) +
  geom_point(size=3)+
  scale_color_brewer(palette = "Set2")+
  theme_bw()

今天的分享就到这里了。下次我们将分享XGBoost算法中SHAP值的可视化,也是比较经典的可视化,另外我们将分享最后一个iml方法—Localized step-wise procedure。下次分享后我们的R与机器学习系列推文也就基本到尾声了,后面的无监督机器学习算法暂时不做重点分享。后续的分享重点将围绕Python与机器学习系列展开。因为在分享过程中发现R中进行机器学习算法运算时还是比较耗时间、耗内存,尤其是在样本量很大的情况下,就不得不借助Python了。另外,Python才是机器学习的主流。后续我们将从Python基础知识逐渐过渡到Python与机器学习系列。欢迎一起学习!

最后编辑于:2023-08-08 12:23


喜欢的朋友记得点赞、收藏、关注哦!!!

;