JraphX GNN Cheatsheet

This cheatsheet provides an overview of all available Graph Neural Network layers in JraphX and their supported features.

Legend:

  • edge_weight: If checked (✓), supports message passing with one-dimensional edge weight information, e.g., GCNConv(...)(x, edge_index, edge_weight).

  • edge_attr: If checked (✓), supports message passing with multi-dimensional edge feature information, e.g., GATConv(...)(x, edge_index, edge_attr).

  • bipartite: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes.

  • JIT-ready: If checked (✓), the layer is fully compatible with @jax.jit compilation for optimal performance.

  • vmap-ready: If checked (✓), the layer can be efficiently vectorized over multiple graphs using nnx.vmap.

Graph Neural Network Operators

Name

edge_weight

edge_attr

bipartite

JIT-ready

vmap-ready

GCNConv (Paper)

GATConv (Paper)

GATv2Conv (Paper)

SAGEConv (Paper)

GINConv (Paper)

EdgeConv (Paper)

TransformerConv (Paper)

Pre-built Models

JraphX provides several pre-built GNN models that combine multiple layers:

Name

JIT-ready

vmap-ready

GCN

GAT

GraphSAGE

GIN

MLP

JumpingKnowledge

Normalization Layers

Name

JIT-ready

vmap-ready

BatchNorm

LayerNorm

GraphNorm

Pooling Operations

Name

JIT-ready

vmap-ready

global_add_pool()

global_mean_pool()

global_max_pool()

TopKPooling

SAGPooling

Quick Usage Examples

Basic layer usage:

import jax.numpy as jnp
from flax import nnx
from jraphx.nn.conv import GCNConv, GATConv, EdgeConv
from jraphx.data import Data
from jraphx.nn.models import MLP

# Create graph data
x = jnp.ones((10, 16))
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])
data = Data(x=x, edge_index=edge_index)

# GCN layer (supports edge weights)
gcn = GCNConv(16, 32, rngs=nnx.Rngs(42))
gcn_out = gcn(data.x, data.edge_index)

# GAT layer (supports edge attributes)
gat = GATConv(16, 32, heads=4, rngs=nnx.Rngs(42))
gat_out = gat(data.x, data.edge_index)

# EdgeConv layer (requires neural network module)
edge_mlp = MLP([32, 32, 32], rngs=nnx.Rngs(42))  # 2*16 -> 32 -> 32
edge_conv = EdgeConv(edge_mlp, aggr='max')
edge_out = edge_conv(data.x, data.edge_index)

Pre-built model usage:

from jraphx.nn.models import GCN

# Create multi-layer GCN
model = GCN(
    in_features=16,
    hidden_features=64,
    out_features=7,
    num_layers=3,
    dropout=0.1,
    rngs=nnx.Rngs(42)
)

# Forward pass
predictions = model(data.x, data.edge_index)

Pooling for graph-level tasks:

from jraphx.nn.pool import global_mean_pool
from jraphx.data import Batch

# Create batch of graphs
graphs = [data, data, data]  # 3 identical graphs for demo
batch = Batch.from_data_list(graphs)

# Get node-level features
node_features = model(batch.x, batch.edge_index)

# Pool to graph-level representations
graph_features = global_mean_pool(node_features, batch.batch)
print(f"Graph features: {graph_features.shape}")  # [3, feature_dim]

JAX-Specific Optimizations

JraphX layers are designed to take full advantage of JAX’s capabilities:

  • JIT Compilation: All layers support @jax.jit for optimal performance

  • Vectorization: Use nnx.vmap to process multiple graphs in parallel

  • Automatic Differentiation: Full support for jax.grad and optimization libraries like Optax

  • XLA Backend: Automatically optimized for your hardware (CPU/GPU/TPU)

Performance example:

import jax

# JIT compile for speed
@jax.jit
def fast_gnn_inference(model, x, edge_index):
    return model(x, edge_index)

# Vectorize over multiple graphs (fixed-size)
@nnx.vmap
def batch_gnn_inference(x_batch, edge_index_batch):
    return model(x_batch, edge_index_batch)

# Use with optimization libraries
import optax
optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, data, targets):
    def loss_fn(model):
        preds = model(data.x, data.edge_index)
        return jnp.mean((preds - targets) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss

Random Number Generation (Flax 0.11.2)

Use modern Flax 0.11.2 Rngs shorthand methods for cleaner code:

# Create Rngs with named key streams
rngs = nnx.Rngs(0, params=1, dropout=2)

# Old JAX approach:
# noise = random.normal(random.key(42), (10, 16))

# New Flax shorthand (much cleaner!):
noise = rngs.normal((10, 16))                    # Default key
features = rngs.params.uniform((10, 16))         # Params key
dropout_mask = rngs.dropout.bernoulli(0.5, (10,))  # Dropout key

For more details, see the Flax randomness guide.

Missing Features

For a complete list of PyTorch Geometric features not yet implemented in JraphX, see Missing Features in JraphX.