jraphx.nn.conv

This module contains graph convolution layers implementing various message passing algorithms.

Core Message Passing Framework

MessagePassing

class MessagePassing(*args: Any, **kwargs: Any)[source]

Bases: Module

Base class for creating message passing layers.

Message passing layers follow the form

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

where \(\bigoplus\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.

Parameters:
  • aggr (str, optional) – The aggregation scheme to use, e.g., "add", "mean", "min", "max". (default: "add")

  • flow (str, optional) – The flow direction of message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • node_dim (int, optional) – The axis along which to propagate. (default: -2)

Base class for all graph neural network layers implementing the message passing paradigm.

Message Passing Steps:

  1. Message: Compute messages from neighboring nodes

  2. Aggregate: Aggregate messages using sum, mean, max, or min

  3. Update: Update node representations based on aggregated messages

Creating Custom Layers:

from jraphx.nn.conv import MessagePassing
import flax.nnx as nnx
import jax.numpy as jnp

class MyGNNLayer(MessagePassing):
    def __init__(self, in_features, out_features, rngs):
        super().__init__(aggr='mean')
        self.lin = nnx.Linear(in_features, out_features, rngs=rngs)

    def message(self, x_j, x_i=None, edge_attr=None):
        # x_j: Features of source nodes
        # x_i: Features of target nodes (optional)
        # edge_attr: Edge features (optional)
        return x_j

    def update(self, aggr_out, x):
        # aggr_out: Aggregated messages
        # x: Original node features
        return self.lin(jnp.concatenate([x, aggr_out], axis=-1))
propagate(edge_index: Array, x: Union[Array, tuple[jax.Array, jax.Array]], edge_attr: jax.Array | None = None, size: tuple[int, int] | None = None) Array[source]

Main propagation step that orchestrates message passing.

This method uses optimized JAX operations for efficient indexing and gathering of node features.

Parameters:
  • edge_index (Array) – Edge indices [2, num_edges]

  • x (Union[Array, tuple[Array, Array]]) – Node features [num_nodes, features] or tuple for bipartite graphs

  • edge_attr (Optional[Array], default: None) – Optional edge features [num_edges, edge_features]

  • size (Optional[tuple[int, int]], default: None) – Optional size (num_src_nodes, num_dst_nodes) for bipartite graphs

Returns:

Array – Updated node features after message passing

message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array[source]

Construct messages from source nodes j to target nodes i.

Parameters:
  • x_j (Array) – Source node features [num_edges, features]

  • x_i (Optional[Array], default: None) – Target node features [num_edges, features]

  • edge_attr (Optional[Array], default: None) – Optional edge features [num_edges, edge_features]

Returns:

Array – Messages [num_edges, message_features]

aggregate(messages: Array, index: Array, dim_size: int | None = None) Array[source]

Aggregate messages at target nodes using optimized scatter operations.

Parameters:
  • messages (Array) – Messages to aggregate [num_edges, features]

  • index (Array) – Target node indices [num_edges]

  • dim_size (Optional[int], default: None) – Number of target nodes

Returns:

Array – Aggregated messages [num_nodes, features]

update(aggr_out: Array, x: jax.Array | None = None) Array[source]

Update node embeddings after aggregation.

Parameters:
  • aggr_out (Array) – Aggregated messages [num_nodes, features]

  • x (Optional[Array], default: None) – Original node features [num_nodes, features]

Returns:

Array – Updated node features [num_nodes, features]

message_and_aggregate(x: Array, edge_index: Array, edge_attr: jax.Array | None = None, dim_size: int | None = None) Array[source]

Fused message and aggregation for efficiency.

This can be overridden for more efficient implementations when message computation and aggregation can be fused. For example, for simple aggregations like sum/mean with linear transformations, we can avoid materializing all messages.

Parameters:
  • x (Array) – Node features

  • edge_index (Array) – Edge indices

  • edge_attr (Optional[Array], default: None) – Optional edge features

Returns:

Array – Aggregated messages

Graph Convolution Layers

GCNConv

class GCNConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper.

\[\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},\]

