图神经网络模型集合GraphGallery,TensorFLow&PyTorch一并实现

2020 年 10 月 5 日 专知

GraphGallery

【导读】图神经网络(Graph Neural Networks,GNN)是近几年兴起的新的研究热点,其借鉴了传统卷积神经网络等模型的思想,在图结构数据上定义了一种新的神经网络架构。如果作为初入该领域的科研人员,想要快速学习并验证自己的idea,需要花费一定的时间搜集数据集,定义模型的训练测试过程,寻找现有的模型进行比较测试,这无疑是繁琐且不必要的。GraphGallery 为科研人员提供了一个简单方便的框架,用于在一些常用的数据集上快速建立和测试自己的模型,并且与现有的 benchmark 模型进行比较。其支持目前主流的两大机器学习框架:TensorFlow 和 PyTorch,为科研人员提供了一些简易操作的API。

安装

  • 直接从源码安装(可以体验最新版本)

git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
python setup.py install
  • 从 Pypi 安装(可以使用稳定版本)


# -U 表示升级使用最新版本
pip install -U graphgallery

快速上手

1. Dataset

数据集包含两种,一种是领域内划分好的数据集 Planetoid,以及扩展性更强的以 npz格式存储的数据集。

数据集详细信息请见 https://github.com/EdisonLeeeee/GraphData

  • Planetoid


from graphgallery.data import Planetoid
# set `verbose=False` to avoid additional outputs
data = Planetoid('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split() # 使用固定的划分,即 每个类别20个结点作为训练集,剩余结点中选取500个作为验证集,1000个作为测试集
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))

目前包含 3 种数据集


>>> data.supported_datasets
('citeseer', 'cora', 'pubmed')
  • NPZDataset


from graphgallery.data import NPZDataset;
data = NPZDataset('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split(random_state=42) # 采用 10%,10%,80%的划分
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))

目前包含 13 种数据集


>>> data.supported_datasets
('citeseer', 'citeseer_full', 'cora', 'cora_ml', 'cora_full',
'amazon_cs', 'amazon_photo', 'coauthor_cs', 'coauthor_phy',
'polblogs', 'pubmed', 'flickr', 'blogcatalog')

定义自己的 npz 数据集

from graphgallery.data import Graph

# Load the adjacency matrix A, attribute matrix X and labels vector y
# A - scipy.sparse.csr_matrix of shape [n_nodes, n_nodes]
# X - scipy.sparse.csr_matrix or np.ndarray of shape [n_nodes, n_atts]
# y - np.ndarray of shape [n_nodes]
...

mydataset = Graph(adj_matrix=A, attr_matrix=X, labels=y)
# save dataset
mydataset.to_npz('path/to/mydataset.npz')
# load dataset
mydataset = Graph.from_npz('path/to/mydataset.npz')

2. Config

GraphGallery 支持 TensorFlow 和 PyTorch 两个后端(默认TensorFlow 后端),通过切换后端可以调用不同的API和模型


>>> from graphgallery import backend, set_backend
>>> backend()
TensorFlow 2.1.2 Backend

>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend

>>> set_backend('tf') # tensorflow or tf
TensorFlow 2.1.2 Backend

同时,支持定义运算过程中的张量 浮点数和整数类型


>>> from graphgallery import intx, floatx, set_intx, set_floatx
>>> intx() # TensorFlow 后端整数默认 int32, PyTorch后端默认 int64
>>> floatx() # 对于两个后端浮点数默认皆为 float32

# 修改默认数据类型
>>> set_intx('int64')
>>> set_floatx('float64')


3. Tensor

GraphGallery 支持将任意输入转换为合适后端的张量(并给予合适的数据类型)

  • 普通张量


>>> backend()
TensorFlow 2.1.2 Backend

>>> from graphgallery import transforms as T
>>> arr = [1, 2, 3]
>>> T.astensor(arr)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
  • 稀疏张量


>>> import scipy.sparse as sp
>>> sp_matrix = sp.eye(3) # 创建一个 3X3 的单位矩阵
>>> T.astensor(sp_matrix)
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f1bbc205dd8>

