- A ^ \mathbf{\hat{A}} A^ 是添加了自环的邻接矩阵,元素 a i j a_{ij} aij可用 1 和 0 表示是否有连边,或等于 e i j e_{ij} eij 连边权值。
- D ^ \mathbf{\hat{D}} D^ 是对角矩阵,对角元素 d i i d_{ii} dii 是当前节点的度。
-
D
^
−
1
/
2
A
^
D
^
−
1
/
2
\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}\mathbf{\hat{D}}^{-1/2}
D^−1/2A^D^−1/2 中的元素
t
i
j
t_{ij}
tij 是
e
j
,
i
d
^
i
d
^
j
\frac{e_{j,i}}{\sqrt{\hat{d}_i}\sqrt{\hat{d}_j}}
d^id^jej,i。(即:message中确定了权重——边权除以source和target的度的1/2次幂的积)
类定义和成员变量
class GCNConv(MessagePassing):
_cached_edge_index: Optional[OptPairTensor]
_cached_adj_t: Optional[SparseTensor]
- _cached_edge_index 和 _cached_adj_t 是用于缓存边索引和邻接矩阵的变量,防止在每次前向传播时重复计算。
构造函数 init
def __init__(
self,
in_channels: int,
out_channels: int,
improved: bool = False,
cached: bool = False,
add_self_loops: Optional[bool] = None,
normalize: bool = True,
bias: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)
if add_self_loops is None:
add_self_loops = normalize
if add_self_loops and not normalize:
raise ValueError(f"'{self.__class__.__name__}' does not support "
f"adding self-loops to the graph when no "
f"on-the-fly normalization is applied")
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize
self._cached_edge_index = None
self._cached_adj_t = None
self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
- 检查是否在不归一化的情况下添加自环,如果是则抛出异常。
- weight_initializer=‘glorot’ 指定使用 Xavier 初始化方法来初始化权重。
重置参数 reset_parameters
def reset_parameters(self):
super().reset_parameters()
self.lin.reset_parameters()
zeros(self.bias)
self._cached_edge_index = None
self._cached_adj_t = None
前向传播 forward
def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
if isinstance(x, (tuple, list)):
raise ValueError(f"'{self.__class__.__name__}' received a tuple "
f"of node features as input while this layer "
f"does not support bipartite message passing. "
f"Please try other layers such as 'SAGEConv' or "
f"'GraphConv' instead")
if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm(
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm(
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
x = self.lin(x)
# propagate_type: (x: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
if self.bias is not None:
out = out + self.bias
return out
forward 方法定义了前向传播过程。
- 如果输入特征是元组或列表,则抛出异常,因为该层不支持二分图消息传递。
- 如果需要归一化,并且 edge_index 是张量,检查是否有缓存。如果没有缓存,计算归一化的边索引和边权重,并缓存结果。
- 如果 edge_index 是稀疏张量,同样处理缓存逻辑。
- 对节点特征进行线性变换。
- 调用 propagate 方法进行消息传递。
- 如果存在偏置,加上偏置。
- 返回输出特征。
消息函数 message
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
- message 方法定义了如何从源节点到目标节点传递消息。
- 如果没有边权重,直接返回源节点特征;如果有边权重,返回加权后的源节点特征。
消息和聚合 message_and_aggregate
def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)
- message_and_aggregate 方法用于稀疏矩阵的消息传递和聚合。
- 使用稀疏矩阵乘法 spmm 进行消息传递和聚合。
GCNConv类实现了GCN的核心思想
通过消息传递和特征聚合来更新节点的表示。它通过以下几个步骤实现:
- 添加自环,使得节点可以聚合自身特征。
- 对节点特征进行线性变换。
- 计算归一化因子,以保证特征的尺度一致。
- 聚合邻居节点的特征,并使用归一化因子进行归一化。
- 返回聚合后的节点特征。
规范化方法 gcn_norm
重载 gcn_norm 函数的定义
@torch.jit._overload
def gcn_norm( # noqa: F811
edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,
dtype):
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa
pass
@torch.jit._overload
def gcn_norm( # noqa: F811
edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,
dtype):
# type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor # noqa
pass
- 通过 @torch.jit._overload 装饰器,定义了两个重载函数的签名,支持不同类型的 edge_index 输入
实际 gcn_norm 函数的实现
def gcn_norm( # noqa: F811
edge_index: Adj,
edge_weight: OptTensor = None,
num_nodes: Optional[int] = None,
improved: bool = False,
add_self_loops: bool = True,
flow: str = "source_to_target",
dtype: Optional[torch.dtype] = None,
):
fill_value = 2. if improved else 1.
SparseTensor 类型的 edge_index和torch.sparse 类型的 edge_index:
- 实现方式:
torch.sparse 类型是 PyTorch 的原生稀疏张量格式,适用于通用的稀疏矩阵操作。
torch_geometric 的 SparseTensor 是为图神经网络优化的稀疏矩阵格式,包含了许多图操作的优化。
- 使用场景:
torch.sparse 适用于需要进行稀疏矩阵乘法等通用操作的场景。
torch_geometric 的 SparseTensor 适用于图神经网络,特别是需要处理大规模图数据的场景。
- 功能支持:
torch.sparse 提供了基本的稀疏张量操作,如矩阵乘法、求和等。
torch_geometric 的 SparseTensor 提供了丰富的图操作支持,如自环添加、归一化、度数计算等。
处理 SparseTensor 类型的 edge_index
if isinstance(edge_index, SparseTensor):
assert edge_index.size(0) == edge_index.size(1)
adj_t = edge_index
if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)
if add_self_loops:
adj_t = torch_sparse.fill_diag(adj_t, fill_value)
deg = torch_sparse.sum(adj_t, dim=1)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t
如果 edge_index 是 SparseTensor 类型,执行以下步骤:
- 确保 edge_index 是方阵(行数等于列数)。
- 如果 adj_t 没有值,则填充为 1。
- 如果需要添加自环,则填充对角线。
- 计算度矩阵并求平方根的倒数,处理无穷值。
- 对 adj_t 进行归一化处理。
处理 torch.sparse 类型的 edge_index
if is_torch_sparse_tensor(edge_index):
assert edge_index.size(0) == edge_index.size(1)
if edge_index.layout == torch.sparse_csc:
raise NotImplementedError("Sparse CSC matrices are not yet "
"supported in 'gcn_norm'")
adj_t = edge_index
if add_self_loops:
adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)
edge_index, value = to_edge_index(adj_t)
col, row = edge_index[0], edge_index[1]
deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col]
return set_sparse_value(adj_t, value), None
如果 edge_index 是 PyTorch 稀疏张量,执行以下步骤:
- 确保 edge_index 是方阵。
- 如果 edge_index 使用稀疏列压缩格式(CSC),则抛出未实现错误。
- 如果需要添加自环,则添加。
- 将稀疏张量转换为边索引格式。
- 计算度矩阵并进行归一化处理。
处理普通张量的 edge_index
assert flow in ['source_to_target', 'target_to_source']
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if add_self_loops:
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
row, col = edge_index[0], edge_index[1]
idx = col if flow == 'source_to_target' else row
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
return edge_index, edge_weight
- 确保 flow 参数在允许的范围内。
- 计算节点数。
maybe_num_nodes是 torch_geometric.utils.num_nodes 模块中的一个函数,用于推断图的节点数量。 - 如果需要添加自环,则添加。
- 如果 edge_weight 为空,则初始化为全 1 的张量。
- 根据 flow 参数选择索引进行度矩阵计算并归一化处理。
PS:当edge_weight均为1的时候,scatter得到每个节点的度赋值给deg,deg为从0到num_nodes的下标对应位置的节点度(但当edge_weight不为1时,得到的就是节点的强度)。
PS:deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
得到的是归一化后的权重。
scatter函数 —(用于将某个张量 src 的值分散到另一个张量 out 中的函数)
torch_scatter.scatter(src, index, dim=-1, out=None, dim_size=None, reduce='sum')
参数说明
- src:源张量,包含要散布的值。
- index:索引张量,指定 src 中的元素将散布到 out 张量中的哪些位置。
- dim:沿着哪个维度进行散布操作。
- out:目标张量,如果不提供,将创建一个新张量。
- dim_size:目标张量的大小,如果不提供,将根据 index 自动计算。
- reduce:聚合操作的类型,例如 ‘sum’(求和)、‘mean’(平均)、‘max’(最大值)等。