where \(\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}\) denotes the adjacency matrix with inserted self-loops and \(\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}\) its diagonal degree matrix. The adjacency matrix can include other values than 1 representing edge weights via the optional edge_weight tensor.

Its node-wise formulation is given by:

\[\mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j\]

with \(\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}\), where \(e_{j,i}\) denotes the edge weight from source node j to target node i (default: 1.0)

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • improved (bool, optional) – If set to True, the layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + 2\mathbf{I}\). (default: False)

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. By default, self-loops will be added when normalize is set to True. (default: True)

  • normalize (bool, optional) – Whether to add self-loops and compute symmetric normalization coefficients on-the-fly. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

  • static_num_nodes (int, optional) – Optional static number of nodes for better JIT performance.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge weights \((|\mathcal{E}|)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

Graph Convolutional Network layer from Kipf & Welling (2017).

Mathematical Formulation:

\[X' = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} X W)\]

where \(\tilde{A} = A + I\) is the adjacency matrix with self-loops and \(\tilde{D}\) is the degree matrix.

Example:

from jraphx.nn.conv import GCNConv
import flax.nnx as nnx

conv = GCNConv(
    in_features=16,
    out_features=32,
    add_self_loops=True,
    normalize=True,
    bias=True,
    rngs=nnx.Rngs(0)
)

out = conv(x, edge_index)
gcn_norm(edge_index: Array, edge_weight: jax.Array | None = None, num_nodes: int | None = None, improved: bool = False, add_self_loops: bool = True, dtype: numpy.dtype | None = None) tuple[jax.Array, jax.Array][source]

Apply GCN normalization to edge weights with optimizations.

This method uses efficient degree computation and caching when possible.

Parameters:
  • edge_index (Array) – Edge indices [2, num_edges]

  • edge_weight (Optional[Array], default: None) – Edge weights [num_edges]

  • num_nodes (Optional[int], default: None) – Number of nodes

  • improved (bool, default: False) – Use improved normalization

  • add_self_loops (bool, default: True) – Add self-loops

  • dtype (Optional[dtype], default: None) – Data type for edge weights

Returns:

tuple[Array, Array] – Tuple of (edge_index, normalized edge_weight)

reset_cache()[source]

Reset the cached edge weights.

Call this when the graph structure changes.

GATConv

class GATConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The graph attentional operator from the “Graph Attention Networks” paper.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_t\mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k \right)\right)}.\]

If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}.\]

If the graph is not bipartite, \(\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}\).

Parameters:
  • in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_features (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None)

  • fill_value (float or str, optional) – The way to generate edge features of self-loops (in case edge_dim != None). (default: "mean")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • residual (bool, optional) – If set to True, the layer will add a learnable skip-connection. (default: False)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.

Graph Attention Network layer from Veličković et al. (2018).

Attention Mechanism:

\[ \begin{align}\begin{aligned}\alpha_{ij} = \text{softmax}_j(e_{ij})\\e_{ij} = \text{LeakyReLU}(a^T [W h_i || W h_j])\end{aligned}\end{align} \]

Multi-head Attention:

  • Multiple attention heads compute independent attention weights

  • Outputs can be concatenated or averaged

Example:

from jraphx.nn.conv import GATConv
import flax.nnx as nnx

conv = GATConv(
    in_features=16,
    out_features=32,
    heads=8,
    concat=True,  # Concatenate head outputs
    dropout=0.6,
    add_self_loops=True,
    rngs=nnx.Rngs(0)
)

out = conv(x, edge_index)
# Output shape: [num_nodes, heads * out_features] if concat=True

GATv2Conv

class GATv2Conv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard GAT layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k \right)\right)}.\]

If the graph has multi-dimensional edge features \(\mathbf{e}_{i,j}\), the attention coefficients \(\alpha_{i,j}\) are computed as

