JraphX GNN Cheatsheet
This cheatsheet provides an overview of all available Graph Neural Network layers in JraphX and their supported features.
Legend:
edge_weight
: If checked (✓), supports message passing with one-dimensional edge weight information, e.g.,GCNConv(...)(x, edge_index, edge_weight)
.edge_attr
: If checked (✓), supports message passing with multi-dimensional edge feature information, e.g.,GATConv(...)(x, edge_index, edge_attr)
.bipartite: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes.
JIT-ready: If checked (✓), the layer is fully compatible with
@jax.jit
compilation for optimal performance.vmap-ready: If checked (✓), the layer can be efficiently vectorized over multiple graphs using
nnx.vmap
.
Graph Neural Network Operators
Name |
|
|
bipartite |
JIT-ready |
vmap-ready |
---|---|---|---|---|---|
✓ |
✓ |
✓ |
|||
✓ |
✓ |
✓ |
|||
✓ |
✓ |
✓ |
|||
✓ |
✓ |
||||
✓ |
✓ |
||||
✓ |
✓ |
||||
✓ |
✓ |
✓ |
Pre-built Models
JraphX provides several pre-built GNN models that combine multiple layers:
Name |
JIT-ready |
vmap-ready |
---|---|---|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
Normalization Layers
Name |
JIT-ready |
vmap-ready |
---|---|---|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
Pooling Operations
Name |
JIT-ready |
vmap-ready |
---|---|---|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
|
✓ |
✓ |
Quick Usage Examples
Basic layer usage:
import jax.numpy as jnp
from flax import nnx
from jraphx.nn.conv import GCNConv, GATConv, EdgeConv
from jraphx.data import Data
from jraphx.nn.models import MLP
# Create graph data
x = jnp.ones((10, 16))
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])
data = Data(x=x, edge_index=edge_index)
# GCN layer (supports edge weights)
gcn = GCNConv(16, 32, rngs=nnx.Rngs(42))
gcn_out = gcn(data.x, data.edge_index)
# GAT layer (supports edge attributes)
gat = GATConv(16, 32, heads=4, rngs=nnx.Rngs(42))
gat_out = gat(data.x, data.edge_index)
# EdgeConv layer (requires neural network module)
edge_mlp = MLP([32, 32, 32], rngs=nnx.Rngs(42)) # 2*16 -> 32 -> 32
edge_conv = EdgeConv(edge_mlp, aggr='max')
edge_out = edge_conv(data.x, data.edge_index)
Pre-built model usage:
from jraphx.nn.models import GCN
# Create multi-layer GCN
model = GCN(
in_features=16,
hidden_features=64,
out_features=7,
num_layers=3,
dropout=0.1,
rngs=nnx.Rngs(42)
)
# Forward pass
predictions = model(data.x, data.edge_index)
Pooling for graph-level tasks:
from jraphx.nn.pool import global_mean_pool
from jraphx.data import Batch
# Create batch of graphs
graphs = [data, data, data] # 3 identical graphs for demo
batch = Batch.from_data_list(graphs)
# Get node-level features
node_features = model(batch.x, batch.edge_index)
# Pool to graph-level representations
graph_features = global_mean_pool(node_features, batch.batch)
print(f"Graph features: {graph_features.shape}") # [3, feature_dim]
JAX-Specific Optimizations
JraphX layers are designed to take full advantage of JAX’s capabilities:
JIT Compilation: All layers support
@jax.jit
for optimal performanceVectorization: Use
nnx.vmap
to process multiple graphs in parallelAutomatic Differentiation: Full support for
jax.grad
and optimization libraries like OptaxXLA Backend: Automatically optimized for your hardware (CPU/GPU/TPU)
Performance example:
import jax
# JIT compile for speed
@jax.jit
def fast_gnn_inference(model, x, edge_index):
return model(x, edge_index)
# Vectorize over multiple graphs (fixed-size)
@nnx.vmap
def batch_gnn_inference(x_batch, edge_index_batch):
return model(x_batch, edge_index_batch)
# Use with optimization libraries
import optax
optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)
@jax.jit
def train_step(model, optimizer, data, targets):
def loss_fn(model):
preds = model(data.x, data.edge_index)
return jnp.mean((preds - targets) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
Random Number Generation (Flax 0.11.2)
Use modern Flax 0.11.2 Rngs shorthand methods for cleaner code:
# Create Rngs with named key streams
rngs = nnx.Rngs(0, params=1, dropout=2)
# Old JAX approach:
# noise = random.normal(random.key(42), (10, 16))
# New Flax shorthand (much cleaner!):
noise = rngs.normal((10, 16)) # Default key
features = rngs.params.uniform((10, 16)) # Params key
dropout_mask = rngs.dropout.bernoulli(0.5, (10,)) # Dropout key
For more details, see the Flax randomness guide.
Missing Features
For a complete list of PyTorch Geometric features not yet implemented in JraphX, see Missing Features in JraphX.