49

可解释性论文阅读笔记1-Tree Regularization

 4 years ago
source link: http://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650412195&%3Bidx=2&%3Bsn=5363c291223c99f537fc3ea8a951829f
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

作者: HardenHuang

学校: 清华大学

研究方向: 自然语言处理

知乎专栏: 模型可解释性论文专栏

AAAI2018的一篇关于深度学习模型可解释性的文章,文章的主要亮点是引入tree regularization在训练深度学习网络的同时给出一个具有可解释性能力的决策树

Beyond Sparsity: Tree Regularization of Deep Models for Interpretability arxiv.org

link: https://arxiv.org/pdf/1711.06178.pdf

1. 模型可解释性的重要性

(1)深度模型的进一步应用需要模型可解释性,例如金融、医疗等领域

(2)仿生性模型更容易被人理解和应用,例如决策树模型

2. Model Interpretion Introduction

做模型可解释性有两套思路:一是为已经训练好的模型寻找可解释性;二是训练出更具有解释性的模型

(1)为已经训练好的模型寻找可解释性: 寻找神经网络模型的决策树表示 [1] ,对输入和输入的梯度进行敏感性分析 [2] ,寻找模型的编程化表示 [3] ,寻找模型的规则集合表示 [4]

(2) 训练出更具解释性的模型:惩罚更不相关的特征获得稀疏特征 [5] ,从文本输入中找到高亮部分 [6]

3. Related work

利用了各类正则化:L1正则化方法 [7] ,binary network [8] ,Edge and node regularization [9]

4. 方法的具体描述

key idea :训练深度学习模型同时,获取准确率高与复杂度小的决策树,决策树的复杂度作为正则项;

决策树的复杂度 :利用APL (average path length),训练集中样例做出决策平均经过的节点数,APL即为Tree Regularization。

问题 :L1正则化 L2正则化等正则化项可以由一个函数 计算出来, 为深度模型 的权重,Tree Regularization却不能简单的由 表示

解决方法 :引入一个替代网络 (多层感知机)去近似表示决策树 , 的输入为 训练过程中的权重 , 输出APL的预测值  YjIrYfy.png!web , 为模型的参数,标签为的APL输出 ,  YjIrYfy.png!web 与 的权重产生关联,因此可以用梯度下降方法进行优化

具体训练过程

step1:训练深度学习模型,输入特征 及真实标签 ,输出标签预测值 ,损失函数如下,其中 YjIrYfy.png!web 来自于的输出

step2: 训练决策树模型,输入为特征 和标签预测值 ,输出为 ,训练过程与一般的决策树分裂过程相同

step3:训练替代网络, 输入为各步训练过程中得到的权重,输出为APL的预测值  YjIrYfy.png!web ,标签来自于决策树模型的,损失函数如下,

FRneIfA.png!web

如此重复训练,可以同时得到深度学习模型与具有可解释性的模型 ,如下是一个决策树的示例

j26riuj.jpg!web

5. 实验及结果

语音识别任务

ae6Frqv.jpg!web

脓毒症重症监护(Sepsis Critical Care)

UFbMFvq.jpg!web

艾滋病治疗结果(HIV Therapy Outcome)

FRzQJjn.jpg!web

6. 代码解读

tree-regularization-public github.com

link: https://github.com/dtak/tree-regularization-public

训练过程:可以看到是 R3iUfuz.png!web 的依次训练过程

def train(self, X_train, F_train, y_train, iters_retrain=25, num_iters=1000,
              batch_size=32, lr=1e-3, param_scale=0.01, log_every=10):
        npr.seed(42)
        num_retrains = num_iters // iters_retrain
        for i in xrange(num_retrains):
            self.gru.objective = self.objective
            # carry over weights from last training
            init_weights = self.gru.weights if i > 0 else None
            print('training deep net... [%d/%d], learning rate: %.4f' % (i + 1, num_retrains, lr))
            self.gru.train(X_train, F_train, y_train, num_iters=iters_retrain,
                           batch_size=batch_size, lr=lr, param_scale=param_scale,
                           log_every=log_every, init_weights=init_weights)
            print('building surrogate dataset...')
            W_train = deepcopy(self.gru.saved_weights.T)
            APL_train = self.average_path_length_batch(W_train, X_train, F_train, y_train)
            print('training surrogate net... [%d/%d]' % (i + 1, num_retrains))
            self.mlp.train(W_train[:self.gru.num_weights, :], APL_train, num_iters=3000,
                           lr=1e-3, param_scale=0.1, log_every=250)

        self.pred_fun = self.gru.pred_fun
        self.weights = self.gru.weights
        # save final decision tree
        self.tree = self.gru.fit_tree(self.weights, X_train, F_train, y_train)

        return self.weights

参考

  1. ^train decision trees for pretrained neural network (Craven and Shavlik (1996))

  2. ^Adler, P.; Falk, C.; Friedler, S. A.; Rybeck, G.; Scheidegger, C.; Smith, B.; and Venkatasubramanian, S. 2016. Auditing black-box models for indirect influence. In ICDM

  3. ^Ribeiro, M. T.; Singh, S.; and Guestrin, C. 2016. Why should I trust you?: Explaining the predictions of any classifier. In KDD

  4. ^Lakkaraju, H.; Bach, S. H.; and Leskovec, J. 2016. Interpretable decision sets: A joint framework for description and prediction. In KDD.

  5. ^Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Right for the right reasons: Training differentiable models by constraining their explanations. In IJCAI

  6. ^Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Right for the right reasons: Training differentiable models by constraining their explanations. In IJCAI

  7. ^Zhang, Y.; Lee, J. D.; and Jordan, M. I. 2016. l1-regularized neural networks are improperly learnable in polynomial time. In ICML

  8. ^Tang, W.; Hua, G.; and Wang, L. 2017. How to train a compact binary neural network with high accuracy? In AAAI.

  9. ^Ochiai, T.; Matsuda, S.; Watanabe, H.; and Katagiri, S. 2017. Automatic node selection for deep neural networks using group lasso regularization. In ICASSP

本文由作者授权AINLP原创发布于公众号平台,点击'阅读原文'直达原文链接,欢迎投稿,AI、NLP均可。

原文链接:

https://zhuanlan.zhihu.com/p/99384386

推荐阅读

AINLP年度阅读收藏清单

关系提取简述

知识图谱存储与查询:自然语言记忆模块(NLM)

AINLP-DBC GPU 使用体验指南

BERT论文笔记

XLNet 论文笔记

征稿启示| 让更多的NLPer看到你的文章

AINLP-DBC GPU 云服务器租用平台建立,价格足够便宜

我们建了一个免费的知识星球:AINLP芝麻街,欢迎来玩,期待一个高质量的NLP问答社区

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLP君微信(id:AINLP2),备注工作/研究方向+加群目的。

qIR3Abr.jpg!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK