Introduction by Example
We shortly introduce the fundamental concepts of JraphX through self-contained examples.
For an introduction to Graph Machine Learning, we refer the interested reader to the Stanford CS224W: Machine Learning with Graphs lectures. For an introduction to JAX, see the JAX documentation.
At its core, JraphX provides the following main features:
Data Handling of Graphs
A graph is used to model pairwise relations (edges) between objects (nodes).
A single graph in JraphX is described by an instance of jraphx.data.Data
, which holds the following attributes by default:
data.x
: Node feature matrix with shape[num_nodes, num_node_features]
as a JAX arraydata.edge_index
: Graph connectivity with shape[2, num_edges]
as a JAX array with integer dtypedata.edge_attr
: Edge feature matrix with shape[num_edges, num_edge_features]
as a JAX arraydata.y
: Target to train against (may have arbitrary shape), e.g., node-level targets of shape[num_nodes, *]
or graph-level targets of shape[1, *]
data.pos
: Node position matrix with shape[num_nodes, num_dimensions]
for 3D point clouds
None of these attributes are required.
In fact, the Data
object is not even restricted to these attributes.
We can extend it to save the connectivity of triangles from a 3D mesh in a JAX array with shape [3, num_faces]
.
See the 3D Mesh Graphs example for a complete implementation.
Note
JAX uses a functional programming paradigm where arrays are immutable. This means that operations on Data
objects return new instances rather than modifying existing ones.
We show a simple example of an unweighted and undirected graph with three nodes and four edges. Each node contains exactly one feature:
import jax.numpy as jnp
from jraphx.data import Data
edge_index = jnp.array([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=jnp.int32)
x = jnp.array([[-1.0], [0.0], [1.0]], dtype=jnp.float32)
data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
Note that edge_index
, i.e. the array defining the source and target nodes of all edges, is not a list of index tuples.
If you want to write your indices this way, you should transpose it before passing to the data constructor:
import jax.numpy as jnp
from jraphx.data import Data
edge_index = jnp.array([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=jnp.int32)
x = jnp.array([[-1.0], [0.0], [1.0]], dtype=jnp.float32)
data = Data(x=x, edge_index=edge_index.T)
>>> Data(edge_index=[2, 4], x=[3, 1])
Although the graph has only two edges, we need to define four index tuples to account for both directions of a edge.
Note
You can print out your data object anytime and receive information about its attributes and their shapes.
Note that it is necessary that the elements in edge_index
only hold indices in the range { 0, ..., num_nodes - 1}
.
This is needed as we want our final data representation to be as compact as possible, e.g., we want to index the source and destination node features of the first edge (0, 1)
via x[0]
and x[1]
, respectively.
Besides holding a number of node-level, edge-level or graph-level attributes, Data
provides a number of useful utility functions, e.g.:
print(data.keys())
>>> ['x', 'edge_index']
print(data['x'])
>>> Array([[-1.0],
[ 0.0],
[ 1.0]], dtype=float32)
for key, item in data:
print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data
'edge_attr' in data
>>> False
data.num_nodes
>>> 3
data.num_edges
>>> 4
data.num_node_features
>>> 1
data.has_isolated_nodes()
>>> False
data.has_self_loops()
>>> False
data.is_directed
>>> False
You can find a complete list of all methods at jraphx.data.Data
.
Working with JAX Arrays
JraphX is designed to work seamlessly with JAX arrays and the JAX ecosystem. Unlike PyTorch tensors, JAX arrays are immutable and operations are functional. Here are some key concepts:
JAX arrays can be created from Python lists or NumPy arrays:
import jax.numpy as jnp
from jraphx.data import Data
# Create JAX arrays for graph data
node_features = jnp.array([[1.0, 0.5], [0.0, 1.0], [0.5, 0.0]], dtype=jnp.float32)
edge_indices = jnp.array([[0, 1, 2], [1, 2, 0]], dtype=jnp.int32)
data = Data(x=node_features, edge_index=edge_indices)
print(data.x.shape)
>>> (3, 2)
JraphX integrates well with JAX’s transformation system. You can use jax.jit
to compile functions for better performance:
import jax
import jax.numpy as jnp
from jraphx.utils import degree
def compute_node_degrees(edge_index, num_nodes):
"""Compute node degrees using JAX."""
return degree(edge_index[1], num_nodes)
# JIT compile with static num_nodes argument
jit_compute_degrees = jax.jit(compute_node_degrees, static_argnums=(1,))
degrees = jit_compute_degrees(data.edge_index, data.x.shape[0])
print(degrees)
>>> [1. 1. 1.]
For processing multiple graphs efficiently, you can use jax.vmap
:
# Create multiple graphs
graphs = [Data(x=jnp.ones((3, 2)), edge_index=jnp.array([[0, 1], [1, 0]]))
for _ in range(5)]
# Process multiple graphs in parallel
def process_single_graph(data):
return jnp.sum(data.x)
# vmap over a batch of graphs
batched_process = jax.vmap(process_single_graph)
# results = batched_process(graph_batch) # Requires proper batching
Mini-batches
Neural networks are usually trained in a batch-wise fashion.
JraphX achieves parallelization over a mini-batch by creating sparse block diagonal adjacency matrices (defined by edge_index
) and concatenating feature and target matrices in the node dimension.
This composition allows differing number of nodes and edges over examples in one batch:
JraphX provides a jraphx.data.Batch
class that handles this concatenation process.
Let’s learn about it in an example:
import jax.numpy as jnp
from jraphx.data import Data, Batch
from jraphx.nn.pool import global_mean_pool
# Create some example graphs
graphs = []
for i in range(3):
x = jnp.ones((4, 2), dtype=jnp.float32) * (i + 1)
edge_index = jnp.array([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=jnp.int32)
graphs.append(Data(x=x, edge_index=edge_index))
# Create a batch from multiple graphs
batch = Batch.from_data_list(graphs)
print(batch)
>>> Batch(batch=[12], edge_index=[2, 12], x=[12, 2])
print(batch.num_graphs)
>>> 3
jraphx.data.Batch
inherits from jraphx.data.Data
and contains an additional attribute called batch
.
batch
is a column vector which maps each node to its respective graph in the batch:
You can use it to, e.g., average node features in the node dimension for each graph individually:
from jraphx.utils import scatter
# Average node features per graph
graph_embeddings = scatter(batch.x, batch.batch, dim_size=batch.num_graphs, dim=0, reduce='mean')
print(graph_embeddings.shape)
>>> (3, 2) # 3 graphs, 2 features each
You can learn more about the internal batching procedure of JraphX, e.g., how to modify its behavior, here.
For documentation of scatter operations, see jraphx.utils.scatter
.
Using Graph Convolution Layers
JraphX provides various graph neural network layers:
import flax.nnx as nnx
from jraphx.nn.conv import GCNConv, GATConv, SAGEConv
# Initialize random number generator
rngs = nnx.Rngs(42)
# Graph Convolutional Network (GCN)
gcn = GCNConv(in_features=3, out_features=16, rngs=rngs)
out = gcn(data.x, data.edge_index)
# Graph Attention Network (GAT)
gat = GATConv(in_features=3, out_features=16, heads=4, rngs=rngs)
out = gat(data.x, data.edge_index)
# GraphSAGE
sage = SAGEConv(in_features=3, out_features=16, rngs=rngs)
out = sage(data.x, data.edge_index)
Building a Complete GNN Model
Combine multiple layers to create a complete GNN model:
import jax
import flax.nnx as nnx
from jraphx.nn.conv import GCNConv
from jraphx.nn.pool import global_mean_pool
class GNN(nnx.Module):
def __init__(self, in_features, hidden_features, out_features, rngs):
self.conv1 = GCNConv(in_features, hidden_features, rngs=rngs)
self.conv2 = GCNConv(hidden_features, hidden_features, rngs=rngs)
self.conv3 = GCNConv(hidden_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
def __call__(self, x, edge_index, batch=None):
# First GCN layer
x = self.conv1(x, edge_index)
x = nnx.relu(x)
x = self.dropout(x)
# Second GCN layer
x = self.conv2(x, edge_index)
x = nnx.relu(x)
x = self.dropout(x)
# Third GCN layer
x = self.conv3(x, edge_index)
# Global pooling (for graph-level prediction)
if batch is not None:
x = global_mean_pool(x, batch)
return x
# Create model
model = GNN(in_features=3, hidden_features=64, out_features=10, rngs=nnx.Rngs(42))
# Forward pass
output = model(data.x, data.edge_index)
Model Inspection with nnx.tabulate
JraphX leverages NNX’s model inspection for transparent development:
from flax import nnx
from jraphx.nn.models import GAT
# Create model
model = GAT(in_features=32, hidden_features=64, out_features=16,
heads=4, num_layers=2, rngs=nnx.Rngs(42))
x = jnp.ones((50, 32))
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])
# Inspect complete model structure and parameters
print(nnx.tabulate(model, x, edge_index, depth=2))
This shows layer hierarchy, parameter counts, input/output shapes, and memory usage - essential for understanding complex GNN architectures before training.
Train/Eval Modes
NNX provides efficient train/eval mode handling for models with dropout or batch normalization:
from jraphx.nn.models import GraphSAGE
# Create model with dropout
model = GraphSAGE(in_features=16, hidden_features=32, out_features=8,
num_layers=2, dropout_rate=0.5, rngs=nnx.Rngs(42))
model.train() # Set to training mode
# Create evaluation model that shares weights
eval_model = nnx.merge(*nnx.split(model)) # Same weights, different behavior
eval_model.eval() # Set to evaluation mode
# Both models share weights but behave differently
train_out = model(x, edge_index) # Uses dropout
eval_out = eval_model(x, edge_index) # No dropout
# Weights stay synchronized automatically - no copying needed!
print("Weights shared:", jnp.allclose(
model.convs[0].linear.kernel.value,
eval_model.convs[0].linear.kernel.value
))
>>> Weights shared: True
For more details, see the Flax documentation for nnx.Module.train() and nnx.Module.eval().
Training a GNN
Here’s a simple training loop example:
import optax
from jraphx.data import DataLoader
# Create optimizer
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.01), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, data, labels):
# Ensure model is in training mode
model.train()
def loss_fn(model):
logits = model(data.x, data.edge_index)
loss = optax.softmax_cross_entropy(logits, labels).mean()
return loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
# Training loop
for epoch in range(100):
loss = train_step(model, optimizer, data, labels)
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")
Data Preprocessing with JAX
JraphX leverages JAX’s functional programming approach for data preprocessing. You can create pure functions to preprocess your data:
import jax
import jax.numpy as jnp
from jraphx.data import Data
from jraphx.utils import add_self_loops
@jax.jit
def preprocess_graph(data):
"""Add self-loops and normalize features."""
# Add self-loops
edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.shape[0])
# Normalize node features
x_normalized = data.x / jnp.linalg.norm(data.x, axis=1, keepdims=True)
return data.replace(x=x_normalized, edge_index=edge_index)
# Apply preprocessing
original_data = Data(x=jnp.ones((3, 2)), edge_index=jnp.array([[0, 1], [1, 2]]))
processed_data = preprocess_graph(original_data)
For more complex preprocessing pipelines, you can compose functions:
def add_positional_encoding(data, rngs, dim=16):
"""Add random positional encoding to nodes."""
pos_enc = rngs.normal((data.x.shape[0], dim)) # Flax 0.11.2 shorthand method!
x_with_pos = jnp.concatenate([data.x, pos_enc], axis=1)
return data.replace(x=x_with_pos)
def preprocessing_pipeline(data, rngs):
"""Full preprocessing pipeline."""
data = preprocess_graph(data)
data = add_positional_encoding(data, rngs)
return data
# Apply full pipeline with random number generator
rngs = nnx.Rngs(42) # Can also use: rngs = nnx.Rngs(0, params=1)
final_data = preprocessing_pipeline(original_data, rngs)
Learning Methods on Graphs
After learning about data handling and preprocessing in JraphX, it’s time to implement our first graph neural network!
We will use a simple GCN layer implemented with JAX and Flax NNX. For a high-level explanation on GCN, have a look at its blog post.
Let’s create some example graph data:
import jax.numpy as jnp
from jraphx.data import Data
# Create a simple graph with 4 nodes, 3 features per node, 3 classes
x = jnp.array([[1.0, 0.5, 0.2], [0.8, 1.0, 0.1], [0.3, 0.7, 1.0], [0.9, 0.2, 0.8]], dtype=jnp.float32)
edge_index = jnp.array([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=jnp.int32) # Undirected edges
y = jnp.array([0, 0, 1, 1], dtype=jnp.int32) # Node labels
data = Data(x=x, edge_index=edge_index, y=y)
print(f"Graph: {data.num_nodes} nodes, {data.num_edges} edges")
Now let’s implement a two-layer GCN using Flax NNX:
import jax.numpy as jnp
from flax import nnx
from jraphx.nn.conv import GCNConv
class GCN(nnx.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int, *, rngs: nnx.Rngs):
self.conv1 = GCNConv(in_features, hidden_features, rngs=rngs)
self.conv2 = GCNConv(hidden_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = nnx.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
return nnx.log_softmax(x, axis=1)
# Create model
model = GCN(in_features=3, hidden_features=16, out_features=3, rngs=nnx.Rngs(42))
The model defines two GCNConv
layers which get called in sequence.
Note that the non-linearity is not integrated in the conv
calls and hence needs to be applied afterwards (consistent with JraphX design).
Here, we use ReLU as our intermediate non-linearity and output a log-softmax distribution over classes.
Let’s create a simple training function using JAX:
import optax
def loss_fn(model, data, train_mask):
"""Compute cross-entropy loss on training nodes."""
logits = model(data)
# Select only training nodes
train_logits = logits[train_mask]
train_labels = data.y[train_mask]
return optax.softmax_cross_entropy_with_integer_labels(train_logits, train_labels).mean()
# Setup optimizer
optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)
# Training loop
train_mask = jnp.array([True, True, False, False]) # First 2 nodes for training
test_mask = jnp.array([False, False, True, True]) # Last 2 nodes for testing
@nnx.jit
def train_step(model, optimizer, data, train_mask):
def loss_fn_inner(model):
return loss_fn(model, data, train_mask)
loss, grads = nnx.value_and_grad(loss_fn_inner)(model)
optimizer.update(model, grads)
return loss
# Train for a few epochs
model.train()
for epoch in range(200):
loss = train_step(model, optimizer, data, train_mask)
if epoch % 50 == 0:
print(f'Epoch {epoch}, Loss: {loss:.4f}')
Finally, we can evaluate our model:
@nnx.jit
def evaluate(model, data, test_mask):
"""Evaluate model accuracy on test nodes."""
logits = model(data)
pred = jnp.argmax(logits, axis=1)
correct = jnp.sum(pred[test_mask] == data.y[test_mask])
accuracy = correct / jnp.sum(test_mask)
return accuracy
model.eval()
test_accuracy = evaluate(model, data, test_mask)
print(f'Test Accuracy: {test_accuracy:.4f}')
>>> Test Accuracy: 0.5000 # Small dataset, results may vary
This is all it takes to implement your first graph neural network with JraphX!
The key advantages of using JAX/Flax NNX are automatic differentiation, JIT compilation for speed, and functional programming patterns.
The easiest way to learn more about Graph Neural Networks is to browse jraphx.nn
and experiment with different layer combinations.
Exercises
What does
edge_index.T
do in JAX? How is it different from PyTorch’sedge_index.t().contiguous()
?Create a function that generates a random graph with
n
nodes andm
edges using JAX arrays. Make sure the function is JIT-compilable.What does each number of the following output mean?
print(batch) >>> Batch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
Implement a preprocessing function using
@jax.jit
that adds self-loops to a graph and normalizes node features. Test it on a simple graph.Create a batched version of the GCN model that can process multiple graphs simultaneously using
nnx.vmap
.