\[\alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}.\]
Parameters:
  • in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_features (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None)

  • fill_value (float, optional) – The way to generate edge features of self-loops (in case edge_dim != None). (default: 0.0)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • share_weights (bool, optional) – If set to True, the same matrix will be applied to the source and the target node of every edge. (default: False)

  • residual (bool, optional) – If set to True, the layer will add a learnable skip-connection. (default: False)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.

Improved Graph Attention Network layer from Brody et al. (2022).

Key Improvements over GAT:

  • Dynamic attention: Attention weights depend on both query and key node features

  • More expressive: Can learn more complex attention patterns

  • Better performance: Often outperforms original GAT

Example:

from jraphx.nn.conv import GATv2Conv
import flax.nnx as nnx

conv = GATv2Conv(
    in_features=16,
    out_features=32,
    heads=8,
    concat=True,
    dropout=0.6,
    edge_dim=8,  # Optional edge features
    rngs=nnx.Rngs(0)
)

out = conv(x, edge_index, edge_attr=edge_attr)

SAGEConv

class SAGEConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper.

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j\]

If project = True, then \(\mathbf{x}_j\) will first get projected via

\[\mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b})\]

as described in Eq. (3) of the paper.

Parameters:
  • in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_features (int) – Size of each output sample.

  • aggr (str, optional) – The aggregation scheme to use. Can be "mean", "max", "lstm", or "gcn". (default: "mean")

  • normalize (bool, optional) – If set to True, output features will be \(\ell_2\)-normalized, i.e., \(\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}\). (default: False)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Shapes:
  • inputs: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • outputs: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V_t}|, F_{out})\) if bipartite

GraphSAGE layer from Hamilton et al. (2017).

Aggregation Options:

  • mean: Average neighbor features

  • max: Element-wise maximum

  • lstm: LSTM aggregation over neighbors

Example:

from jraphx.nn.conv import SAGEConv
import flax.nnx as nnx

# Mean aggregation (most common)
conv = SAGEConv(
    in_features=16,
    out_features=32,
    aggr='mean',
    normalize=True,
    rngs=nnx.Rngs(0)
)

# LSTM aggregation
conv_lstm = SAGEConv(
    in_features=16,
    out_features=32,
    aggr='lstm',
    rngs=nnx.Rngs(0)
)

out = conv(x, edge_index)
message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array[source]

Construct messages from source nodes.

Parameters:
  • x_j (Array) – Source node features [num_edges, out_features]

  • x_i (Optional[Array], default: None) – Target node features (not used)

  • edge_attr (Optional[Array], default: None) – Edge features (not used)

Returns:

Array – Messages [num_edges, out_features]

aggregate(messages: Array, index: Array, dim_size: int | None = None) Array[source]

Aggregate messages based on the specified method.

Parameters:
  • messages (Array) – Messages to aggregate [num_edges, out_features]

  • index (Array) – Target node indices [num_edges]

  • dim_size (Optional[int], default: None) – Number of target nodes

Returns:

Array – Aggregated messages [num_nodes, out_features]

GINConv

class GINConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper.

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]

or

\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]

here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.

Parameters:
  • nn (Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_features] to shape [-1, out_features], e.g., defined by MLP.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

Graph Isomorphism Network layer from Xu et al. (2019).

Key Features:

  • Most expressive GNN under the WL-test framework

  • Uses MLPs for transformation

  • Learnable or fixed epsilon parameter

Example:

from jraphx.nn.conv import GINConv
from jraphx.nn.models import MLP
import flax.nnx as nnx

# Create MLP for GIN
mlp = MLP(
    channel_list=[16, 32, 32],
    norm="batch_norm",
    act="relu",
    rngs=nnx.Rngs(0)
)

conv = GINConv(
    nn=mlp,
    eps=0.0,
    train_eps=True  # Learn epsilon
)

out = conv(x, edge_index)
message(x_j: Array, x_i: jax.Array | None = None, edge_attr: jax.Array | None = None) Array[source]

Construct messages from source nodes.

