暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

ICML 2021 | 针对图神经网络的通用因果解释方法

AISeer 2021-09-22
1164

论文标题 | Generative Causal Explanations for Graph Neural Networks

论文来源 | ICML 2021

论文链接 | https://arxiv.org/abs/2104.06643

源码链接 | https://github.com/wanyu-lin/ICML2021-Gem

TL;DR

论文中提出了一种通用的因果解释方法 Gem,可以为不同图学习任务中的 GNNs 提供通用的可解释性。主要想法是将 GNNs 模型的推理决策可解释性问题转为一个因果学习任务,然后基于格兰杰因果 (Granger causality) 的目标函数来训练一个因果解释模型。Gem 不依赖于 GNNs 的内部结构和图学习任务相关的先验知识,因此泛化性较强而且可以从图结构数据的因果角度来解释 GNNs ~ 保持好奇 🧐。实验部分在生成数据集和真实数据中验证了 Gem 相对其它解释模型而言,可以提升 30% 的解释准确性而且解释速度可达 110×。

Problem Definition

由于 GNNs 模型的可解释性研究是一个新兴的研究领域,当前主流的研究工作如下 「Mark」 🐶

  • 「GNNExplainer」: [NeurIPS 2019] Generating Explanations for Graph Neural Networks

  • 「PGExplainer」[NeurIPS 2020] Parameterized Explainer for Graph Neural Network

  • 「PGM-Explainer」: [NeurIPS 2020] Probabilistic Graphical Model Explanations for Graph Neural Networks

  • 「XGNN」[KDD 2020] Towards Model-Level Explanations of Graph Neural Networks

以上方法都是基于图结构或者加性特征归因 (addittive feature attribution) 的方法来解释 GNNs 的推理结果,泛化性较差而且没有从因果层面来考虑。

而这篇文章从因果模型来解释 GNNs,希望这一篇文章可以解决我多年前对 GNNs 因果推理的疑惑 🤔

首先好奇会以什么形式来解释 GNNs ⁉️

一般的图学习任务定义如下:

给定图集合 ,每个图表示为 ,其节点集合为 ,对应的每个节点特征维度为 。论文中考虑在 Graph-level 和 Node-level 级别的「分类任务」来解释 GNNs。

  • 对于 Graph-level 分类任务的数据集为 ,每个图 其对应的标签为 ,其中 表示类别数量。

  • 对于 Node-level 分类任务的数据集为 ,每个图 中的节点 对应的标签为

论文中使用示例 表示一个「实例」,在 Graph-level 中对应 ,在 Node-level 中对应 。所以下文谈到的实例可以为 Graph or Node。

GNNs 模型可以形式化化地表示如下:

  • Graph-level:

  • Node-level:

对应的目标函数如下, ,其中 表示真实分类, 表示预测输出,标量 表示对应的损失。

模型解释的任务是给定一个预训练模型 ,得到模型 对预训练模型进行快速精确的解释,预训练模型被称为 target GNN。

解释下 GNNs 模型的解释形式:

Intrinsically, an explanation is a subgraph that is the most relevant for a prediction —— the outcome of the target GNN.

个人理解:对于预测模型的预测输出,需要找到原图中的对应的子图来支持这一分类结果,找到子图这一过程由解释模型完成。这个子图作用在于,带着疑问看?

Algorithm/Model

论文中提出的模型如下图所示

Gem 整体架构

主要包括两个模块:

  • Distillation process:对于图中的边进行因果贡献分数计算。
  • Graph generator:根据 distillation 得到的子图监督训练图生成模型。

因果贡献

GNNs 对于一个实例的预测结果主要在于图结构 computation graph ,其中 表示邻接矩阵。

GNN 的目标是学习到一个分类条件分布

一个计算图示例如下,对于节点 的 2-hop。

图示例

给定预训练的 GNNs 模型和对应的实例 ,其对应的分类结果为 ,所以解释模型的任务是找到预测结果 对应的子图

考虑到格兰杰因果的主要思想,可以量化图中边 对 预训练 GNN 模型预测误差 的因果贡献。

这样就可以计算删除边 对模型的因果贡献,其预测误差即为损失函数计算得到的值。计算方式如下

利用删除边的策略来计算因果贡献这种方法论文中将之称为「ground-truth distillation process.」

通过计算图中所有边的因果贡献,可以直接根据贡献分数排序来选择 top-K 最相关的边作为预测解释。但是图数据中的边贡献分布并不是独立的,因此作者使用 「graph rules」 来提高 distillation 过程。这一步就比较玄学了,完全取决于数据特征,但是大部分应该考虑了 top-K distillation edges 应该是连通的。

以上蒸馏过程中得到的子图表示为 可以根据其它模型的输出来进行有监督的,以此来解释其它模型

因果解释模型

原则上任意的图生成模型都可以用作图因果解释模型,论文中用到的是 Graph auto-encoder。

其中 表示 computation graph 的邻接矩阵, 表示每条边对 预测其子图的贡献。

对于目标节点的模型解释输出为 computation graph 的一个 compact subgraph,以此解释一个节点分类为什么得到当前标签。

解释模型直接使用上一步根据因果贡献 top-K 边过滤用到的子图进行训练,损失函数为 RMSE。还使用了 「node labeling technique」 技术来区分不同的节点,这部分不再细述,了解其整体过程就可以了。

Experiments

实验部分使用了人工和真实数据集,解释准确率如下所示

实验结果

其运行时间对比结果如下

实验结果

以一个例子说明输出子图的效果

实验结果

Thoughts

  • 论文中提出的 Gem 模型整体思路清楚而且模型简单易懂,大佬的文章果然写得就是不一样 👍
  • 整体而言是考虑了因果关系,但是 Explainer 非常依赖因果分数计算而且超参数 top-K...
  • Explainer 用了两层 GCN 是不是只能学到 2-hop 内的因果关联,如果超出这个 hop 重构的邻接矩阵会不是不太靠谱呢
  • 回到模型因果可解释性问题,没想到 GNN 可以用 compact subgraph 来解释模型结果。


推荐阅读


  点击「阅读原文」留言评论哦❣️

文章转载自AISeer,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论