Advanced Techniques
This guide covers advanced features and techniques for optimizing JraphX applications.
Memory-Efficient Training
Using JAX Scan for Sequential Processing
For large graphs or long training sequences, use jax.lax.scan
to reduce memory consumption:
import jax
import jax.numpy as jnp
from jraphx.data import Data, Batch
def efficient_training_loop(model, optimizer, data_list, num_epochs):
"""Memory-efficient training using scan."""
def epoch_step(carry, epoch_data):
model, optimizer = carry
epoch_idx, batch_data = epoch_data
def loss_fn(model):
logits = model(batch_data.x, batch_data.edge_index)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits, batch_data.y
).mean()
return loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return (model, optimizer), loss
# Prepare data for all epochs
all_epochs_data = []
for epoch in range(num_epochs):
for i, data in enumerate(data_list):
all_epochs_data.append((epoch, data))
# Run training with scan
(model, optimizer), losses = jax.lax.scan(
epoch_step,
(model, optimizer),
jnp.array(all_epochs_data)
)
return model, optimizer, losses
Gradient Checkpointing
For very deep GNN models, use gradient checkpointing to trade compute for memory:
from jax import checkpoint
class DeepGNN(nnx.Module):
def __init__(self, num_layers, hidden_dim, rngs):
self.layers = [
GCNConv(hidden_dim, hidden_dim, rngs=rngs)
for _ in range(num_layers)
]
def __call__(self, x, edge_index):
for i, layer in enumerate(self.layers):
# Checkpoint every other layer
if i % 2 == 0:
x = checkpoint(layer)(x, edge_index)
else:
x = layer(x, edge_index)
x = nnx.relu(x)
return x
Vectorized Graph Processing with vmap
Process multiple graphs in parallel using JAX’s vmap
:
import jax
from jraphx.data.vmap_batch import pad_graph_data
def process_single_graph(data, model):
"""Process a single graph."""
return model(data.x, data.edge_index)
# Vectorize over a batch of graphs
process_batch = nnx.vmap(process_single_graph, in_axes=(0, None))
# Pad graphs to same size for vmap
padded_graphs = pad_graph_data(graph_list, max_nodes=100, max_edges=200)
# Process all graphs in parallel
outputs = process_batch(padded_graphs, model)
Custom vmap Patterns
def custom_vmap_aggregation(graphs, model):
"""Custom vmap pattern for graph aggregation."""
# Define per-graph operation
def per_graph_op(graph):
node_features = model(graph.x, graph.edge_index)
# Custom aggregation
graph_feature = node_features.mean(axis=0)
return graph_feature
# Vectorize and apply
vmapped_op = nnx.vmap(per_graph_op)
graph_features = vmapped_op(graphs)
# Further processing on all graphs
return graph_features.mean(axis=0)
Custom Message Passing Implementations
Implementing Edge-Conditioned Convolutions
from jraphx.nn.conv import MessagePassing
import flax.nnx as nnx
class EdgeConditionedConv(MessagePassing):
"""Message passing with edge features."""
def __init__(self, in_features, out_features, edge_dim, rngs):
super().__init__(aggr='mean')
self.node_mlp = nnx.Sequential(
nnx.Linear(in_features * 2 + edge_dim, out_features, rngs=rngs),
nnx.relu,
nnx.Linear(out_features, out_features, rngs=rngs)
)
def message(self, x_i, x_j, edge_attr):
# Concatenate source, target, and edge features
msg = jnp.concatenate([x_i, x_j, edge_attr], axis=-1)
return self.node_mlp(msg)
def __call__(self, x, edge_index, edge_attr):
return self.propagate(
edge_index, x=x, edge_attr=edge_attr
)
Implementing Attention Mechanisms
class CustomAttentionConv(MessagePassing):
"""Custom attention-based message passing."""
def __init__(self, in_features, out_features, heads=4, rngs=None):
super().__init__(aggr='add')
self.heads = heads
self.out_features = out_features
self.W_q = nnx.Linear(in_features, heads * out_features, rngs=rngs)
self.W_k = nnx.Linear(in_features, heads * out_features, rngs=rngs)
self.W_v = nnx.Linear(in_features, heads * out_features, rngs=rngs)
def message(self, x_i, x_j, edge_index_i, size_i):
# Multi-head attention
Q = self.W_q(x_i).reshape(-1, self.heads, self.out_features)
K = self.W_k(x_j).reshape(-1, self.heads, self.out_features)
V = self.W_v(x_j).reshape(-1, self.heads, self.out_features)
# Compute attention scores
scores = (Q * K).sum(axis=-1) / jnp.sqrt(self.out_features)
alpha = nnx.softmax(scores, axis=0)
# Apply attention to values
return (alpha[..., None] * V).reshape(-1, self.heads * self.out_features)
Performance Optimization
Efficient Scatter Operations
from jraphx.utils.scatter import scatter_sum, scatter_max
def efficient_aggregation(src, index, dim_size):
"""Combine multiple scatter operations efficiently."""
# Compute multiple aggregations in single pass
sum_result = scatter_sum(src, index, dim_size=dim_size)
max_result = scatter_max(src, index, dim_size=dim_size)
# Avoid redundant computations
mean_result = sum_result / scatter_sum(
jnp.ones_like(src), index, dim_size=dim_size
)
return {
'sum': sum_result,
'max': max_result,
'mean': mean_result
}
Working with Dynamic Graphs
Handling Variable-Size Graphs
def process_dynamic_graphs(graphs):
"""Process graphs with varying sizes."""
def process_single(graph):
# Pad to maximum size if needed
max_nodes = 1000
current_nodes = graph.x.shape[0]
if current_nodes < max_nodes:
pad_size = max_nodes - current_nodes
x_padded = jnp.pad(
graph.x,
((0, pad_size), (0, 0)),
mode='constant'
)
mask = jnp.concatenate([
jnp.ones(current_nodes),
jnp.zeros(pad_size)
])
else:
x_padded = graph.x[:max_nodes]
mask = jnp.ones(max_nodes)
return x_padded, mask
# Process each graph
processed = [process_single(g) for g in graphs]
return processed
Dynamic Edge Construction
def construct_knn_graph(x, k=10):
"""Dynamically construct k-NN graph from node features."""
# Compute pairwise distances
dist_matrix = jnp.sum((x[:, None] - x[None, :]) ** 2, axis=-1)
# Find k nearest neighbors
_, indices = jax.lax.top_k(-dist_matrix, k)
# Construct edge index
num_nodes = x.shape[0]
source = jnp.repeat(jnp.arange(num_nodes), k)
target = indices.flatten()
edge_index = jnp.stack([source, target])
return edge_index
Distributed Training
Data Parallel Training
import jax
from jax import pmap
def distributed_train_step(model, optimizer, batch):
"""Single training step for data parallel training."""
def loss_fn(model):
logits = model(batch.x, batch.edge_index)
loss = compute_loss(logits, batch.y)
return loss.mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
# Average gradients across devices
grads = jax.tree_map(lambda x: jax.lax.pmean(x, 'batch'), grads)
optimizer.update(model, grads)
return loss
# Parallelize across devices
parallel_train_step = pmap(distributed_train_step, axis_name='batch')
Model Parallel GNNs
def model_parallel_gnn(x, edge_index, num_devices=2):
"""Split GNN layers across devices."""
devices = jax.devices()[:num_devices]
# Split layers across devices
with jax.default_device(devices[0]):
x = layer1(x, edge_index)
x = layer2(x, edge_index)
with jax.default_device(devices[1]):
x = layer3(x, edge_index)
x = layer4(x, edge_index)
return x
Advanced Pooling Strategies
Differentiable Pooling
class DiffPool(nnx.Module):
"""Differentiable pooling layer."""
def __init__(self, in_features, ratio=0.5, rngs=None):
self.pool_gnn = GCNConv(in_features, int(in_features * ratio), rngs=rngs)
self.embed_gnn = GCNConv(in_features, in_features, rngs=rngs)
def __call__(self, x, edge_index, batch=None):
# Compute cluster assignments
s = self.pool_gnn(x, edge_index)
s = nnx.softmax(s, axis=-1)
# Compute new node features
x_pooled = s.T @ x
# Compute new adjacency
adj = to_dense_adj(edge_index)
adj_pooled = s.T @ adj @ s
# Convert back to edge index
edge_index_pooled = to_edge_index(adj_pooled)
return x_pooled, edge_index_pooled, s
Best Practices Summary
Memory Management - Use
jax.lax.scan
for sequential operations - Apply gradient checkpointing for deep models - Batch graphs efficiently with paddingPerformance - JIT compile performance-critical functions - Use static arguments for conditional logic - Leverage vmap for parallel processing
Scalability - Implement data parallel training with pmap - Use model parallelism for very large models - Consider dynamic batching for variable-size graphs
Debugging - Use
jax.debug.print
inside JIT-compiled functions - Check shapes withjax.debug.breakpoint
- Profile withjax.profiler
See Also
jraphx.nn - Neural network layer reference
JAX Integration with JraphX - Advanced JAX integration tutorial
JAX Documentation - JAX performance guide