暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

高效的ShapleyValue计算 - TreeShap

RandomGenerator 2021-06-11
2304

Shap 运作方式

Shap 作为一个模型解释工具大家应该都已经很熟悉了。如果之前没有用过,可以快速看一下官方的使用示例:这里的 notebooks[1]。它可以实现通过 post-hoc 模型解释,展示预测条目中每个特征的具体贡献度是多少,相比简单的 feature importance,能做到 local 级别的量化解释。

Local 预测解释

其背后的原理来自于 博弈论中的 shapley value[2],对于如何应用于模型解释方面的详细介绍,可以参考 Christoph 的《Interpretable Machine Learning[3]》中的 5.9, 5.10 两章。这里简要说明一下,标准的 shap 运行逻辑大致如下:

我们有一个黑盒模型 m,有一条需要预测解释的条目 x,由 a, b, c 三个特征组成。我们如何获取到特征 a 在最终预测中的贡献度呢?首先引入一个概念,特征的缺失和引入。要了解特征贡献,一个直观的想法就是看某个特征在与不在对预测结果造成的影响。在 shap 里,特征“缺失”,一般的方法是在所有样本中采样,去看这个特征如果随机选择一个值,模型输出如何。特征的“引入”,则表示这个特征使用了预测条目中的值,模型的输出如何。