类似的,只需要切换后端,亦可将输入转换为 PyTorch 张量


>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend

>>> T.astensor(arr)
tensor([1, 2, 3])

>>> T.astensor(sp_matrix)
tensor(indices=tensor([[0, 1, 2],
                      [0, 1, 2]]),
      values=tensor([1., 1., 1.]),
      size=(3, 3), nnz=3, layout=torch.sparse_coo)

astensor 函数接收三个参数,

  • x : 需要转化的Python对象

  • dtype: 转化的类型,若不指定则根据后端的 intx(), floatx() 函数推断

  • devicie: 参数所在的设备 (可以指定"CPU", "GPU", "cuda", "GPU:0" 等等),若不指定则为 "CPU:0"

  • kind: 转化成何种张量,"T" 表示 TensorFlow 张量,"P" 表示 PyTorch 张量,若不指定则模型转为当前后端适合的张量

4. Transforms

GraphGallery  的 transforms 模块包含各种对输入数据的变换操作,例如针对(稀疏)邻接矩阵的变换,(密集)特征矩阵的变换,以及包含上节所述的张量转换。

例如对稀疏邻接矩阵(adjacency matrix)做 GCN 常见的归一化操作


>>> from graphgallery import transforms as T
>>> T.normalize_adj(adj_matrix)

其默认实现了

以及对结点特征矩阵(Attribute matrix)做行归一化


>>> from graphgallery import transforms as T
>>> T.normalize_attr(attr_matrix)

其默认实现了

5. Models

顾名思义,GraphGallery 是一个GNN模型的 Gallery。

GraphGallery 实现了一系列的半监督结点分类模型,具体可见项目主页:https://github.com/EdisonLeeeee/GraphGallery

以最常见的GCN模型为例


from graphgallery.nn.models import GCN
model = GCN(graph, adj_transform='normalize_adj', attr_transform='normalize_attr', device="GPU", seed=123)
model.build()
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
loss, accuracy = model.test(idx_test, verbose=1)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
  • graph 是输入的图,adj_transform 是对邻接矩阵的变换,attr_transform 是对结点特征矩阵的变换,并且可以指定运行设备 device 和用于重现结果的随机种子 seed

  • 模型调用 build 快速搭建一个 GCN 模型,build 可以指定包含隐藏层单元个数(层数),激活函数,学习率等参数


# 一层隐藏层 (32单元),激活函数 RELU
>>> model.build(hiddens=32, activations='relu')

# 两层隐藏层(32和64单元),两层的激活函数都是 RELU
>>> model.build(hiddens=[32, 64], activations='relu')

# 两层隐藏层 (32和64单元),激活函数分别是 RELU 和 ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
  • 模型调用 train 方法进行训练。idx_train 是训练集结点,同理 idx_val是验证集结点(也可以不指定),verbose 可以指定 0, 1, 2, 3, 4 五种训练过程输出,返回的 his 是 一个记录训练历史情况的类,可以通过调用 his.history 查看训练过程的输出。

  • 模型调用 test 方法进行测试,idx_test 是测试集结点,verbose 可指定 0 和1两种,最终返回 测试集的损失和准确率

Planetoid Cora 数据集上的结果


Training...
100/100 [==============================] - 1s 14ms/step - loss: 1.0161 - acc: 0.9500 - val_loss: 1.4101 - val_acc: 0.7740 - time: 1.4180
Testing...
1/1 [==============================] - 0s 62ms/step - test_loss: 1.4123 - test_acc: 0.8120 - time: 0.0620
Test loss 1.4123, Test accuracy 81.20%


至此,只需要几行代码即可完成对一个模型的调用和训练测试,并且当你切换不同的后端,调用的是不同后端实现的模型(甚至不需要更改上述调用代码)。

后续工作

  • 实现更多的 GNN 模型(两种后端)

  • 支持更多的任务(目前主要支持半监督的结点分类任务),未来会加入链路预测,图分类等任务

  • 支持更多样的图数据结构(目前只支持单一无向同构图),未来会考虑异构图,多图

  • 为项目提供更好的项目文档和注释(完善中...)

