三种方法实现GNN鲁棒的中值聚合过程

三种方法实现GNN鲁棒的中值聚合过程

在前面的文章中,介绍了我们的工作:

文章提出了一个鲁棒的中值聚合函数来提高GNN的鲁棒性,那么如何实现这一自定义的聚合函数呢?本文提供了三种方法实现,包括直接利用PyTorch,以及使用PyG和DGL框架进行实现。

1初始化数据

首先,我们先初始化一个简单的图数据

import networkx as nx
num_nodes = 5
k = num_nodes // 2 + 1
G = nx.newman_watts_strogatz_graph(num_nodes, k, 0.1)

这是一个简单的带有5个节点的图,然后我们可以利用matplotlib画出这个图的形状

nx.draw(G, with_labels=True, font_weight='bold')

在GNN中,假设每个节点都有一个初始化的特征向量,这里我们使用随机初始化

num_feats = 2
feat = torch.randn(num_nodes, num_feats)

准备就绪,接下来开始实现基于中值聚合函数的GNN。

2直接利用PyTorch实现

通常我们使用的GNN聚合方式为:

h_{v}^{(k)}=MEAN^{(k)}\left(\left\{h_{u}^{(k-1)} \mid u \in \mathcal{N}(v)\right\}\right) \\

这里的MEAN也可以换成SUM。为简单起见,这里不考虑各个节点聚合的权重。

而本文中我们需要实现的基于中值聚合函数的GNN聚合过程可以表示为如下:

h_{v}^{(k)}=MEDIAN^{(k)}\left(\left\{h_{u}^{(k-1)} \mid u \in \mathcal{N}(v)\right\}\right) \\

虽然只是将聚合函数从MEAN换成了MEDIAN,但实现起来却复杂得多,因为MEDIAN涉及到了排序的操作。一个简单的思路是直接照着公式实现。

  1. 转化成稀疏矩阵形式并添加自环
import scipy.sparse as sp
adj_matrix = nx.to_scipy_sparse_matrix(G)
adj_matrix += sp.eye(adj_matrix.shape[0], format='csr') 

adj_matrix是一个稀疏表示的邻接矩阵,代表着节点间的连接关系,我们可以输出它的密集二维矩阵观察:

>>> adj_matrix.A
array([[1., 1., 0., 1., 1.],
       [1., 1., 1., 0., 0.],
       [0., 1., 1., 1., 0.],
       [1., 0., 1., 1., 1.],
       [1., 0., 0., 1., 1.]])
  1. 将邻接矩阵转化成一系列张量:
neighbors = [torch.as_tensor(row) for row in adj_matrix.tolil().rows]

neighbors是一个List,里面是代表着每个节点对应的邻居节点的下标:

>>> neighbors
[tensor([0, 1, 3, 4]),
 tensor([0, 1, 2]),
 tensor([1, 2, 3]),
 tensor([0, 2, 3, 4]),
 tensor([0, 3, 4])]
  1. 实现中值聚合过程
aggregation = []
for nbr in neighbors:
    message = torch.median(feat[nbr], dim=0).values
    aggregation.append(message)
h = torch.stack(aggregation)

这里的实现过程是比较简单直接的,就是遍历每个节点,然后采用中值函数去计算它的邻居消息,最终将所有节点的表征拼接起来得到最后的整个表示。这里的h是一个(num_nodes, num_feats)大小的矩阵,代表着最终的聚合结果:

>>> h
tensor([[-0.0803, -0.3828],
        [ 0.2315, -0.3140],
        [ 0.2315, -0.3140],
        [-0.0803, -0.3828],
        [-0.0803, -0.3828]])

3PyG实现

上述方法的实现虽然看起来简单明了,但是实现效率却偏低。这里我们采用PyG来实现这一中值聚合过程。在PyG中目前支持的聚合函数是mean, sun, max, min,并不支持中值聚合。因此,我们需要自定义这个实现,实现原理是: 提取度大小一样的节点,组成多个块,每个块的大小为[M, D, num_feat],其中D为度的大小,M为度为D的节点的数量,对于这个规则的块(或者矩阵),我们就可以直接运用torch.median函数了:

from torch_geometric.utils import to_dense_batch

edge_index = torch.as_tensor(adj_matrix.nonzero()).long()

row, col = edge_index

x_j = feat[col]
# 注意:to_dense_batch要求下标row按顺序排列否则会出错
dense_x, mask = to_dense_batch(x_j, row)
h = x_j.new_zeros(dense_x.size(0), dense_x.size(-1))
deg = mask.sum(dim=1)
for i in deg.unique():
    deg_mask = deg == i
    h[deg_mask] = dense_x[deg_mask, :i].median(dim=1).values

最后的结果也是与上述的实现一样,效率却更高:

>>> h
tensor([[-0.0803, -0.3828],
        [ 0.2315, -0.3140],
        [ 0.2315, -0.3140],
        [-0.0803, -0.3828],
        [-0.0803, -0.3828]])

4DGL实现

虽然上述PyG的实现比纯PyTorch实现快了不少,但是仍然不够优雅,尤其是将度一样的节点提取出来。幸运的是,DGL帮我们实现了这一步骤,我们只需要修改reduce过程即可。

首选转换成DGL图,并赋予节点特征:

g = dgl.from_scipy(adj_matrix)
g.ndata['h'] = feat
g

接着,只需要修改聚合函数为自定义的median_reduce即可:

import torch
import dgl.function as fn

def median_reduce(nodes):
    return {'h': torch.median(nodes.mailbox['m'], 
                              dim=1).values}

aggregate_fn = fn.copy_src('h', 'm')
reduce_fn = median_reduce
g.update_all(aggregate_fn, reduce_fn)
h = g.ndata['h']

最后的结果也是与上述的实现一样,效率却更高并且更加简洁:

>>> h
tensor([[-0.0803, -0.3828],
        [ 0.2315, -0.3140],
        [ 0.2315, -0.3140],
        [-0.0803, -0.3828],
        [-0.0803, -0.3828]])

5总结

上述提供了三种实现自定义中值聚合的方法,从最简单的PyTorch实现,到PyG和DGL实现,实现过程逐渐简化,效率逐渐提高。然而,必须说明的是,尽管效率提高了不少,但是运算效率还是远远比不上传统的mean, sun, max, min聚合方式。

上述实现为了简洁说明,并没有加入线性层转换,完整的实现可以参考GraphGallery中的实现:

发布于 2021-11-27 22:28