接下来一个概念是看特征在不同阶段引入的边际贡献。也就是说 a, b, c 三个特征,a 特征的贡献,要分别计算 a 作为第一个特征引入,作为第二个特征引入,作为第三个特征引入的边际贡献是多少。我们用 a 代表特征引入,a'代表特征缺失,具体计算如下:

  1. 第一个特征引入,使用 m(a, b', c') - m(a', b', c') 得到边际贡献。
  2. 第二个特征引入,使用 1/2 * (m(a, b, c') - m(a', b, c')) + 1/2 * (m(a, b', c) - m(a', b', c)),得到边际贡献。可以看到作为第二个特征引入时有两种情况,所以这里分别乘以 1/2。
  3. 第三个特征引入,使用 m(a, b, c) - m(a', b, c) 得到边际贡献。
  4. 最后再把三种情况分别乘以 1/3 加和,就是特征 a 的总体贡献。

要计算其它特征的贡献度,对 b 和 c 做同样的计算就可以。这里 m(a', b', c') 就是传说中的** base value**,加上三个特征的贡献度,正好得到最终的预测值。这也是 shap 的一个非常优良的特性,可加性。Shap 在理论上还有一些其它优点,感兴趣的同学可以参考 shap 的论文[4]

TreeShap 的效率提升

背景资料

从上面的计算过程可以看到,shap 可以作用于任何黑盒模型,但是计算复杂度非常高。首先特征在不同位置引入这个遍历,会达到指数级的复杂度。另外在估计特征缺失时,进行多次采样也需要多次触发模型预测过程。所以就有一些利用模型本身的特性来加速计算 shap 的方法。今天我们主要来看 TreeShap 的效率提升方法。相关参考资料如下:

TreeShap 加速的论文也来自 shap 库的作者,具体可以参考这里:https://arxiv.org/pdf/1706.06060.pdf

Shap 库中调用的部分,注意解释运行时也有不同的模式,作者上述文章中的实现对应的是 tree path dependent 模式:https://github.com/slundberg/shap/blob/master/shap/explainers/_tree.py

我们以 lgb 为例,可以看到库中只是简单的提交了一个 original_model.predict(..., pred_contrib=True),所以主要代码还是在 lgb 库中:https://github.com/microsoft/LightGBM/blob/master/src/io/tree.cpp

主要看 TreeShap 这个方法。注意这个方法的实现,其实也是这位 shap 库的作者,人家还是很硬核的,shap 库的可视化做的非常漂亮,写 cpp 一样也是信手拈来,还发了几篇很有影响力的论文。一般学术成果要推广,能实现**高性能,高质量且易用的 library **还是非常关键的一点。不过嘛,shap 库里的 Python 代码质量还是有不少提升空间的……

利用树结构评估特征贡献

从前面 shap 的运行方式来看,我们在计算过程中,分别用模型预测了 8 种不同的的特征组合采样,来看特征边际贡献值。但树模型是如何进行判断和预测的方式,我们是非常明确的。在每个节点,树模型都会选择一个特征进行分裂,如果这个节点选择的特征,是“缺失”状态,那么对于这个特征的贡献度,我们可以理解为是左右两棵子树贡献度之和。如果这个节点选择的特征是“引入”状态,那么我们可以看这个特征值下我们会走到左节点还是右节点,贡献度完全由那个节点产生。这样我们可以从特征组合的黑盒角度转变为从树的结构的角度去计算各个特征的贡献度了,跟计算 feature importance 的想法类似。

这就是论文中提到的 Algorithm 1 的内容:

红框中的符号写错了

这里的逻辑有些复杂,为了便于理解,我们把问题做简化,假设只有一个特征 a,一共有两条数据,分别是 a=0, label=0 和 a=1, label=1。然后我们构建一棵决策树,大致长这样:

单棵决策树模型

图中也顺带列举了一下公式中的符号代表的含义:

  • d 代表当前节点选择分裂的特征的 index,这里只有一个特征,所以肯定选中 index 0
  • t 代表分裂的阈值,小于等于则分到左节点,反之右节点
  • r 代表节点的数据量,根节点有 2 个,叶子节点则都只有 1 个
  • a 和 b 是数据 index 的集合,这里比较简单,左边节点就是 [0],右边节点就是 [1],本质是希望能区分出来这个节点包含了哪些数据点
  • v 是数据点的 label 值,加上下标 j 就是取一下在当前节点所有数据点的 label 均值,这里左节点均值为 0,右节点均值为 1

从公式看整体的流程:在算特征贡献度的时候,如果是叶子节点,则用的是 w * vj,这个 vj 就是叶子节点的预测值(即上面的 label 均值),w 是通过上面的 r 算出来的数据量占比,代码中对应:

const double w = data_count(node);
// 先可以不用理会这个 hot, cold fraction
const double hot_zero_fraction = data_count(hot_index) / w;
const double cold_zero_fraction = data_count(cold_index) / w;

如果是中间节点,先判断这个 dj(split 的特征索引)是不是属于 S(特征是否引入的集合):

  • 如果属于 S,就按照这个节点的 threshold 去分裂,分到左边,则采用左边节点 (aj) 的贡献度,分到右边,则采用右边节点 (bj) 的贡献度。
  • 如果不属于,则更新左右两边的 w,递归算完后加起来。这里公式写错了,右边应该是跟上面一样,换成 bj。

这么看可能还是有些云里雾里,我们直接拿这棵树来算一下。因为只有一个特征,所以这个 S 只有两种可能,即:

  • 没有任何一个特征引入:{}
  • 引入特征 a:{a}

当没有一个特征引入时,我们会去分别计算左右两个子节点的贡献度,分别是 1 * 1/2 ,和 0 * 1/2,加起来就是 0.5,得到了我们的 base value。

当引入特征 a 时,对于第一个数据点 (0, 0),会计算左边路径的贡献度,跟上面的区别是因为“激活”了这个特征,所以这个数据量权重 w 是不做更新的,最终结果为 0 * 1 为 0。所以对于第一个数据点,特征 a 的最终贡献度算出来为 0 - 0.5 为-0.5。同理对于第二个数据点 (1, 1),同样的算法得到特征 a 的贡献度为 1 * 1 - 0.5 为 0.5。是不是很 make sense!

利用 algorithm 1,我们去掉了原版计算 shap 过程中对缺失特征做的“采样”处理,提升了一部分效率。

利用遍历树进一步加速

有了 algorithm 1 之后,我们可以把不同特征组合方式输入进去来求得各个特征的贡献度,但特征组合仍然是 2^M 的指数级复杂度。这里一个很自然的想法就是我们不需要遍历特征组合情况,而只要遍历树的路径的所有可能即可,在遍历过程中把计算的 path 信息记录下来,然后在叶子节点就能计算出这条 path 的特征贡献信息。这就是作者在文中提出的 Algorithm 2 的大致思路。

如果想要一个直观理解,还是以上面两个数据点的例子来看,假设我们的特征不是 1 个,而是有 10 个:(f0, f1, ..., f9) -> 0。但是树的结构仍然是上面那样,只使用特征 0 做了一次分裂就结束了。如果以标准的计算方法,对于每一条样本,我们都要计算特征 0 在各个位置引入时的贡献度,特征 1 在各个位置引入时的贡献度,以此类推,起码要计算 2^10 种特征是否引入的可能性,才能最终推算出每个特征的贡献度。

但是如果我们以树的结构来看,其实特征 1-9 是否引入,在这棵树中都不会有任何决策的变化,所以我们其实根本不需要计算那些情况。这就是所谓的“按照树结构”来进行遍历,可以大大降低计算复杂度。

如果只想了解 intuition,这段看到这里就可以结束了。下面是结合代码公式做的一些具体流程分析,略有些复杂。

对于某个节点 x,使用了特征 a 来做分裂条件,那么在经过这个节点时,不管特征组合如何变化,只可能有特征 a 的存在缺失两种情况,这也是对应到代码中的 one pathzero path 的 fraction。

phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];

叶子节点的这个计算,就对应上了原版 shap 中的红框部分:

特征贡献度公式

作者代码里还有 hot, cold index,看着有点 confuse,其实就是根据特征决策走的哪个节点,true 就是 hot,否则就是 cold,这会跟 Algorithm 1 里的 aj, bj 部分对应上。

const int hot_index = Decision(feature_values[split_feature_[node]], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);

论文中 Algorithm 2 的伪代码的流程还挺复杂的,核心是这一块:

Algorithm 2 中核心的递归逻辑

其中 extend 就是去增长 path,更新 zero, one fraction 等信息,还比较直观。Unwind 这个是判断 path 中是否已经对这个 feature 做过 split 了,如果是的话就 undo 掉,在这个节点再做 split,我个人理解是对 path 中保留的 feature 信息做了合并,也就是路径中可能对 a 特征有多次的使用,但最后在看贡献度时,我们会把多次操作的 zero, one fraction 信息合并在一起。这样复杂度会从节点数的平方(所有可能的 path) 降低到树的深度的平方 (extend/unwind,recurse 的复杂度都为 O(d))。这块的详细代码比较复杂,我们这边就不做深入展开,只需要了解大致的思路即可。有兴趣的同学可以结合代码和 这篇文章[5] 深入分析。

其它模型实现

Shap 库中还有很多其它模型的优化实现,比如对于深度神经网络,有利用 gradients 特性的 DeepExplainer,对于黑盒模型,也有优化了运算速度的 KernelExplainer 等,有兴趣的同学可以进一步深入学习。

另外 Shapley 的应用也不局限于模型解释,例如在滴滴拼车这类 cost sharing 问题上,也能见到它的身影。博弈论相关的思想,对于解决很多实际商业问题都很有借鉴意义,值得深入了解。

问题

这个方法看上去很美好,但实际看现在的 shap 库的运行时的默认参数,用的是 feature_perturbation='interventional'
模式,这种模式下必须传入 backgroupd data 才行。具体的原因可以看这个 github issue[6],做了很好的总结。

简单来说,我们在文中分析的方法,计算的是通过条件概率 E[f(x)|x_s] 的形式去获取到特征的贡献度,从而可能导致一些模型并没有 explicitly 使用的特征被赋予一些贡献度(参考 这篇文章[7] 中的例子)。而一些批评者认为从因果模型角度看,我们更应该使用 E[f(x)|do(X = x_s)] 的形式来构建 shap 解释的过程。所以目前的默认模式更符合这种因果的观点,从 background data 去做特征数值的采样达到 do(X = x_s) 的效果。

另外一个观点是 true to the data 和 true to the model 的 trade-off,如作者所说:

I view it as a fundamental trade-off between being "true to the data" and never providing inputs that are off-manifold, and being "true to the model" and never letting credit bleed between correlated features.

这篇论文给了一些例子,讲解我们如何在这两种模式中做出选择:https://arxiv.org/pdf/2006.16234.pdf

因果模型也是个非常有意思的话题,之后我们有时间再做进一步的分析讲解 :)

参考资料

[1]

notebooks: https://github.com/slundberg/shap/blob/master/notebooks/tabular_examples/tree_based_models

[2]

博弈论中的 shapley value: https://en.wikipedia.org/wiki/Shapley_value

[3]

Interpretable Machine Learning: https://christophm.github.io/interpretable-ml-book/shapley.html

[4]

shap 的论文: https://dl.acm.org/doi/pdf/10.5555/3295222.3295230

[5]

这篇文章: https://arxiv.org/pdf/1905.04610.pdf

[6]

github issue: https://github.com/slundberg/shap/issues/1098

[7]

这篇文章: https://arxiv.org/pdf/1910.13413.pdf


文章转载自RandomGenerator,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论