jraphx.nn.models
This module contains pre-built GNN model architectures that can be used out-of-the-box for common graph learning tasks.
Pre-built GNN Models
These models provide complete architectures with multiple layers, normalization, dropout, and optional features like JumpingKnowledge connections.
GCN
- class GCN(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNN
Graph Convolutional Network.
From “Semi-supervised Classification with Graph Convolutional Networks” https://arxiv.org/abs/1609.02907
Uses GCNConv layers for message passing.
- Parameters:
in_features (
int
) – Size of input featureshidden_features (
int
) – Size of hidden layersnum_layers (
int
) – Number of GCN layersout_features (
Optional
[int
], default:None
) – Size of output (if None, uses hidden_features)dropout_rate (
float
, default:0.0
) – Dropout probabilityact (
Optional
[Callable
], default:None
) – Activation functionact_first (
bool
, default:False
) – If True, apply activation before normalizationnorm (
Optional
[str
], default:None
) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional
[str
], default:None
) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool
, default:False
) – Whether to use residual connectionsimproved – Use improved GCN normalization
cached – Cache normalized edge weights for static graphs
add_self_loops – Add self-loops to the graph
normalize – Apply symmetric normalization
rngs (
Optional
[Rngs
], default:None
) – Random number generators
Graph Convolutional Network model with configurable layers and normalization.
Example:
from jraphx.nn.models import GCN import flax.nnx as nnx model = GCN( in_features=16, hidden_features=64, num_layers=3, out_features=10, dropout_rate=0.5, norm="layer_norm", # Options: "batch_norm", "layer_norm", "graph_norm", None jk=None, # Options: "cat", "max", "lstm", None rngs=nnx.Rngs(0) ) # Forward pass out = model(x, edge_index, batch=batch)
GAT
- class GAT(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNN
Graph Attention Network.
From “Graph Attention Networks” https://arxiv.org/abs/1710.10903 or “How Attentive are Graph Attention Networks?” https://arxiv.org/abs/2105.14491
Uses GATConv or GATv2Conv layers for message passing.
- Parameters:
in_features (
int
) – Size of input featureshidden_features (
int
) – Size of hidden layers (per head if concat=True)num_layers (
int
) – Number of GAT layersout_features (
Optional
[int
], default:None
) – Size of output (if None, uses hidden_features)heads (
int
, default:1
) – Number of attention headsconcat (
bool
, default:True
) – Whether to concatenate or average multi-head outputsv2 (
bool
, default:False
) – Use GATv2Conv instead of GATConvdropout_rate (
float
, default:0.0
) – Dropout probabilityact (
Optional
[Callable
], default:None
) – Activation functionact_first (
bool
, default:False
) – If True, apply activation before normalizationnorm (
Optional
[str
], default:None
) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional
[str
], default:None
) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool
, default:False
) – Whether to use residual connectionsedge_dim (
Optional
[int
], default:None
) – Edge feature dimensionrngs (
Optional
[Rngs
], default:None
) – Random number generators
Graph Attention Network model with multi-head attention and configurable architecture.
Example:
from jraphx.nn.models import GAT import flax.nnx as nnx model = GAT( in_features=16, hidden_features=64, num_layers=3, out_features=10, heads=8, v2=False, # Use GATv2 if True dropout_rate=0.6, norm="layer_norm", jk="max", # JumpingKnowledge aggregation rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
GraphSAGE
- class GraphSAGE(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNN
GraphSAGE: Inductive Representation Learning on Large Graphs.
From “Inductive Representation Learning on Large Graphs” https://arxiv.org/abs/1706.02216
Uses SAGEConv layers for message passing.
- Parameters:
in_features (
int
) – Size of input featureshidden_features (
int
) – Size of hidden layersnum_layers (
int
) – Number of GraphSAGE layersout_features (
Optional
[int
], default:None
) – Size of output (if None, uses hidden_features)aggr – Aggregation method (‘mean’, ‘max’, ‘lstm’)
dropout_rate (
float
, default:0.0
) – Dropout probabilityact (
Optional
[Callable
], default:None
) – Activation functionact_first (
bool
, default:False
) – If True, apply activation before normalizationnorm (
Optional
[str
], default:None
) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional
[str
], default:None
) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool
, default:False
) – Whether to use residual connectionsnormalize – Whether to L2-normalize output features
rngs (
Optional
[Rngs
], default:None
) – Random number generators
GraphSAGE model with multiple aggregation options.
Example:
from jraphx.nn.models import GraphSAGE import flax.nnx as nnx model = GraphSAGE( in_features=16, hidden_features=64, num_layers=3, out_features=10, aggr="mean", # Options: "mean", "max", "lstm" dropout_rate=0.5, norm="batch_norm", jk="cat", # Concatenate all layer outputs rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
GIN
- class GIN(*args: Any, **kwargs: Any)[source]
Bases:
BasicGNN
Graph Isomorphism Network.
From “How Powerful are Graph Neural Networks?” https://arxiv.org/abs/1810.00826
Uses GINConv layers with MLP aggregation for message passing.
- Parameters:
in_features (
int
) – Size of input featureshidden_features (
int
) – Size of hidden layersnum_layers (
int
) – Number of GIN layersout_features (
Optional
[int
], default:None
) – Size of output (if None, uses hidden_features)dropout_rate (
float
, default:0.0
) – Dropout probabilityact (
Optional
[Callable
], default:None
) – Activation functionact_first (
bool
, default:False
) – If True, apply activation before normalizationnorm (
Optional
[str
], default:None
) – Normalization type (‘batch_norm’, ‘layer_norm’, None)jk (
Optional
[str
], default:None
) – Jumping Knowledge mode (‘last’, ‘cat’, ‘max’, ‘lstm’, None)residual (
bool
, default:False
) – Whether to use residual connectionstrain_eps – Whether to learn the epsilon parameter
rngs (
Optional
[Rngs
], default:None
) – Random number generators
Graph Isomorphism Network model with MLP transformations.
Example:
from jraphx.nn.models import GIN import flax.nnx as nnx model = GIN( in_features=16, hidden_features=64, num_layers=5, out_features=10, dropout_rate=0.5, norm="batch_norm", jk="cat", rngs=nnx.Rngs(0) ) out = model(x, edge_index, batch=batch)
Base Classes
BasicGNN
- class BasicGNN(*args: Any, **kwargs: Any)[source]
Bases:
Module
An abstract class for implementing basic GNN models.
- Parameters:
in_features (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.hidden_features (int) – Size of each hidden sample.
num_layers (int) – Number of message passing layers.
out_features (int, optional) – If not set to
None
, will apply a final linear transformation to convert hidden node embeddings to output sizeout_features
. (default:None
)dropout_rate (float, optional) – Dropout probability. (default:
0.
)act (Callable, optional) – The non-linear activation function to use. (default:
jax.nn.relu
)act_first (bool, optional) – If set to
True
, activation is applied before normalization. (default:False
)norm (str, optional) – The normalization function to use (
"batch_norm"
,"layer_norm"
,"graph_norm"
). (default:None
)jk (str, optional) – The Jumping Knowledge mode (
"last"
,"cat"
,"max"
,"lstm"
). (default:None
)residual (bool, optional) – Whether to use residual connections between layers. (default:
False
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.**kwargs (optional) – Additional arguments for the specific convolution layer.
Abstract base class for GNN models. Provides a common interface for building multi-layer GNNs with normalization, dropout, and JumpingKnowledge connections.
Subclassing Example:
from jraphx.nn.models import BasicGNN from jraphx.nn.conv import MessagePassing class MyCustomGNN(BasicGNN): def init_conv(self, in_features, out_features, rngs=None, **kwargs): # Return your custom message passing layer return MyCustomConv(in_features, out_features, rngs=rngs, **kwargs)
Utility Models
MLP
- class MLP(*args: Any, **kwargs: Any)[source]
Bases:
Module
A Multi-Layer Perception (MLP) model.
There exists two ways to instantiate an
MLP
:By specifying explicit feature sizes, e.g.,
mlp = MLP([16, 32, 64, 128], rngs=nnx.Rngs(0))
creates a three-layer MLP with differently sized hidden layers.
By specifying fixed hidden feature sizes over a number of layers, e.g.,
mlp = MLP(in_features=16, hidden_features=32, out_features=128, num_layers=3, rngs=nnx.Rngs(0))
creates a three-layer MLP with equally sized hidden layers.
- Parameters:
feature_list (List[int] or int, optional) – List of input, intermediate and output features such that
len(feature_list) - 1
denotes the number of layers of the MLP (default:None
)in_features (int, optional) – Size of each input sample. Will override
feature_list
. (default:None
)hidden_features (int, optional) – Size of each hidden sample. Will override
feature_list
. (default:None
)out_features (int, optional) – Size of each output sample. Will override
feature_list
. (default:None
)num_layers (int, optional) – The number of layers. Will override
feature_list
. (default:None
)dropout_rate (float, optional) – Dropout probability of each hidden embedding. (default:
0.
)act (Callable, optional) – The non-linear activation function to use. (default:
jax.nn.relu
)act_first (bool, optional) – If set to
True
, activation is applied before normalization. (default:False
)norm (str or Callable, optional) – The normalization function to use. (default:
None
)plain_last (bool, optional) – If set to
False
, will apply non-linearity, batch normalization and dropout to the last layer as well. (default:True
)bias (bool, optional) – If set to
False
, the module will not learn additive biases. (default:True
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
Multi-layer perceptron with configurable layers, normalization, and dropout.
Example:
from jraphx.nn.models import MLP import flax.nnx as nnx # Using channel list mlp = MLP( channel_list=[16, 64, 64, 32, 10], norm="layer_norm", bias=True, dropout_rate=0.5, act="relu", rngs=nnx.Rngs(0) ) # Or using in/hidden/out channels mlp = MLP( in_features=16, hidden_features=64, out_features=10, num_layers=3, norm="batch_norm", dropout_rate=0.5, rngs=nnx.Rngs(0) ) out = mlp(x)
JumpingKnowledge
- class JumpingKnowledge(*args: Any, **kwargs: Any)[source]
Bases:
Module
The Jumping Knowledge layer aggregation module from the “Representation Learning on Graphs with Jumping Knowledge Networks” paper.
Jumping knowledge is performed based on either concatenation (
"cat"
)\[\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)},\]max pooling (
"max"
)\[\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right),\]or weighted summation
\[\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}\]with attention scores \(\alpha_v^{(t)}\) obtained from a bi-directional LSTM (
"lstm"
).- Parameters:
mode (str) – The aggregation scheme to use (
"cat"
,"max"
or"lstm"
).num_features (int, optional) – The number of features per representation. Needs to be only set for LSTM-style aggregation. (default:
None
)num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default:
None
)rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
JumpingKnowledge layer for aggregating representations from different GNN layers.
Example:
from jraphx.nn.models import JumpingKnowledge import flax.nnx as nnx # Concatenation mode jk = JumpingKnowledge(mode="cat", channels=64, num_layers=3) # Max pooling mode jk = JumpingKnowledge(mode="max") # LSTM aggregation mode jk = JumpingKnowledge( mode="lstm", channels=64, num_layers=3, rngs=nnx.Rngs(0) ) # Aggregate layer outputs layer_outputs = [layer1_out, layer2_out, layer3_out] final_out = jk(layer_outputs)
Model Selection Guide
Choosing the Right Model
GCN: Best for citation networks and semi-supervised learning tasks with homophilic graphs.
GAT: Excellent for graphs where edge importance varies. The attention mechanism learns which neighbors are most relevant.
GraphSAGE: Ideal for large-scale graphs and inductive learning scenarios where you need to generalize to unseen nodes.
GIN: Most expressive for distinguishing graph structures. Best for graph-level tasks like molecular property prediction.
Configuration Tips
- Number of Layers:
2-3 layers for most node classification tasks
4-5 layers for graph-level tasks
Use JumpingKnowledge for deeper networks
- Normalization:
batch_norm
: Best for large batches and stable traininglayer_norm
: Works well with smaller batchesgraph_norm
: Specifically designed for graph data
- JumpingKnowledge:
cat
: Preserves all information but increases dimensionalitymax
: Good balance of expressiveness and efficiencylstm
: Most flexible but requires more parameters
- Dropout:
0.5-0.6 for training stability
Higher rates (0.6-0.8) for GAT models
Lower rates (0.2-0.5) for deeper models
Performance Comparison
Model |
Speed |
Memory |
Expressiveness |
Best For |
---|---|---|---|---|
GCN |
Fast |
Low |
Medium |
Node classification |
GAT |
Medium |
Medium-High |
High |
Heterophilic graphs |
GraphSAGE |
Fast |
Low-Medium |
Medium |
Large-scale graphs |
GIN |
Fast |
Low |
Highest |
Graph classification |