GraphGallery 项目主页:https://github.com/EdisonLeeeee/GraphGallery

GraphData 项目主页:https://github.com/EdisonLeeeee/GraphData


专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎注册登录专知www.zhuanzhi.ai,获取5000+AI主题干货知识资料!
欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程资料和与专家交流咨询
点击“ 阅读原文 ”,了解使用 专知 ,查看获取5000+AI主题知识资源
登录查看更多
19

相关内容

数据集,又称为资料集、数据集合或资料集合,是一种由数据所组成的集合。
Data set(或dataset)是一个数据的集合,通常以表格形式出现。每一列代表一个特定变量。每一行都对应于某一成员的数据集的问题。它列出的价值观为每一个变量,如身高和体重的一个物体或价值的随机数。每个数值被称为数据资料。对应于行数,该数据集的数据可能包括一个或多个成员。
【KDD2020】 解决基于图神经网络的会话推荐中的信息损失
专知会员服务
31+阅读 · 2020年10月29日
【KDD2020】图深度学习:基础、进展与应用,182页ppt
专知会员服务
133+阅读 · 2020年8月30日
专知会员服务
132+阅读 · 2020年8月24日
一份简单《图神经网络》教程,28页ppt
专知会员服务
120+阅读 · 2020年8月2日
专知会员服务
118+阅读 · 2020年7月22日
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
126+阅读 · 2020年3月15日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
68+阅读 · 2020年1月17日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
下载 | 最全中文文本分类模型库,上手即用
机器学习算法与Python学习
30+阅读 · 2019年10月17日
用 TensorFlow hub 在 Keras 中做 ELMo 嵌入
AI研习社
5+阅读 · 2019年5月12日
Github热门图深度学习(GraphDL)源码与框架
新智元
21+阅读 · 2019年3月19日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
Github 项目推荐 | 用 TensorFlow 实现的模型集合
AI研习社
5+阅读 · 2018年2月14日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
GitHub上大热的Deep Photo终于有TensorFlow版了!
量子位
4+阅读 · 2017年8月14日
Arxiv
3+阅读 · 2020年11月28日
Arxiv
0+阅读 · 2020年11月26日
Arxiv
27+阅读 · 2020年6月19日
Arxiv
7+阅读 · 2018年6月1日
Arxiv
3+阅读 · 2018年6月1日
Arxiv
9+阅读 · 2018年2月4日
VIP会员
相关VIP内容
【KDD2020】 解决基于图神经网络的会话推荐中的信息损失
专知会员服务
31+阅读 · 2020年10月29日
【KDD2020】图深度学习:基础、进展与应用,182页ppt
专知会员服务
133+阅读 · 2020年8月30日
专知会员服务
132+阅读 · 2020年8月24日
一份简单《图神经网络》教程,28页ppt
专知会员服务
120+阅读 · 2020年8月2日
专知会员服务
118+阅读 · 2020年7月22日
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
126+阅读 · 2020年3月15日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
68+阅读 · 2020年1月17日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
相关资讯
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
下载 | 最全中文文本分类模型库,上手即用
机器学习算法与Python学习
30+阅读 · 2019年10月17日
用 TensorFlow hub 在 Keras 中做 ELMo 嵌入
AI研习社
5+阅读 · 2019年5月12日
Github热门图深度学习(GraphDL)源码与框架
新智元
21+阅读 · 2019年3月19日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
Github 项目推荐 | 用 TensorFlow 实现的模型集合
AI研习社
5+阅读 · 2018年2月14日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
GitHub上大热的Deep Photo终于有TensorFlow版了!
量子位
4+阅读 · 2017年8月14日
相关论文
Arxiv
3+阅读 · 2020年11月28日
Arxiv
0+阅读 · 2020年11月26日
Arxiv
27+阅读 · 2020年6月19日
Arxiv
7+阅读 · 2018年6月1日
Arxiv
3+阅读 · 2018年6月1日
Arxiv
9+阅读 · 2018年2月4日
Top
微信扫码咨询专知VIP会员