jraphx.nn.conv
This module contains graph convolution layers implementing various message passing algorithms.
Core Message Passing Framework
MessagePassing
- class MessagePassing(*args: Any, **kwargs: Any)[source]
Bases:
Module
Base class for creating message passing layers.
Message passing layers follow the form
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]where \(\bigoplus\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.
- Parameters:
aggr (str, optional) – The aggregation scheme to use, e.g.,
"add"
,"mean"
,"min"
,"max"
. (default:"add"
)flow (str, optional) – The flow direction of message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)node_dim (int, optional) – The axis along which to propagate. (default:
-2
)
Base class for all graph neural network layers implementing the message passing paradigm.
Message Passing Steps:
Message: Compute messages from neighboring nodes
Aggregate: Aggregate messages using sum, mean, max, or min
Update: Update node representations based on aggregated messages
Creating Custom Layers:
from jraphx.nn.conv import MessagePassing import flax.nnx as nnx import jax.numpy as jnp class MyGNNLayer(MessagePassing): def __init__(self, in_features, out_features, rngs): super().__init__(aggr='mean') self.lin = nnx.Linear(in_features, out_features, rngs=rngs) def message(self, x_j, x_i=None, edge_attr=None): # x_j: Features of source nodes # x_i: Features of target nodes (optional) # edge_attr: Edge features (optional) return x_j def update(self, aggr_out, x): # aggr_out: Aggregated messages # x: Original node features return self.lin(jnp.concatenate([x, aggr_out], axis=-1))
- propagate(edge_index: Array, x: Union[Array, tuple[jax.Array, jax.Array]], edge_attr: jax.Array | None = None, size: tuple[int, int] | None = None) Array [source]
Main propagation step that orchestrates message passing.
This method uses optimized JAX operations for efficient indexing and gathering of node features.
- Parameters:
edge_index (
Array
) – Edge indices [2, num_edges]x (
Union
[Array
,tuple
[Array
,Array
]]) – Node features [num_nodes, features] or tuple for bipartite graphsedge_attr (
Optional
[Array
], default:None
) – Optional edge features [num_edges, edge_features]size (
Optional
[tuple
[int
,int
]], default:None
) – Optional size (num_src_nodes, num_dst_nodes) for bipartite graphs
- Returns:
Array
– Updated node features after message passing
- message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array [source]
Construct messages from source nodes j to target nodes i.
- Parameters:
x_j (
Array
) – Source node features [num_edges, features]x_i (
Optional
[Array
], default:None
) – Target node features [num_edges, features]edge_attr (
Optional
[Array
], default:None
) – Optional edge features [num_edges, edge_features]
- Returns:
Array
– Messages [num_edges, message_features]
- aggregate(messages: Array, index: Array, dim_size: int | None = None) Array [source]
Aggregate messages at target nodes using optimized scatter operations.
- Parameters:
messages (
Array
) – Messages to aggregate [num_edges, features]index (
Array
) – Target node indices [num_edges]dim_size (
Optional
[int
], default:None
) – Number of target nodes
- Returns:
Array
– Aggregated messages [num_nodes, features]
- update(aggr_out: Array, x: jax.Array | None = None) Array [source]
Update node embeddings after aggregation.
- Parameters:
aggr_out (
Array
) – Aggregated messages [num_nodes, features]x (
Optional
[Array
], default:None
) – Original node features [num_nodes, features]
- Returns:
Array
– Updated node features [num_nodes, features]
- message_and_aggregate(x: Array, edge_index: Array, edge_attr: jax.Array | None = None, dim_size: int | None = None) Array [source]
Fused message and aggregation for efficiency.
This can be overridden for more efficient implementations when message computation and aggregation can be fused. For example, for simple aggregations like sum/mean with linear transformations, we can avoid materializing all messages.
- Parameters:
x (
Array
) – Node featuresedge_index (
Array
) – Edge indicesedge_attr (
Optional
[Array
], default:None
) – Optional edge features
- Returns:
Array
– Aggregated messages
Graph Convolution Layers
GCNConv
- class GCNConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper.
\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than
1
representing edge weights via the optionaledge_weight
tensor.Its node-wise formulation is given by:
\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j\]with \(\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}\), where \(e_{j,i}\) denotes the edge weight from source node
j
to target nodei
(default:1.0
)- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
improved (bool, optional) – If set to
True
, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default:False
)cached (bool, optional) – If set to
True
, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set toTrue
in transductive learning scenarios. (default:False
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. By default, self-loops will be added whennormalize
is set toTrue
. (default:True
)normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on-the-fly. (default:
True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.static_num_nodes (int, optional) – Optional static number of nodes for better JIT performance.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
Graph Convolutional Network layer from Kipf & Welling (2017).
Mathematical Formulation:
\[X' = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} X W)\]where \(\tilde{A} = A + I\) is the adjacency matrix with self-loops and \(\tilde{D}\) is the degree matrix.
Example:
from jraphx.nn.conv import GCNConv import flax.nnx as nnx conv = GCNConv( in_features=16, out_features=32, add_self_loops=True, normalize=True, bias=True, rngs=nnx.Rngs(0) ) out = conv(x, edge_index)
- gcn_norm(edge_index: Array, edge_weight: jax.Array | None = None, num_nodes: int | None = None, improved: bool = False, add_self_loops: bool = True, dtype: numpy.dtype | None = None) tuple[jax.Array, jax.Array] [source]
Apply GCN normalization to edge weights with optimizations.
This method uses efficient degree computation and caching when possible.
- Parameters:
edge_index (
Array
) – Edge indices [2, num_edges]edge_weight (
Optional
[Array
], default:None
) – Edge weights [num_edges]num_nodes (
Optional
[int
], default:None
) – Number of nodesimproved (
bool
, default:False
) – Use improved normalizationadd_self_loops (
bool
, default:True
) – Add self-loopsdtype (
Optional
[dtype
], default:None
) – Data type for edge weights
- Returns:
tuple
[Array
,Array
] – Tuple of (edge_index, normalized edge_weight)
GATConv
- class GATConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The graph attentional operator from the “Graph Attention Networks” paper.
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k \right)\right)}.\]If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}.\]If the graph is not bipartite, \(\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}\).
- Parameters:
in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.
out_features (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default:
None
)fill_value (float or str, optional) – The way to generate edge features of self-loops (in case
edge_dim != None
). (default:"mean"
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)residual (bool, optional) – If set to
True
, the layer will add a learnable skip-connection. (default:False
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.
Graph Attention Network layer from Veličković et al. (2018).
Attention Mechanism:
\[ \begin{align}\begin{aligned}\alpha_{ij} = \text{softmax}_j(e_{ij})\\e_{ij} = \text{LeakyReLU}(a^T [W h_i || W h_j])\end{aligned}\end{align} \]Multi-head Attention:
Multiple attention heads compute independent attention weights
Outputs can be concatenated or averaged
Example:
from jraphx.nn.conv import GATConv import flax.nnx as nnx conv = GATConv( in_features=16, out_features=32, heads=8, concat=True, # Concatenate head outputs dropout=0.6, add_self_loops=True, rngs=nnx.Rngs(0) ) out = conv(x, edge_index) # Output shape: [num_nodes, heads * out_features] if concat=True
GATv2Conv
- class GATv2Conv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard GAT layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node.
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k \right)\right)}.\]If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as
\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}.\]- Parameters:
in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.
out_features (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default:
None
)fill_value (float, optional) – The way to generate edge features of self-loops (in case
edge_dim != None
). (default:0.0
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)share_weights (bool, optional) – If set to
True
, the same matrix will be applied to the source and the target node of every edge. (default:False
)residual (bool, optional) – If set to
True
, the layer will add a learnable skip-connection. (default:False
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.
Improved Graph Attention Network layer from Brody et al. (2022).
Key Improvements over GAT:
Dynamic attention: Attention weights depend on both query and key node features
More expressive: Can learn more complex attention patterns
Better performance: Often outperforms original GAT
Example:
from jraphx.nn.conv import GATv2Conv import flax.nnx as nnx conv = GATv2Conv( in_features=16, out_features=32, heads=8, concat=True, dropout=0.6, edge_dim=8, # Optional edge features rngs=nnx.Rngs(0) ) out = conv(x, edge_index, edge_attr=edge_attr)
SAGEConv
- class SAGEConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper.
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j\]If
project = True
, then \(\mathbf{x}_j\) will first get projected via\[\mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b})\]as described in Eq. (3) of the paper.
- Parameters:
in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.
out_features (int) – Size of each output sample.
aggr (str, optional) – The aggregation scheme to use. Can be
"mean"
,"max"
,"lstm"
, or"gcn"
. (default:"mean"
)normalize (bool, optional) – If set to
True
, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default:False
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
- Shapes:
inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite
GraphSAGE layer from Hamilton et al. (2017).
Aggregation Options:
mean: Average neighbor features
max: Element-wise maximum
lstm: LSTM aggregation over neighbors
Example:
from jraphx.nn.conv import SAGEConv import flax.nnx as nnx # Mean aggregation (most common) conv = SAGEConv( in_features=16, out_features=32, aggr='mean', normalize=True, rngs=nnx.Rngs(0) ) # LSTM aggregation conv_lstm = SAGEConv( in_features=16, out_features=32, aggr='lstm', rngs=nnx.Rngs(0) ) out = conv(x, edge_index)
- message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array [source]
Construct messages from source nodes.
- Parameters:
x_j (
Array
) – Source node features [num_edges, out_features]x_i (
Optional
[Array
], default:None
) – Target node features (not used)edge_attr (
Optional
[Array
], default:None
) – Edge features (not used)
- Returns:
Array
– Messages [num_edges, out_features]
- aggregate(messages: Array, index: Array, dim_size: int | None = None) Array [source]
Aggregate messages based on the specified method.
- Parameters:
messages (
Array
) – Messages to aggregate [num_edges, out_features]index (
Array
) – Target node indices [num_edges]dim_size (
Optional
[int
], default:None
) – Number of target nodes
- Returns:
Array
– Aggregated messages [num_nodes, out_features]
GINConv
- class GINConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper.
\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]or
\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.
- Parameters:
nn (Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
x
of shape[-1, in_features]
to shape[-1, out_features]
, e.g., defined by MLP.eps (float, optional) – (Initial) \(\epsilon\)-value. (default:
0.
)train_eps (bool, optional) – If set to
True
, \(\epsilon\) will be a trainable parameter. (default:False
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
Graph Isomorphism Network layer from Xu et al. (2019).
Key Features:
Most expressive GNN under the WL-test framework
Uses MLPs for transformation
Learnable or fixed epsilon parameter
Example:
from jraphx.nn.conv import GINConv from jraphx.nn.models import MLP import flax.nnx as nnx # Create MLP for GIN mlp = MLP( channel_list=[16, 32, 32], norm="batch_norm", act="relu", rngs=nnx.Rngs(0) ) conv = GINConv( nn=mlp, eps=0.0, train_eps=True # Learn epsilon ) out = conv(x, edge_index)
- message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array [source]
Construct messages from source nodes.
- Parameters:
x_j (
Array
) – Source node features [num_edges, in_features]x_i (
Optional
[Array
], default:None
) – Target node features (not used)edge_attr (
Optional
[Array
], default:None
) – Edge features (not used)
- Returns:
Array
– Messages [num_edges, in_features]
EdgeConv
- class EdgeConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper.
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.
- Parameters:
nn (Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features
x
of shape[-1, 2 * in_features]
to shape[-1, out_features]
, e.g., defined by MLP.aggr (str, optional) – The aggregation scheme to use (
"add"
,"mean"
,"max"
). (default:"max"
)
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
Dynamic edge convolution from Wang et al. (2019).
Dynamic Graph Construction:
Can dynamically compute k-nearest neighbors
Suitable for point cloud processing
Edge features computed from node pairs
Example:
from jraphx.nn.conv import EdgeConv from jraphx.nn.models import MLP import flax.nnx as nnx # MLP processes edge features [x_i || x_j - x_i] mlp = MLP( channel_list=[32, 64, 64], rngs=nnx.Rngs(0) ) conv = EdgeConv(nn=mlp, aggr='max') out = conv(x, edge_index)
DynamicEdgeConv
- class DynamicEdgeConv(*args: Any, **kwargs: Any)[source]
Bases:
Module
Dynamic Edge Convolution layer with k-NN graph construction.
This is a simplified version of PyTorch Geometric’s DynamicEdgeConv that requires pre-computed k-NN indices. Unlike PyG’s version which automatically computes k-nearest neighbors using torch-cluster, this implementation expects the k-NN indices to be provided as input.
For true dynamic graph construction, you would need to: 1. Compute k-NN indices from node features using a JAX k-NN implementation 2. Pass these indices to this layer via the knn_indices parameter
PyG equivalent: Uses torch_cluster.knn() for automatic k-NN computation.
- Parameters:
Dynamic edge convolution from Wang et al. (2019).
JraphX vs PyTorch Geometric:
PyG: Automatically computes k-NN using
torch_cluster.knn()
JraphX: Requires pre-computed k-NN indices (simplified version)
Limitations:
No automatic k-NN computation from node features
Requires external k-NN libraries (e.g., sklearn, faiss)
k-NN indices must be provided as input
Example:
from jraphx.nn.conv import DynamicEdgeConv from jraphx.nn.models import MLP import jax.numpy as jnp import flax.nnx as nnx # Create MLP for edge processing [x_i || x_j - x_i] mlp = MLP( channel_list=[6, 64, 128], # Input: 2*3=6 for 3D points rngs=nnx.Rngs(0) ) conv = DynamicEdgeConv(nn=mlp, k=6, aggr='max') # Pre-compute k-NN indices (6 nearest neighbors) # In practice, use sklearn.neighbors.NearestNeighbors or similar knn_indices = compute_knn_indices(x, k=6) out = conv(x, knn_indices=knn_indices)
TransformerConv
- class TransformerConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassing
The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper.
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},\]where the attention coefficients \(\alpha_{i,j}\) are computed via multi-head dot product attention:
\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right)\]- Parameters:
in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.
out_features (int) – Size of each output sample.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)beta (bool, optional) –
If set, will combine aggregation and skip information via
\[\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}\]with \(\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [\mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i])\). (default:
False
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)edge_dim (int, optional) –
Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is:
\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right),\]where the attention coefficients \(\alpha_{i,j}\) are now computed via:
\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right)\](default
None
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)rngs (
Rngs
, default:None
) – Random number generators for initialization.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.
Graph Transformer layer from Shi et al. (2021).
Multi-head Attention:
Efficient QKV projection using single linear layer
Scaled dot-product attention
Optional edge feature incorporation
Beta gating mechanism for skip connections
Example:
from jraphx.nn.conv import TransformerConv import flax.nnx as nnx conv = TransformerConv( in_features=16, out_features=32, heads=8, concat=True, dropout_rate=0.1, edge_dim=8, # Optional edge features beta=True, # Gating mechanism root_weight=True, # Skip connection rngs=nnx.Rngs(0) ) out = conv(x, edge_index, edge_attr=edge_attr)
- message(query_i: Array, key_j: Array, value_j: Array, edge_attr: jax.Array | None = None, index: Array = None, ptr: jax.Array | None = None, size_i: int | None = None) Array [source]
Compute messages with attention weights.
- Parameters:
query_i (
Array
) – Query features of target nodes [E, H*C]key_j (
Array
) – Key features of source nodes [E, H*C]value_j (
Array
) – Value features of source nodes [E, H*C]edge_attr (
Optional
[Array
], default:None
) – Edge features [E, edge_dim]index (
Array
, default:None
) – Target node indices for edges [E]ptr (
Optional
[Array
], default:None
) – Batch pointers (unused)size_i (
Optional
[int
], default:None
) – Number of target nodeskey_dropout – Random key for dropout
- Returns:
Array
– Weighted messages [E, H*C]
Layer Selection Guide
Choosing the Right Layer
Layer |
Complexity |
Expressiveness |
Memory Usage |
Best For |
---|---|---|---|---|
GCNConv |
Low |
Medium |
Low |
Citation networks |
GATConv |
Medium |
High |
Medium |
Heterophilic graphs |
GATv2Conv |
Medium |
Higher |
Medium |
Complex attention patterns |
SAGEConv |
Low-Medium |
Medium |
Low-Medium |
Large-scale graphs |
GINConv |
Medium |
Highest |
Medium |
Graph classification |
EdgeConv |
High |
High |
High |
Point clouds |
DynamicEdgeConv |
High |
High |
High |
Point clouds (k-NN) |
TransformerConv |
High |
Highest |
High |
Complex relationships |
Performance Tips
Batch Processing:
from jraphx.data import Batch
# Batch multiple graphs for efficiency
batch = Batch.from_data_list([graph1, graph2, graph3])
out = conv(batch.x, batch.edge_index)
JIT Compilation:
import jax
# JIT compile the forward pass
@jax.jit
def forward(x, edge_index):
return conv(x, edge_index)
out = forward(x, edge_index)
Memory Efficiency:
Use
concat=False
in attention layers to reduce memoryConsider
aggr='mean'
overaggr='lstm'
for large graphsUse sparse operations when available
Edge Features
Many layers support edge features:
# GATv2 with edge features
conv = GATv2Conv(16, 32, heads=8, edge_dim=4)
out = conv(x, edge_index, edge_attr=edge_features)
# TransformerConv with edge features
conv = TransformerConv(16, 32, heads=8, edge_dim=4)
out = conv(x, edge_index, edge_attr=edge_features)
Advanced Usage
Custom Aggregation
class CustomConv(MessagePassing):
def __init__(self, in_features, out_features):
# Custom aggregation function
super().__init__(aggr='add')
def aggregate(self, inputs, index, dim_size=None):
# Override for custom aggregation
return scatter_mean(inputs, index, dim=0, dim_size=dim_size)
Heterogeneous Graphs
# Different edge types
edge_index_1 = ... # Type 1 edges
edge_index_2 = ... # Type 2 edges
# Use different convolutions
conv1 = GCNConv(16, 32)
conv2 = SAGEConv(16, 32)
out1 = conv1(x, edge_index_1)
out2 = conv2(x, edge_index_2)
out = out1 + out2 # Combine