Creating Message Passing Networks with JAX
Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme. With \(\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F\) denoting node features of node \(i\) in layer \((k-1)\) and \(\mathbf{e}_{j,i} \in \mathbb{R}^D\) denoting (optional) edge features from node \(j\) to node \(i\), message passing graph neural networks can be described as
where \(\bigoplus\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma\) and \(\phi\) denote differentiable functions such as MLPs (Multi Layer Perceptrons).
In JraphX, we implement this using JAX and Flax NNX for neural network components.
The “MessagePassing” Base Class
JraphX provides the MessagePassing
base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation.
The user only has to define the functions \(\phi\) (i.e. message()
), and \(\gamma\) (i.e. update()
), as well as the aggregation scheme to use (i.e. aggr="add"
, aggr="mean"
or aggr="max"
).
This is done with the help of the following methods:
MessagePassing(aggr="add", flow="source_to_target")
: Defines the aggregation scheme to use ("add"
,"mean"
or"max"
) and the flow direction of message passing (either"source_to_target"
or"target_to_source"
).MessagePassing.propagate(edge_index, size=None, **kwargs)
: The initial call to start propagating messages. Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings. Note thatpropagate()
is not limited to exchanging messages in square adjacency matrices of shape[N, N]
only, but can also exchange messages in general sparse assignment matrices, e.g., bipartite graphs, of shape[N, M]
by passingsize=(N, M)
as an additional argument. If set toNone
, the assignment matrix is assumed to be a square matrix. For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, e.g.x=(x_N, x_M)
.MessagePassing.message(...)
: Constructs messages to node \(i\) in analogy to \(\phi\) for each edge \((j,i) \in \mathcal{E}\) ifflow="source_to_target"
and \((i,j) \in \mathcal{E}\) ifflow="target_to_source"
. Can take any argument which was initially passed topropagate()
. In addition, JAX arrays passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, e.g.x_i
andx_j
. Note that we generally refer to \(i\) as the central nodes that aggregates information, and refer to \(j\) as the neighboring nodes, since this is the most common notation.MessagePassing.update(aggr_out, ...)
: Updates node embeddings in analogy to \(\gamma\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed topropagate()
.
Let us verify this by re-implementing two popular GNN variants, the GCN layer from Kipf and Welling and the EdgeConv layer from Wang et al..
Implementing the GCN Layer
The GCN layer is mathematically defined as
where neighboring node features are first transformed by a weight matrix \(\mathbf{W}\), normalized by their degree, and finally summed up. Lastly, we apply the bias vector \(\mathbf{b}\) to the aggregated output. This formula can be divided into the following steps:
Add self-loops to the adjacency matrix.
Linearly transform node feature matrix.
Compute normalization coefficients.
Normalize node features in \(\phi\).
Sum up neighboring node features (
"add"
aggregation).Apply a final bias vector.
Steps 1-3 are typically computed before message passing takes place.
Steps 4-5 can be easily processed using the MessagePassing
base class.
The full layer implementation is shown below:
import jax.numpy as jnp
from flax import nnx
from jax.ops import segment_sum
from jraphx.nn.conv.message_passing import MessagePassing
from jraphx.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_features, out_features, *, rngs: nnx.Rngs):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.linear = nnx.Linear(in_features, out_features, use_bias=True, rngs=rngs)
def __call__(self, x, edge_index):
# x has shape [N, in_features]
# edge_index has shape [2, E]
# Step 2: Linearly transform node feature matrix first (more efficient).
x = self.linear(x)
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])
# Step 3: Compute normalization.
row, col = edge_index[0], edge_index[1]
deg = degree(col, x.shape[0], dtype=x.dtype)
deg_inv_sqrt = jnp.power(deg, -0.5)
deg_inv_sqrt = jnp.where(jnp.isinf(deg_inv_sqrt), 0.0, deg_inv_sqrt)
# Create edge weights from normalization
edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Efficient message passing with normalization.
messages = jnp.take(x, row, axis=0) * edge_weight.reshape(-1, 1)
out = segment_sum(messages, col, num_segments=x.shape[0])
return out
GCNConv
inherits from MessagePassing
with "add"
aggregation.
All the logic of the layer takes place in its __call__()
method.
Here, we first linearly transform node features using nnx.Linear
(step 2 - done first for better cache efficiency), then add self-loops to our edge indices using add_self_loops()
(step 1).
The normalization coefficients are derived by the node degrees \(\deg(i)\) for each node \(i\) which gets transformed to \(1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)})\) for each edge \((j,i) \in \mathcal{E}\).
The result is saved in the array edge_weight
of shape [num_edges, ]
(step 3).
For efficient computation, we use JAX’s optimized operations:
- jnp.take()
for fast indexing to gather source node features
- Element-wise multiplication to apply edge weights
- jax.ops.segment_sum()
for efficient aggregation by target nodes
This approach is more efficient than the traditional propagate()
method because it directly leverages JAX’s optimized array operations.
That is all that it takes to create a simple message passing layer with JAX! You can use this layer as a building block for deep architectures. Initializing and calling it is straightforward:
conv = GCNConv(16, 32, rngs=nnx.Rngs(42))
output = conv(x, edge_index)
Implementing the Edge Convolution
The edge convolutional layer processes graphs or point clouds and is mathematically defined as
where \(h_{\mathbf{\Theta}}\) denotes an MLP.
In analogy to the GCN layer, we can use the MessagePassing
class to implement this layer, this time using the "max"
aggregation:
import jax.numpy as jnp
from flax import nnx
from jraphx.nn.conv.message_passing import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_features, out_features, *, rngs: nnx.Rngs):
super().__init__(aggr='max') # "Max" aggregation.
self.mlp = nnx.Sequential(
nnx.Linear(2 * in_features, out_features, rngs=rngs),
nnx.relu,
nnx.Linear(out_features, out_features, rngs=rngs)
)
def __call__(self, x, edge_index):
# x has shape [N, in_features]
# edge_index has shape [2, E]
return self.propagate(edge_index, x)
def message(self, x_j, x_i, edge_attr=None):
# x_i has shape [E, in_features]
# x_j has shape [E, in_features]
tmp = jnp.concatenate([x_i, x_j - x_i], axis=1) # tmp has shape [E, 2 * in_features]
return self.mlp(tmp)
Inside the message()
function, we use self.mlp
to transform both the target node features x_i
and the relative source node features x_j - x_i
for each edge \((j,i) \in \mathcal{E}\).
The edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space.
JraphX provides a DynamicEdgeConv
implementation that handles this automatically:
from jraphx.nn.conv import DynamicEdgeConv
from jraphx.nn.models import MLP
# Create neural network for edge feature processing
nn = MLP(feature_list=[6, 128], rngs=nnx.Rngs(42)) # Input: 2*3=6, Output: 128
# Create dynamic edge convolution layer
conv = DynamicEdgeConv(
nn=nn,
k=6, # Number of nearest neighbors
)
# Use with point cloud data (x contains spatial coordinates)
# Note: k-NN indices must be pre-computed from spatial coordinates
output = conv(x, knn_indices=knn_indices)
Note that unlike PyTorch Geometric’s version, JraphX’s DynamicEdgeConv does not automatically compute k-NN graphs from node features. You must provide the k-NN indices separately, typically computed using external libraries or custom JAX implementations for spatial/feature-space nearest neighbors.
Exercises
Imagine we are given the following Data
object:
import jax.numpy as jnp
from jraphx.data import Data
edge_index = jnp.array([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=jnp.int32)
x = jnp.array([[-1.0], [0.0], [1.0]], dtype=jnp.float32)
data = Data(x=x, edge_index=edge_index.T)
Try to answer the following questions related to GCNConv
:
What information does
row
andcol
hold in the context of JAX arrays?What does
degree()
do and how is it different from PyTorch’s version?Why do we use
degree(col, ...)
rather thandegree(row, ...)
?What does
deg_inv_sqrt[col]
anddeg_inv_sqrt[row]
do in terms of JAX indexing?How does
jnp.take()
work in the JraphX implementation compared to PyTorch’s automatic lifting?Add an
update()
function to the customGCNConv
that adds transformed central node features to the aggregated output.What are the benefits of using
jax.ops.segment_sum()
over the traditional message passing approach?
Try to answer the following questions related to EdgeConv
:
What is
x_i
andx_j - x_i
in the context of JAX arrays?What does
jnp.concatenate([x_i, x_j - x_i], axis=1)
do? Whyaxis = 1
?Implement a vectorized version of EdgeConv that processes multiple graphs using
nnx.vmap
.