jraphx.nn.models
This module contains pre-built GNN model architectures that can be used out-of-the-box for common graph learning tasks.
Pre-built GNN Models
These models provide complete architectures with multiple layers, normalization, dropout, and optional features like JumpingKnowledge connections.
GCN
- class GCN(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNNGraph Convolutional Network.
From “Semi-supervised Classification with Graph Convolutional Networks” https://arxiv.org/abs/1609.02907
Uses GCNConv layers for message passing.
- Parameters:
in_features (
int) – Size of input featureshidden_features (
int) – Size of hidden layersnum_layers (
int) – Number of GCN layersout_features (
Optional[int], default:None) – Size of output (if None, uses hidden_features)dropout_rate (
float, default:0.0) – Dropout probabilityact (
Optional[Callable], default:None) – Activation functionact_first (
bool, default:False) – If True, apply activation before normalizationnorm (
Optional[str], default:None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional[str], default:None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool, default:False) – Whether to use residual connectionsimproved – Use improved GCN normalization
cached – Cache normalized edge weights for static graphs
add_self_loops – Add self-loops to the graph
normalize – Apply symmetric normalization
rngs (
Optional[Rngs], default:None) – Random number generators
Graph Convolutional Network model with configurable layers and normalization.
Example:
from jraphx.nn.models import GCN import flax.nnx as nnx model = GCN( in_features=16, hidden_features=64, num_layers=3, out_features=10, dropout_rate=0.5, norm="layer_norm", # Options: "batch_norm", "layer_norm", "graph_norm", None jk=None, # Options: "cat", "max", "lstm", None rngs=nnx.Rngs(0) ) # Forward pass out = model(x, edge_index, batch=batch)
GAT
- class GAT(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNNGraph Attention Network.
From “Graph Attention Networks” https://arxiv.org/abs/1710.10903 or “How Attentive are Graph Attention Networks?” https://arxiv.org/abs/2105.14491
Uses GATConv or GATv2Conv layers for message passing.
- Parameters:
in_features (
int) – Size of input featureshidden_features (
int) – Size of hidden layers (per head if concat=True)num_layers (
int) – Number of GAT layersout_features (
Optional[int], default:None) – Size of output (if None, uses hidden_features)heads (
int, default:1) – Number of attention headsconcat (
bool, default:True) – Whether to concatenate or average multi-head outputsv2 (
bool, default:False) – Use GATv2Conv instead of GATConvdropout_rate (
float, default:0.0) – Dropout probabilityact (
Optional[Callable], default:None) – Activation functionact_first (
bool, default:False) – If True, apply activation before normalizationnorm (
Optional[str], default:None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional[str], default:None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool, default:False) – Whether to use residual connectionsedge_dim (
Optional[int], default:None) – Edge feature dimensionrngs (
Optional[Rngs], default:None) – Random number generators
Graph Attention Network model with multi-head attention and configurable architecture.
Example:
from jraphx.nn.models import GAT import flax.nnx as nnx model = GAT( in_features=16, hidden_features=64, num_layers=3, out_features=10, heads=8, v2=False, # Use GATv2 if True dropout_rate=0.6, norm="layer_norm", jk="max", # JumpingKnowledge aggregation rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
GraphSAGE
- class GraphSAGE(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNNGraphSAGE: Inductive Representation Learning on Large Graphs.
From “Inductive Representation Learning on Large Graphs” https://arxiv.org/abs/1706.02216
Uses SAGEConv layers for message passing.
- Parameters:
in_features (
int) – Size of input featureshidden_features (
int) – Size of hidden layersnum_layers (
int) – Number of GraphSAGE layersout_features (
Optional[int], default:None) – Size of output (if None, uses hidden_features)aggr – Aggregation method (‘mean’, ‘max’, ‘lstm’)
dropout_rate (
float, default:0.0) – Dropout probabilityact (
Optional[Callable], default:None) – Activation functionact_first (
bool, default:False) – If True, apply activation before normalizationnorm (
Optional[str], default:None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional[str], default:None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool, default:False) – Whether to use residual connectionsnormalize – Whether to L2-normalize output features
rngs (
Optional[Rngs], default:None) – Random number generators
GraphSAGE model with multiple aggregation options.
Example:
from jraphx.nn.models import GraphSAGE import flax.nnx as nnx model = GraphSAGE( in_features=16, hidden_features=64, num_layers=3, out_features=10, aggr="mean", # Options: "mean", "max", "lstm" dropout_rate=0.5, norm="batch_norm", jk="cat", # Concatenate all layer outputs rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
GIN
- class GIN(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNNGraph Isomorphism Network.
From “How Powerful are Graph Neural Networks?” https://arxiv.org/abs/1810.00826
Uses GINConv layers with MLP aggregation for message passing.
- Parameters:
in_features (
int) – Size of input featureshidden_features (
int) – Size of hidden layersnum_layers (
int) – Number of GIN layersout_features (
Optional[int], default:None) – Size of output (if None, uses hidden_features)dropout_rate (
float, default:0.0) – Dropout probabilityact (
Optional[Callable], default:None) – Activation functionact_first (
bool, default:False) – If True, apply activation before normalizationnorm (
Optional[str], default:None) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional[str], default:None) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool, default:False) – Whether to use residual connectionstrain_eps – Whether to learn the epsilon parameter
rngs (
Optional[Rngs], default:None) – Random number generators
Graph Isomorphism Network model with MLP transformations.
Example:
from jraphx.nn.models import GIN import flax.nnx as nnx model = GIN( in_features=16, hidden_features=64, num_layers=5, out_features=10, dropout_rate=0.5, norm="batch_norm", jk="cat", rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
Base Classes
BasicGNN
- class BasicGNN(*args: Any, **kwargs: Any)[source]
Bases:
ModuleAn abstract class for implementing basic GNN models.
- Parameters:
in_features (int or tuple) – Size of each input sample, or
-1to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.hidden_features (int) – Size of each hidden sample.
num_layers (int) – Number of message passing layers.
out_features (int, optional) – If not set to
None, will apply a final linear transformation to convert hidden node embeddings to output sizeout_features. (default:None)dropout_rate (float, optional) – Dropout probability. (default:
0.)act (Callable, optional) – The non-linear activation function to use. (default:
jax.nn.relu)act_first (bool, optional) – If set to
True, activation is applied before normalization. (default:False)norm (str, optional) – The normalization function to use (
"batch_norm","layer_norm","graph_norm"). (default:None)jk (str, optional) – The Jumping Knowledge mode (
"last","cat","max","lstm"). (default:None)residual (bool, optional) – Whether to use residual connections between layers. (default:
False)rngs (
Optional[Rngs], default:None) – Random number generators for initialization.**kwargs (optional) – Additional arguments for the specific convolution layer.
Abstract base class for GNN models. Provides a common interface for building multi-layer GNNs with normalization, dropout, and JumpingKnowledge connections.
Subclassing Example:
from jraphx.nn.models import BasicGNN from jraphx.nn.conv import MessagePassing class MyCustomGNN(BasicGNN): def init_conv(self, in_features, out_features, rngs=None, **kwargs): # Return your custom message passing layer return MyCustomConv(in_features, out_features, rngs=rngs, **kwargs)
Utility Models
MLP
- class MLP(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA Multi-Layer Perception (MLP) model.
There exists two ways to instantiate an
MLP:By specifying explicit feature sizes, e.g.,
mlp = MLP([16, 32, 64, 128], rngs=nnx.Rngs(0))
creates a three-layer MLP with differently sized hidden layers.
By specifying fixed hidden feature sizes over a number of layers, e.g.,
mlp = MLP(in_features=16, hidden_features=32, out_features=128, num_layers=3, rngs=nnx.Rngs(0))
creates a three-layer MLP with equally sized hidden layers.
- Parameters:
feature_list (List[int] or int, optional) – List of input, intermediate and output features such that
len(feature_list) - 1denotes the number of layers of the MLP (default:None)in_features (int, optional) – Size of each input sample. Will override
feature_list. (default:None)hidden_features (int, optional) – Size of each hidden sample. Will override
feature_list. (default:None)out_features (int, optional) – Size of each output sample. Will override
feature_list. (default:None)num_layers (int, optional) – The number of layers. Will override
feature_list. (default:None)dropout_rate (float, optional) – Dropout probability of each hidden embedding. (default:
0.)act (Callable, optional) – The non-linear activation function to use. (default:
jax.nn.relu)act_first (bool, optional) – If set to
True, activation is applied before normalization. (default:False)norm (str or Callable, optional) – The normalization function to use. (default:
None)plain_last (bool, optional) – If set to
False, will apply non-linearity, batch normalization and dropout to the last layer as well. (default:True)bias (bool, optional) – If set to
False, the module will not learn additive biases. (default:True)rngs (
Optional[Rngs], default:None) – Random number generators for initialization.
Multi-layer perceptron with configurable layers, normalization, and dropout.
Example:
from jraphx.nn.models import MLP import flax.nnx as nnx # Using channel list mlp = MLP( channel_list=[16, 64, 64, 32, 10], norm="layer_norm", bias=True, dropout_rate=0.5, act="relu", rngs=nnx.Rngs(0) ) # Or using in/hidden/out channels mlp = MLP( in_features=16, hidden_features=64, out_features=10, num_layers=3, norm="batch_norm", dropout_rate=0.5, rngs=nnx.Rngs(0) ) out = mlp(x)
JumpingKnowledge
- class JumpingKnowledge(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThe Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper.
Jumping knowledge is performed based on either concatenation (
"cat")\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)},\]max pooling (
"max")\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right),\]or weighted summation
\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM (
"lstm").- Parameters:
mode (str) – The aggregation scheme to use (
"cat","max"or"lstm").num_features (int, optional) – The number of features per representation. Needs to be only set for LSTM-style aggregation. (default:
None)num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default:
None)rngs (
Optional[Rngs], default:None) – Random number generators for initialization.
JumpingKnowledge layer for aggregating representations from different GNN layers.
Example:
from jraphx.nn.models import JumpingKnowledge import flax.nnx as nnx # Concatenation mode jk = JumpingKnowledge(mode="cat", channels=64, num_layers=3) # Max pooling mode jk = JumpingKnowledge(mode="max") # LSTM aggregation mode jk = JumpingKnowledge( mode="lstm", channels=64, num_layers=3, rngs=nnx.Rngs(0) ) # Aggregate layer outputs layer_outputs = [layer1_out, layer2_out, layer3_out] final_out = jk(layer_outputs)
Model Selection Guide
Choosing the Right Model
GCN: Best for citation networks and semi-supervised learning tasks with homophilic graphs.
GAT: Excellent for graphs where edge importance varies. The attention mechanism learns which neighbors are most relevant.
GraphSAGE: Ideal for large-scale graphs and inductive learning scenarios where you need to generalize to unseen nodes.
GIN: Most expressive for distinguishing graph structures. Best for graph-level tasks like molecular property prediction.
Configuration Tips
- Number of Layers:
2-3 layers for most node classification tasks
4-5 layers for graph-level tasks
Use JumpingKnowledge for deeper networks
- Normalization:
batch_norm: Best for large batches and stable traininglayer_norm: Works well with smaller batchesgraph_norm: Specifically designed for graph data
- JumpingKnowledge:
cat: Preserves all information but increases dimensionalitymax: Good balance of expressiveness and efficiencylstm: Most flexible but requires more parameters
- Dropout:
0.5-0.6 for training stability
Higher rates (0.6-0.8) for GAT models
Lower rates (0.2-0.5) for deeper models
Performance Comparison
Model |
Speed |
Memory |
Expressiveness |
Best For |
|---|---|---|---|---|
GCN |
Fast |
Low |
Medium |
Node classification |
GAT |
Medium |
Medium-High |
High |
Heterophilic graphs |
GraphSAGE |
Fast |
Low-Medium |
Medium |
Large-scale graphs |
GIN |
Fast |
Low |
Highest |
Graph classification |