Parameters:
  • x_j (Array) – Source node features [num_edges, in_features]

  • x_i (Optional[Array], default: None) – Target node features (not used)

  • edge_attr (Optional[Array], default: None) – Edge features (not used)

Returns:

Array – Messages [num_edges, in_features]

EdgeConv

class EdgeConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, \mathbf{x}_j - \mathbf{x}_i),\]

where \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. a MLP.

Parameters:
  • nn (Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps pair-wise concatenated node features x of shape [-1, 2 * in_features] to shape [-1, out_features], e.g., defined by MLP.

  • aggr (str, optional) – The aggregation scheme to use ("add", "mean", "max"). (default: "max")

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

Dynamic edge convolution from Wang et al. (2019).

Dynamic Graph Construction:

  • Can dynamically compute k-nearest neighbors

  • Suitable for point cloud processing

  • Edge features computed from node pairs

Example:

from jraphx.nn.conv import EdgeConv
from jraphx.nn.models import MLP
import flax.nnx as nnx

# MLP processes edge features [x_i || x_j - x_i]
mlp = MLP(
    channel_list=[32, 64, 64],
    rngs=nnx.Rngs(0)
)

conv = EdgeConv(nn=mlp, aggr='max')
out = conv(x, edge_index)
message(x_j: Array, x_i: Array, edge_attr: jax.Array | None = None) Array[source]

Compute messages using edge features (x_i, x_j - x_i).

Return type:

Array

DynamicEdgeConv

class DynamicEdgeConv(*args: Any, **kwargs: Any)[source]

Bases: Module

Dynamic Edge Convolution layer with k-NN graph construction.

This is a simplified version of PyTorch Geometric’s DynamicEdgeConv that requires pre-computed k-NN indices. Unlike PyG’s version which automatically computes k-nearest neighbors using torch-cluster, this implementation expects the k-NN indices to be provided as input.

For true dynamic graph construction, you would need to: 1. Compute k-NN indices from node features using a JAX k-NN implementation 2. Pass these indices to this layer via the knn_indices parameter

PyG equivalent: Uses torch_cluster.knn() for automatic k-NN computation.

Parameters:
  • nn (Module) – Neural network for edge features

  • k (int) – Number of nearest neighbors

  • aggr (str, default: 'max') – Aggregation method (‘add’, ‘mean’, ‘max’). Default: ‘max’

Dynamic edge convolution from Wang et al. (2019).

JraphX vs PyTorch Geometric:

  • PyG: Automatically computes k-NN using torch_cluster.knn()

  • JraphX: Requires pre-computed k-NN indices (simplified version)

Limitations:

  • No automatic k-NN computation from node features

  • Requires external k-NN libraries (e.g., sklearn, faiss)

  • k-NN indices must be provided as input

Example:

from jraphx.nn.conv import DynamicEdgeConv
from jraphx.nn.models import MLP
import jax.numpy as jnp
import flax.nnx as nnx

# Create MLP for edge processing [x_i || x_j - x_i]
mlp = MLP(
    channel_list=[6, 64, 128],  # Input: 2*3=6 for 3D points
    rngs=nnx.Rngs(0)
)

conv = DynamicEdgeConv(nn=mlp, k=6, aggr='max')

# Pre-compute k-NN indices (6 nearest neighbors)
# In practice, use sklearn.neighbors.NearestNeighbors or similar
knn_indices = compute_knn_indices(x, k=6)

out = conv(x, knn_indices=knn_indices)

TransformerConv

class TransformerConv(*args: Any, **kwargs: Any)[source]

Bases: MessagePassing

The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper.

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed via multi-head dot product attention:

\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right)\]
Parameters:
  • in_features (int or tuple) – Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_features (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • beta (bool, optional) –

    If set, will combine aggregation and skip information via

    \[\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}\]

    with \(\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [\mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i])\). (default: False)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • edge_dim (int, optional) –

    Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is:

    \[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right),\]

    where the attention coefficients \(\alpha_{i,j}\) are now computed via:

    \[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right)\]

    (default None)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • rngs (Rngs, default: None) – Random number generators for initialization.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, H * F_{out})\) where \(H\) is the number of heads.

