8

【ICLR 2018图神经网络论文解读】Graph Attention Networks (GAT) 图注意力模型

 2 years ago
source link: https://weisenhui.top/posts/61610.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client
【ICLR 2018图神经网络论文解读】Graph Attention Networks (GAT) 图注意力模型

问题:我们能不能让图自己去学习A节点与A的邻居节点之间聚合信息的权重呢?
本文提出的模型GAT就是答案

Graph Attention Network为了避免与GAN弄混,因此缩写为GAT。

与GCN类似,GAT同样是一种局部网络。因此,(相比于GNN或GGNN等网络)训练GAT模型无需了解整个图结构,只需知道每个节点的邻节点即可。

GCN结合邻近节点特征的方式和图的结构依依相关,这也给GCN带来了几个问题:无法完成inductive任务,即处理动态图问题。

而GAT则能够很好地处理上述问题。

GAT模型简介

(1)输入:$\mathbf{h}=\left{\vec{h}{1}, \vec{h}{2}, \ldots, \vec{h}{N}\right}, \vec{h}{i} \in \mathbb{R}^{F},N表示节点个数,F表示每个节点的特征数量。(2)经过GAT(3)输出:,N表示节点个数,F表示每个节点的特征数量。(2)经过GAT(3)输出:\mathbf{h}^{\prime}=\left{\vec{h}{1}^{\prime}, \vec{h}{2}^{\prime}, \ldots, \vec{h}{N}^{\prime}\right}, \vec{h}{i}^{\prime} \in \mathbb{R}^{F^{\prime}}$

GAT这个模型的结构比较简单,基本看一下他的公式和论文里的图,就能懂了

为了提升模型的表达能力和训练稳定性,论文中还借鉴了Transformer模型结构中的多头注意力层,也就是用kk个WkWk得到不同的注意力,再将其拼接在一起或者求平均。

这篇GAT论文的模型不是很难,相比模型,我们可以来思考下几个容易忽视的问题

1. 为什么要研究图神经网络模型?

一句话:图神经网络是研究怎么通过一个图结构抽取出更好的节点特征

当我们有一堆节点的特征x和每个节点对应的标签y的时候,我们是可以直接当做一个结构化的数据,把他放到神经网络去训练模型的,就像MNIST手写识别任务一样。

但是我们现在的数据之间是有关系的,他们的关系可以用图来表示,那么我是不是可以利用这个关系图来帮助我训练呢?显然是可以的,这也是图神经网络要做的。

举个例子,现在我要做一个用户身份识别的任务,最简单的做法自然是直接使用用户特征和用户的标签去训练一个分类模型就行了。但是你拿的用户特征是非常稀疏的,很难准确地表示用户的信息。如果现在我还告诉你用户的社交关系(一个关系图),那你就可以利用这个社交关系去构建更好的用户特征,比如用户A的朋友都是有钱人,那用户A大概率也是有钱人,然后再放到模型中去训练,那效果就会好很多。

简单来说,图神经网络模型,比如GCN,GAT就是把节点特征x放到图模型里面去,然后得到一个更好的节点特征xbetterxbetter,此时我们再把(xbetter,y)(xbetter,y),放到一个神经网络中去进行分类任务,效果就会比(x,y)(x,y)要好

2. Transduction Learning和Inductive Learning的区别

引用:归纳式和直推式学习(Inductive vs. Transductive Learning)

归纳式(Inductive)

  • 归纳式学习是我们传统理解的监督学习(supervised learning),我们基于已经打标的训练数据,训练一个机器学习模型。然后我们用这个模型去预测我们没有从未见过的测试数据集上的标签。

直推式(Transduction)

  • 和归纳式不同,直推式学习首先观察全部数据,包括了训练和测试数据,我们从观测到的训练数据集上进行学习,然后在测试集上做预测。即便如此,我们不知道测试数据集上的标签,我们可以在训练时利用数据集上的模式和额外信息

两者区别:直推式学习不会建立一个预测模型,如果一个新数据节点加入测试数据集,那么我们需要重新训练模型,然后再预测标签。然而,归纳式学习会建立一个预测模型,遇到新数据节点不需要重新运行算法训练。


我的理解:

  • Transductive learning(意为直推学习)就是在训练阶段已经看到过测试集节点的一些信息,比如在GCN中就是节点的结构信息已经被看到了,然后在测试阶段你只能处理这些测试集,不能处理新的节点,因为这些节点的结构信息你训练阶段没见过,模型就懵了。换句话说,测试集只是没有用标签的信息来帮助训练共享的参数W(训练集的节点和标签会帮助我训练W),但是图的结构必须包含测试集的节点
  • Inductive learning(意为归纳学习),它在训练阶段完全不需要用到测试集的信息,也就是说你测试集不管是训练的时候见过的(利用过某些信息的),或者没见过的,我都无所谓。

3. GAT为什么可以是Inductive Learning

GAT的论文中提到GAT这个model也可以解决Inductive learning的问题

GAT中注意力系数的计算公式:

多头GAT的Average聚合公式:

GAT在训练阶段学习参数a和W,然后在预测阶段,对于一个新的图结构,我虽然没在训练阶段见过这个图结构,但是我只要知道节点i的表征hihi和它邻居的表征hjhj,我就能通过学习好的参数a和W计算出聚合后节点i的表征h′ihi′。所以说GAT不依赖与完整的图结构,只依赖于边,因此可以用于inductive任务

而GCN是transductive learning,不是inductive learning,因为他训练模型后,如果直接应用于未知的图结构,那拉普拉斯矩阵就会发生改变,以前训练好的基于原图的模型也就失效了。

GCN需要利用图的结构来提取节点ii(原特征为hihi)更好的表征h′ihi′,你现在如果来一个新的节点,我之前训练的时候没见过它的图结构,那也就无法用GCN提取它的表征h′ihi′,进而无法去做后续的节点标签预测了

两层GCN模型:

其中N表示结点个数,C表示每个节点的特征向量维度,F表示经过两层GCN后,每个结点的新的特征向量的维度,W(0),W(1)W(0),W(1)是需要学习的参数

我们可以看一下GCN是怎么划分训练集和测试集的:

dataset = CoraData().data
node_feature = dataset.x / dataset.x.sum(1, keepdims=True)
tensor_x = tensor_from_numpy(node_feature, DEVICE)
tensor_y = tensor_from_numpy(dataset.y, DEVICE)
tensor_train_mask = tensor_from_numpy(dataset.train_mask, DEVICE)
tensor_val_mask = tensor_from_numpy(dataset.val_mask, DEVICE)
tensor_test_mask = tensor_from_numpy(dataset.test_mask, DEVICE)
normalize_adjacency = CoraData.normalization(dataset.adjacency)   # 规范化邻接矩阵

num_nodes, input_dim = node_feature.shape #(N,D)

indices = torch.from_numpy(np.asarray([normalize_adjacency.row, 
                                       normalize_adjacency.col]).astype('int64')).long()
values = torch.from_numpy(normalize_adjacency.data.astype(np.float32))
tensor_adjacency = torch.sparse.FloatTensor(indices, values, 
                                            (num_nodes, num_nodes)).to(DEVICE)

可以看到它是把所有数据放在一起,然后得到一个邻接矩阵adjacency,然后再通过mask的方式获得train, val, test datasets,这就是我上面说的,GCN需要利用到测试集上的一些额外信息(构建邻接矩阵)


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK