jraphx.nn.models

This module contains pre-built GNN model architectures that can be used out-of-the-box for common graph learning tasks.

Pre-built GNN Models

These models provide complete architectures with multiple layers, normalization, dropout, and optional features like JumpingKnowledge connections.

GCN

class GCN(*args: Any, **kwargs: Any)[source]

Bases: BasicGNN

Graph Convolutional Network.

From “Semi-supervised Classification with Graph Convolutional Networks” https://arxiv.org/abs/1609.02907

Uses GCNConv layers for message passing.

Parameters:
  • in_features (int) – Size of input features

  • hidden_features (int) – Size of hidden layers

  • num_layers (int) – Number of GCN layers

  • out_features (Optional[int], default: None) – Size of output (if None, uses hidden_features)

  • dropout_rate (float, default: 0.0) – Dropout probability

  • act (Optional[Callable], default: None) – Activation function

  • act_first (bool, default: False) – If True, apply activation before normalization

  • norm (Optional[str], default: None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)

  • jk (Optional[str], default: None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)

  • residual (bool, default: False) – Whether to use residual connections

  • improved – Use improved GCN normalization

  • cached – Cache normalized edge weights for static graphs

  • add_self_loops – Add self-loops to the graph

  • normalize – Apply symmetric normalization

  • rngs (Optional[Rngs], default: None) – Random number generators

Graph Convolutional Network model with configurable layers and normalization.

Example:

from jraphx.nn.models import GCN
import flax.nnx as nnx

model = GCN(
    in_features=16,
    hidden_features=64,
    num_layers=3,
    out_features=10,
    dropout_rate=0.5,
    norm="layer_norm",  # Options: "batch_norm", "layer_norm", "graph_norm", None
    jk=None,  # Options: "cat", "max", "lstm", None
    rngs=nnx.Rngs(0)
)

# Forward pass
out = model(x, edge_index, batch=batch)
init_conv(in_features: int, out_features: int, rngs: flax.nnx.rnglib.Rngs | None = None, **kwargs) MessagePassing[source]

Initialize GCNConv layer.

Return type:

MessagePassing

GAT

class GAT(*args: Any, **kwargs: Any)[source]

Bases: BasicGNN

Graph Attention Network.

From “Graph Attention Networks” https://arxiv.org/abs/1710.10903 or “How Attentive are Graph Attention Networks?” https://arxiv.org/abs/2105.14491

Uses GATConv or GATv2Conv layers for message passing.

Parameters:
  • in_features (int) – Size of input features

  • hidden_features (int) – Size of hidden layers (per head if concat=True)

  • num_layers (int) – Number of GAT layers

  • out_features (Optional[int], default: None) – Size of output (if None, uses hidden_features)

  • heads (int, default: 1) – Number of attention heads

  • concat (bool, default: True) – Whether to concatenate or average multi-head outputs

  • v2 (bool, default: False) – Use GATv2Conv instead of GATConv

  • dropout_rate (float, default: 0.0) – Dropout probability

  • act (Optional[Callable], default: None) – Activation function

  • act_first (bool, default: False) – If True, apply activation before normalization

  • norm (Optional[str], default: None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)

  • jk (Optional[str], default: None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)

  • residual (bool, default: False) – Whether to use residual connections

  • edge_dim (Optional[int], default: None) – Edge feature dimension

  • rngs (Optional[Rngs], default: None) – Random number generators

Graph Attention Network model with multi-head attention and configurable architecture.

Example:

from jraphx.nn.models import GAT
import flax.nnx as nnx

model = GAT(
    in_features=16,
    hidden_features=64,
    num_layers=3,
    out_features=10,
    heads=8,
    v2=False,  # Use GATv2 if True
    dropout_rate=0.6,
    norm="layer_norm",
    jk="max",  # JumpingKnowledge aggregation
    rngs=nnx.Rngs(0)
)

out = model(x, edge_index, batch=batch)
init_conv(in_features: int, out_features: int, rngs: flax.nnx.rnglib.Rngs | None = None, **kwargs) MessagePassing[source]

Initialize GATConv or GATv2Conv layer.

Return type:

MessagePassing

GraphSAGE

class GraphSAGE(*args: Any, **kwargs: Any)[source]

Bases: BasicGNN

GraphSAGE: Inductive Representation Learning on Large Graphs.

From “Inductive Representation Learning on Large Graphs” https://arxiv.org/abs/1706.02216

Uses SAGEConv layers for message passing.

Parameters:
  • in_features (int) – Size of input features

  • hidden_features (int) – Size of hidden layers

  • num_layers (int) – Number of GraphSAGE layers

  • out_features (Optional[int], default: None) – Size of output (if None, uses hidden_features)

  • aggr – Aggregation method (‘mean’, ‘max’, ‘lstm’)

  • dropout_rate (float, default: 0.0) – Dropout probability

  • act (Optional[Callable], default: None) – Activation function

  • act_first (bool, default: False) – If True, apply activation before normalization

  • norm (Optional[str], default: None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)

  • jk (Optional[str], default: None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)

  • residual (bool, default: False) – Whether to use residual connections

  • normalize – Whether to L2-normalize output features

  • rngs (Optional[Rngs], default: None) – Random number generators

GraphSAGE model with multiple aggregation options.

Example:

from jraphx.nn.models import GraphSAGE
import flax.nnx as nnx

model = GraphSAGE(
    in_features=16,
    hidden_features=64,
    num_layers=3,
    out_features=10,
    aggr="mean",  # Options: "mean", "max", "lstm"
    dropout_rate=0.5,
    norm="batch_norm",
    jk="cat",  # Concatenate all layer outputs
    rngs=nnx.Rngs(0)
)

out = model(x, edge_index, batch=batch)
init_conv(in_features: int, out_features: int, rngs: flax.nnx.rnglib.Rngs | None = None, **kwargs) MessagePassing[source]

Initialize SAGEConv layer.

Return type:

MessagePassing

GIN

class GIN(*args: Any, **kwargs: Any)[source]

Bases: BasicGNN

Graph Isomorphism Network.

From “How Powerful are Graph Neural Networks?” https://arxiv.org/abs/1810.00826

Uses GINConv layers with MLP aggregation for message passing.

Parameters:
  • in_features (int) – Size of input features

  • hidden_features (int) – Size of hidden layers

  • num_layers (int) – Number of GIN layers

  • out_features (Optional[int], default: None) – Size of output (if None, uses hidden_features)

  • dropout_rate (float, default: 0.0) – Dropout probability

  • act (Optional[Callable], default: None) – Activation function

  • act_first (bool, default: False) – If True, apply activation before normalization

  • norm (Optional[str], default: None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)

  • jk (Optional[str], default: None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)

  • residual (bool, default: False) – Whether to use residual connections

  • train_eps – Whether to learn the epsilon parameter

  • rngs (Optional[Rngs], default: None) – Random number generators

Graph Isomorphism Network model with MLP transformations.

Example:

from jraphx.nn.models import GIN
import flax.nnx as nnx

model = GIN(
    in_features=16,
    hidden_features=64,
    num_layers=5,
    out_features=10,
    dropout_rate=0.5,
    norm="batch_norm",
    jk="cat",
    rngs=nnx.Rngs(0)
)

out = model(x, edge_index, batch=batch)
init_conv(in_features: int, out_features: int, rngs: flax.nnx.rnglib.Rngs | None = None, **kwargs) MessagePassing[source]

Initialize GINConv layer.

Return type:

MessagePassing

Base Classes

BasicGNN

class BasicGNN(*args: Any, **kwargs: Any)[source]

Bases: Module

An abstract class for implementing basic GNN models.

Parameters:
  • in_features (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • hidden_features (int) – Size of each hidden sample.

  • num_layers (int) – Number of message passing layers.

  • out_features (int, optional) – If not set to None, will apply a final linear transformation to convert hidden node embeddings to output size out_features. (default: None)

  • dropout_rate (float, optional) – Dropout probability. (default: 0.)

  • act (Callable, optional) – The non-linear activation function to use. (default: jax.nn.relu)

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • norm (str, optional) – The normalization function to use ("batch_norm", "layer_norm", "graph_norm"). (default: None)

  • jk (str, optional) – The Jumping Knowledge mode ("last", "cat", "max", "lstm"). (default: None)

  • residual (bool, optional) – Whether to use residual connections between layers. (default: False)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

  • **kwargs (optional) – Additional arguments for the specific convolution layer.

Abstract base class for GNN models. Provides a common interface for building multi-layer GNNs with normalization, dropout, and JumpingKnowledge connections.

Subclassing Example:

from jraphx.nn.models import BasicGNN
from jraphx.nn.conv import MessagePassing

class MyCustomGNN(BasicGNN):
    def init_conv(self, in_features, out_features, rngs=None, **kwargs):
        # Return your custom message passing layer
        return MyCustomConv(in_features, out_features, rngs=rngs, **kwargs)
init_conv(in_features: int, out_features: int, rngs: flax.nnx.rnglib.Rngs | None = None, **kwargs) MessagePassing[source]

Initialize convolution layer. To be implemented by subclasses.

Return type:

MessagePassing

Utility Models

MLP

class MLP(*args: Any, **kwargs: Any)[source]

Bases: Module

A Multi-Layer Perception (MLP) model.

There exists two ways to instantiate an MLP:

  1. By specifying explicit feature sizes, e.g.,

    mlp = MLP([16, 32, 64, 128], rngs=nnx.Rngs(0))
    

    creates a three-layer MLP with differently sized hidden layers.

  2. By specifying fixed hidden feature sizes over a number of layers, e.g.,

    mlp = MLP(in_features=16, hidden_features=32,
              out_features=128, num_layers=3, rngs=nnx.Rngs(0))
    

    creates a three-layer MLP with equally sized hidden layers.

Parameters:
  • feature_list (List[int] or int, optional) – List of input, intermediate and output features such that len(feature_list) - 1 denotes the number of layers of the MLP (default: None)

  • in_features (int, optional) – Size of each input sample. Will override feature_list. (default: None)

  • hidden_features (int, optional) – Size of each hidden sample. Will override feature_list. (default: None)

  • out_features (int, optional) – Size of each output sample. Will override feature_list. (default: None)

  • num_layers (int, optional) – The number of layers. Will override feature_list. (default: None)

  • dropout_rate (float, optional) – Dropout probability of each hidden embedding. (default: 0.)

  • act (Callable, optional) – The non-linear activation function to use. (default: jax.nn.relu)

  • act_first (bool, optional) – If set to True, activation is applied before normalization. (default: False)

  • norm (str or Callable, optional) – The normalization function to use. (default: None)

  • plain_last (bool, optional) – If set to False, will apply non-linearity, batch normalization and dropout to the last layer as well. (default: True)

  • bias (bool, optional) – If set to False, the module will not learn additive biases. (default: True)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Multi-layer perceptron with configurable layers, normalization, and dropout.

Example:

from jraphx.nn.models import MLP
import flax.nnx as nnx

# Using channel list
mlp = MLP(
    channel_list=[16, 64, 64, 32, 10],
    norm="layer_norm",
    bias=True,
    dropout_rate=0.5,
    act="relu",
    rngs=nnx.Rngs(0)
)

# Or using in/hidden/out channels
mlp = MLP(
    in_features=16,
    hidden_features=64,
    out_features=10,
    num_layers=3,
    norm="batch_norm",
    dropout_rate=0.5,
    rngs=nnx.Rngs(0)
)

out = mlp(x)
property in_features: int

Size of each input sample.

Return type:

int

property out_features: int

Size of each output sample.

Return type:

int

property num_layers: int

Number of layers.

Return type:

int

JumpingKnowledge

class JumpingKnowledge(*args: Any, **kwargs: Any)[source]

Bases: Module

The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper.

Jumping knowledge is performed based on either concatenation ("cat")

\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)},\]

max pooling ("max")

\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right),\]

or weighted summation

\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]

with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM ("lstm").

Parameters:
  • mode (str) – The aggregation scheme to use ("cat", "max" or "lstm").

  • num_features (int, optional) – The number of features per representation. Needs to be only set for LSTM-style aggregation. (default: None)

  • num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

JumpingKnowledge layer for aggregating representations from different GNN layers.

Example:

from jraphx.nn.models import JumpingKnowledge
import flax.nnx as nnx

# Concatenation mode
jk = JumpingKnowledge(mode="cat", channels=64, num_layers=3)

# Max pooling mode
jk = JumpingKnowledge(mode="max")

# LSTM aggregation mode
jk = JumpingKnowledge(
    mode="lstm",
    channels=64,
    num_layers=3,
    rngs=nnx.Rngs(0)
)

# Aggregate layer outputs
layer_outputs = [layer1_out, layer2_out, layer3_out]
final_out = jk(layer_outputs)

Model Selection Guide

Choosing the Right Model

GCN: Best for citation networks and semi-supervised learning tasks with homophilic graphs.

GAT: Excellent for graphs where edge importance varies. The attention mechanism learns which neighbors are most relevant.

GraphSAGE: Ideal for large-scale graphs and inductive learning scenarios where you need to generalize to unseen nodes.

GIN: Most expressive for distinguishing graph structures. Best for graph-level tasks like molecular property prediction.

Configuration Tips

Number of Layers:
  • 2-3 layers for most node classification tasks

  • 4-5 layers for graph-level tasks

  • Use JumpingKnowledge for deeper networks

Normalization:
  • batch_norm: Best for large batches and stable training

  • layer_norm: Works well with smaller batches

  • graph_norm: Specifically designed for graph data

JumpingKnowledge:
  • cat: Preserves all information but increases dimensionality

  • max: Good balance of expressiveness and efficiency

  • lstm: Most flexible but requires more parameters

Dropout:
  • 0.5-0.6 for training stability

  • Higher rates (0.6-0.8) for GAT models

  • Lower rates (0.2-0.5) for deeper models

Performance Comparison

Model Performance Characteristics

Model

Speed

Memory

Expressiveness

Best For

GCN

Fast

Low

Medium

Node classification

GAT

Medium

Medium-High

High

Heterophilic graphs

GraphSAGE

Fast

Low-Medium

Medium

Large-scale graphs

GIN

Fast

Low

Highest

Graph classification