Graph Transformer layer from Shi et al. (2021).

Multi-head Attention:

  • Efficient QKV projection using single linear layer

  • Scaled dot-product attention

  • Optional edge feature incorporation

  • Beta gating mechanism for skip connections

Example:

from jraphx.nn.conv import TransformerConv
import flax.nnx as nnx

conv = TransformerConv(
    in_features=16,
    out_features=32,
    heads=8,
    concat=True,
    dropout_rate=0.1,
    edge_dim=8,  # Optional edge features
    beta=True,  # Gating mechanism
    root_weight=True,  # Skip connection
    rngs=nnx.Rngs(0)
)

out = conv(x, edge_index, edge_attr=edge_attr)
message(query_i: Array, key_j: Array, value_j: Array, edge_attr: jax.Array | None = None, index: Array = None, ptr: jax.Array | None = None, size_i: int | None = None) Array[source]

Compute messages with attention weights.

Parameters:
  • query_i (Array) – Query features of target nodes [E, H*C]

  • key_j (Array) – Key features of source nodes [E, H*C]

  • value_j (Array) – Value features of source nodes [E, H*C]

  • edge_attr (Optional[Array], default: None) – Edge features [E, edge_dim]

  • index (Array, default: None) – Target node indices for edges [E]

  • ptr (Optional[Array], default: None) – Batch pointers (unused)

  • size_i (Optional[int], default: None) – Number of target nodes

  • key_dropout – Random key for dropout

Returns:

Array – Weighted messages [E, H*C]

Layer Selection Guide

Choosing the Right Layer

Layer Comparison

Layer

Complexity

Expressiveness

Memory Usage

Best For

GCNConv

Low

Medium

Low

Citation networks

GATConv

Medium

High

Medium

Heterophilic graphs

GATv2Conv

Medium

Higher

Medium

Complex attention patterns

SAGEConv

Low-Medium

Medium

Low-Medium

Large-scale graphs

GINConv

Medium

Highest

Medium

Graph classification

EdgeConv

High

High

High

Point clouds

DynamicEdgeConv

High

High

High

Point clouds (k-NN)

TransformerConv

High

Highest

High

Complex relationships

Performance Tips

Batch Processing:

from jraphx.data import Batch

# Batch multiple graphs for efficiency
batch = Batch.from_data_list([graph1, graph2, graph3])
out = conv(batch.x, batch.edge_index)

JIT Compilation:

import jax

# JIT compile the forward pass
@jax.jit
def forward(x, edge_index):
    return conv(x, edge_index)

out = forward(x, edge_index)

Memory Efficiency:

  • Use concat=False in attention layers to reduce memory

  • Consider aggr='mean' over aggr='lstm' for large graphs

  • Use sparse operations when available

Edge Features

Many layers support edge features:

# GATv2 with edge features
conv = GATv2Conv(16, 32, heads=8, edge_dim=4)
out = conv(x, edge_index, edge_attr=edge_features)

# TransformerConv with edge features
conv = TransformerConv(16, 32, heads=8, edge_dim=4)
out = conv(x, edge_index, edge_attr=edge_features)

Advanced Usage

Custom Aggregation

class CustomConv(MessagePassing):
    def __init__(self, in_features, out_features):
        # Custom aggregation function
        super().__init__(aggr='add')

    def aggregate(self, inputs, index, dim_size=None):
        # Override for custom aggregation
        return scatter_mean(inputs, index, dim=0, dim_size=dim_size)

Heterogeneous Graphs

# Different edge types
edge_index_1 = ...  # Type 1 edges
edge_index_2 = ...  # Type 2 edges

# Use different convolutions
conv1 = GCNConv(16, 32)
conv2 = SAGEConv(16, 32)

out1 = conv1(x, edge_index_1)
out2 = conv2(x, edge_index_2)
out = out1 + out2  # Combine