图神经网络GNN:让模型学会理解关系数据
前面学习线性回归、逻辑回归、决策树、GBDT 这类模型时,我们面对的数据通常是表格、向量或图片。每个样本大多可以单独看待。
但现实中还有很多数据天然带有“关系”:
- 社交网络中,用户之间有好友关系。
- 推荐系统中,用户、商品、点击、购买构成交互图。
- 分子结构中,原子是节点,化学键是边。
- 交通网络中,路口是节点,道路是边。
- 知识图谱中,实体和实体之间有语义关系。
这些数据的重点不只在于“单个对象有什么特征”,还在于“对象之间如何连接”。图神经网络(Graph Neural Network,GNN)就是专门用来处理这类图结构数据的神经网络。
一句话概括:
GNN 让每个节点不断接收邻居的信息,更新自己的表示,从而把节点特征和图结构一起编码进向量里。
1. 为什么普通神经网络不够用
假设我们要判断一个社交网络用户是否可能对某个话题感兴趣。
如果只看用户自己的特征,例如年龄、地区、历史点击,当然能得到一些信息。
但在社交网络里,用户的邻居也很重要:
- 他的朋友是否都关注这个话题?
- 他是否处在某个兴趣社区中?
- 他和哪些关键用户有连接?
- 他在网络中的位置是否特殊?
普通全连接神经网络通常假设输入是固定长度向量,不能直接表达“这个节点连着哪些节点”。
CNN 适合图片,因为图片像素排列在规则网格上;RNN/Transformer 适合序列,因为序列有明确顺序。
图数据不一样。它没有固定网格,也没有天然顺序。每个节点的邻居数量可能不同,图的大小也可能不同。
所以我们需要一种模型,既能利用节点本身的特征,又能利用节点之间的连接关系。
这就是 GNN 出现的原因。
2. 图数据如何表示
一个图通常写成:
$$
G=(V,E)
$$
其中:
- $V$ 是节点集合。
- $E$ 是边集合。
如果图中有 $n$ 个节点,每个节点有 $d$ 维特征,那么节点特征矩阵可以写成:
$$
X \in \mathbb{R}^{n \times d}
$$
图的连接关系常用邻接矩阵表示:
$$
A \in \mathbb{R}^{n \times n}
$$
如果节点 $i$ 和节点 $j$ 之间有边,那么:
$$
A_{ij}=1
$$
否则:
$$
A_{ij}=0
$$
在工程实现中,尤其是稀疏图里,通常不会真的存一个巨大的邻接矩阵,而是用边列表:
edge_index = [
[0, 1, 1, 2],
[1, 0, 2, 1],
]
它表示边:
0 -> 1
1 -> 0
1 -> 2
2 -> 1
这比存完整矩阵更节省空间。
3. GNN 的核心思想:消息传递
大多数 GNN 都可以用消息传递(Message Passing)来理解。
对于某个节点,它会做三件事:
- 从邻居节点接收消息。
- 把邻居消息聚合起来。
- 用聚合后的信息更新自己的节点表示。
可以写成一个通用形式:
$$
h_v^{(k)} =
\text{UPDATE}^{(k)}
\left(
h_v^{(k-1)},
\text{AGGREGATE}^{(k)}
\left(
{h_u^{(k-1)}:u\in \mathcal{N}(v)}
\right)
\right)
$$
其中:
- $h_v^{(k)}$ 表示节点 $v$ 在第 $k$ 层的表示。
- $\mathcal{N}(v)$ 表示节点 $v$ 的邻居集合。
- AGGREGATE 表示邻居聚合函数。
- UPDATE 表示节点更新函数。
直观理解:
第 1 层:节点看到一跳邻居的信息。
第 2 层:节点看到两跳邻居的信息。
第 3 层:节点看到三跳邻居的信息。
层数越深,节点能接收到越远的图结构信息。但层数也不能无限加深,因为可能出现过平滑问题,后面会讲。
4. 一个小例子:判断节点类别
假设有一个论文引用网络:
- 节点表示论文。
- 边表示引用关系。
- 节点特征是论文标题和摘要提取出的向量。
- 标签是论文所属领域,例如机器学习、数据库、计算机视觉。
如果只看单篇论文的文字特征,模型可以做分类。
但引用关系也很有用:
一篇论文引用了很多图神经网络论文,
它自己也很可能和图学习相关。
GNN 会让每篇论文聚合相邻论文的信息。这样,一个节点的最终表示不仅包含自己的摘要信息,也包含它引用了谁、被谁引用、处在哪个研究社区中。
最后用节点表示做分类:
$$
\hat{y}_v = \text{softmax}(W h_v)
$$
这就是节点分类任务。
5. GNN 能解决哪些任务
GNN 的预测目标可以放在不同层级。
5.1 节点级任务
节点级任务是给每个节点预测标签。
例如:
- 社交网络中预测用户兴趣。
- 引用网络中预测论文类别。
- 欺诈检测中预测账户是否异常。
5.2 边级任务
边级任务是预测两个节点之间的关系。
例如:
- 推荐系统中预测用户是否会点击商品。
- 社交网络中预测两个人是否可能成为好友。
- 知识图谱中预测实体之间是否存在某种关系。
这类问题也叫链接预测(Link Prediction)。
5.3 图级任务
图级任务是给整个图预测标签。
例如:
- 判断一个分子是否有毒。
- 判断一个代码函数是否存在漏洞。
- 判断一个蛋白质结构属于哪一类。
图级任务通常需要先得到每个节点的表示,再通过 readout 或 pooling 操作得到整张图的表示。
6. GCN:最经典的图卷积网络
GCN(Graph Convolutional Network)是最经典、最常见的 GNN 模型之一。
GCN 的直觉非常简单:
一个节点的新表示,来自它自己和邻居节点表示的加权平均,再经过线性变换和非线性激活。
如果直接把邻居特征相加,会有一个问题:度数大的节点会收到更多信息,数值规模可能变大。
所以 GCN 会做归一化。常见形式是:
$$
H^{(l+1)}
=
\sigma
\left(
\tilde{D}^{-\frac{1}{2}}
\tilde{A}
\tilde{D}^{-\frac{1}{2}}
H^{(l)}
W^{(l)}
\right)
$$
其中:
- $\tilde{A}=A+I$,表示给图加自环,让节点也保留自己的信息。
- $\tilde{D}$ 是 $\tilde{A}$ 的度矩阵。
- $H^{(l)}$ 是第 $l$ 层节点表示。
- $W^{(l)}$ 是可学习参数。
- $\sigma$ 是激活函数,例如 ReLU。
这个公式看起来很像矩阵运算版的消息传递:
邻接矩阵决定从哪些邻居接收消息
度矩阵负责归一化
权重矩阵负责特征变换
激活函数引入非线性
7. 手工理解一次 GCN 聚合
假设节点 1 有两个邻居:节点 0 和节点 2。为了简单,先不考虑复杂归一化,只做平均聚合。
三个节点的特征为:
| 节点 | 特征 |
|---|---|
| 0 | $[1, 0]$ |
| 1 | $[0, 1]$ |
| 2 | $[1, 1]$ |
如果给节点加自环,那么节点 1 会聚合:
节点 0、节点 1、节点 2
平均后得到:
$$
h_1’=
\frac{[1,0]+[0,1]+[1,1]}{3}
=
\left[
\frac{2}{3},
\frac{2}{3}
\right]
$$
这个新表示已经不只是节点 1 自己的特征,而是混入了邻居节点的信息。
如果再经过一层线性变换:
$$
h_1’’ = h_1’W
$$
模型就可以学习如何把聚合后的邻居信息转成更适合任务的表示。
这就是 GCN 的基本味道。
8. GraphSAGE:适合大图的采样聚合
GCN 在整张图上做聚合,如果图很大,训练会比较困难。
例如一个社交网络可能有上亿节点。每次都使用所有邻居进行聚合,计算和内存成本都会很高。
GraphSAGE 的思路是:
不必每次看所有邻居,可以采样固定数量的邻居,再做聚合。
例如对每个节点:
第 1 层采样 10 个邻居
第 2 层对每个邻居再采样 5 个邻居
这样每个 batch 的计算规模就能控制住。
GraphSAGE 还强调归纳能力(inductive ability)。也就是说,训练好以后,模型可以对训练时没见过的新节点生成表示,只要这个新节点有特征和邻居信息。
这对真实系统很重要,因为社交网络、推荐系统里的新用户和新商品会不断出现。
9. GAT:让模型自己决定邻居重要性
GCN 通常按照图结构和度数归一化来聚合邻居,但它默认邻居的重要性主要由结构决定。
可是在很多图里,不同邻居的重要性并不一样。
例如推荐系统中,一个用户连接了很多商品,但某些最近购买的商品可能比很久以前点过的商品更重要。
GAT(Graph Attention Network)引入注意力机制,让模型学习每个邻居的权重:
$$
h_i’=
\sigma
\left(
\sum_{j\in\mathcal{N}(i)}
\alpha_{ij}W h_j
\right)
$$
其中 $\alpha_{ij}$ 表示节点 $i$ 聚合邻居 $j$ 时的注意力权重。
如果某个邻居更有帮助,模型可以给它更高权重;如果某个邻居噪声较大,权重可以更低。
这让 GAT 在一些异质性更强、邻居质量差异更大的图上更灵活。
10. GNN 的训练流程
GNN 的训练过程和普通神经网络很像,但前向传播中会显式使用图结构。
一般流程如下:
- 准备节点特征矩阵 $X$。
- 准备边列表
edge_index或邻接矩阵 $A$。 - 通过多层 GNN 做消息传递,得到节点表示。
- 根据任务类型做预测。
- 计算损失函数。
- 反向传播更新参数。
对于节点分类,常用交叉熵损失:
$$
\mathcal{L}
=
-\sum_{v\in \mathcal{V}_{train}}
y_v \log \hat{y}_v
$$
其中只在训练节点集合上计算损失。
对于图分类,则通常先把节点表示池化成图表示:
$$
h_G = \text{READOUT}({h_v:v\in G})
$$
再用 $h_G$ 做分类或回归。
11. PyTorch Geometric 风格代码
实际项目中,经常使用 PyTorch Geometric 或 DGL 来实现 GNN。
下面用 PyTorch Geometric 风格写一个两层 GCN 节点分类模型。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes):
super().__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, num_classes)
def forward(self, x, edge_index):
# 第一层图卷积:节点接收一跳邻居信息
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# 第二层图卷积:继续聚合邻居信息,并输出类别 logits
x = self.conv2(x, edge_index)
return x
model = GCN(
input_dim=dataset.num_node_features,
hidden_dim=64,
num_classes=dataset.num_classes,
)
optimizer = torch.optim.Adam(
model.parameters(),
lr=0.01,
weight_decay=5e-4,
)
data = dataset[0]
for epoch in range(200):
model.train()
optimizer.zero_grad()
logits = model(data.x, data.edge_index)
# 只在训练节点上计算分类损失
loss = F.cross_entropy(
logits[data.train_mask],
data.y[data.train_mask],
)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"epoch={epoch}, loss={loss.item():.4f}")
这段代码里最关键的是:
data.x:节点特征矩阵。data.edge_index:图的边列表。GCNConv:图卷积层。train_mask:哪些节点用于训练。
和普通神经网络相比,GNN 的模型输入不仅有特征 x,还必须有图结构 edge_index。
12. GNN 和 CNN、Transformer 的关系
可以用一个角度理解它们:
| 模型 | 主要数据结构 | 信息如何流动 |
|---|---|---|
| CNN | 规则网格 | 局部卷积核聚合邻域像素 |
| RNN | 序列 | 按时间步传递隐藏状态 |
| Transformer | 序列或集合 | 注意力在 token 间传递信息 |
| GNN | 图结构 | 节点沿边传递消息 |
CNN 的邻居关系由图像网格天然决定。
GNN 的邻居关系由图的边决定。
Transformer 可以看成一种更自由的全连接信息交互,而 GNN 通常受到图结构约束,只在有边的节点之间传递信息。
所以 GNN 的优势在于:它把“谁和谁有关”这个先验结构直接放进模型里。
13. 过平滑问题
GNN 并不是层数越深越好。
当 GNN 层数很多时,一个节点会不断混合越来越远的邻居信息。最后不同节点的表示可能变得越来越相似,这叫过平滑(Over-smoothing)。
直观地说:
每一层都在和邻居平均,
平均太多次以后,
大家的表示就越来越像。
这会导致节点分类能力下降,因为模型难以区分不同节点。
常见缓解方法包括:
- 使用残差连接。
- 使用跳层连接。
- 限制 GNN 层数。
- 使用归一化和正则化。
- 设计更强的消息传递结构。
很多实际 GCN 模型只有 2 到 3 层,就是因为层数太深容易带来问题。
14. 图同质性和异质性
很多经典 GNN 假设图具有同质性(homophily):
相连的节点往往属于相似类别。
例如论文引用网络中,同一领域的论文更可能互相引用;社交网络中,兴趣相似的人更可能连接。
在这种情况下,邻居聚合通常有效。
但有些图是异质性的(heterophily):
相连的节点不一定相似,甚至可能经常不同类。
例如交易网络中,正常账户可能和异常账户发生交易;知识图谱中,不同类型实体之间才会连接。
这时简单邻居平均可能反而引入噪声,需要更适合异质图的模型,例如关系图神经网络、异构图神经网络或更复杂的注意力机制。
15. GNN 的常见应用
GNN 的应用非常广:
- 推荐系统:用户-商品图、点击图、购买图。
- 风控反欺诈:账户交易网络、设备关联网络。
- 分子性质预测:原子和化学键组成分子图。
- 交通预测:道路网络上的车流量预测。
- 知识图谱:实体关系建模和链接预测。
- 社交网络:社区发现、用户分类、影响力预测。
- 代码分析:函数调用图、抽象语法树、依赖图。
这些任务的共同点是:关系本身就是信息。
如果把图结构丢掉,只把节点当成独立样本,往往会损失大量上下文。
16. GNN 的优点
GNN 的优点主要有:
- 能直接利用图结构。
- 可以把节点特征和邻居信息融合起来。
- 适合节点、边、图三个层级的任务。
- 对不规则结构数据很自然。
- 在推荐、分子、知识图谱等场景中表现强。
尤其是当关系结构很重要时,GNN 往往比只看单点特征的模型更有优势。
17. GNN 的局限
GNN 也有不少挑战。
第一,大图训练成本高。
真实图可能有千万甚至上亿节点,完整图训练很难,需要采样、分布式训练或 mini-batch 技术。
第二,邻居噪声会影响结果。
如果图里有很多错误边、弱关系或恶意连接,消息传递可能把噪声扩散到更多节点。
第三,层数加深容易过平滑。
节点表示变得过于相似后,分类或预测能力会下降。
第四,动态图更难处理。
现实中的图经常随时间变化,例如新用户、新交易、新道路状态。静态 GNN 不一定能直接适应。
第五,可解释性仍然困难。
虽然注意力权重和子图解释能提供一些线索,但 GNN 的整体决策过程仍然不如决策树那样直观。
18. 实践建议
使用 GNN 时,可以按下面顺序思考:
- 先确认图结构是否真的有用,不要为了用 GNN 而强行构图。
- 明确任务是节点级、边级还是图级。
- 检查节点特征质量,只有边没有特征时效果可能受限。
- 从简单 GCN 或 GraphSAGE 做 baseline。
- 大图优先考虑邻居采样和 mini-batch 训练。
- 如果邻居重要性差异明显,可以尝试 GAT。
- 如果图是多关系或多类型节点,考虑 R-GCN 或异构图模型。
- 监控过平滑,不要盲目堆很多层。
- 做好训练、验证、测试划分,避免图数据泄漏。
- 和非 GNN baseline 对比,确认图结构确实带来收益。
19. 总结
GNN 的核心思想可以概括为:
节点从邻居接收消息,
聚合邻居信息,
更新自己的表示,
多层堆叠后获得更大范围的结构上下文。
如果用公式表示,就是:
$$
h_v^{(k)} =
\text{UPDATE}^{(k)}
\left(
h_v^{(k-1)},
\text{AGGREGATE}^{(k)}
\left(
{h_u^{(k-1)}:u\in \mathcal{N}(v)}
\right)
\right)
$$
理解 GNN 时,可以抓住三个关键词:
- 图结构:边决定哪些节点可以交换信息。
- 消息传递:节点通过邻居聚合更新表示。
- 表示学习:最终把节点、边或整张图编码成可用于预测的向量。
当数据中的“关系”本身很重要时,GNN 就非常值得考虑。它不是简单地把图变成表格,而是让模型在图上直接学习。
参考文献
本文主要参考了以下资料:
- Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks, ICLR 2017. https://arxiv.org/abs/1609.02907
- William L. Hamilton, Rex Ying, Jure Leskovec, Inductive Representation Learning on Large Graphs, NeurIPS 2017. https://arxiv.org/abs/1706.02216
- Petar Veličković et al., Graph Attention Networks, ICLR 2018. https://arxiv.org/abs/1710.10903
- Justin Gilmer et al., Neural Message Passing for Quantum Chemistry, ICML 2017. https://arxiv.org/abs/1704.01212
- PyTorch Geometric Documentation: https://pytorch-geometric.readthedocs.io/
