Advanced Techniques

This guide covers advanced features and techniques for optimizing JraphX applications.

Memory-Efficient Training

Using JAX Scan for Sequential Processing

For large graphs or long training sequences, use jax.lax.scan to reduce memory consumption:

import jax
import jax.numpy as jnp
from jraphx.data import Data, Batch

def efficient_training_loop(model, optimizer, data_list, num_epochs):
    """Memory-efficient training using scan."""

    def epoch_step(carry, epoch_data):
        model, optimizer = carry
        epoch_idx, batch_data = epoch_data

        def loss_fn(model):
            logits = model(batch_data.x, batch_data.edge_index)
            loss = optax.softmax_cross_entropy_with_integer_labels(
                logits, batch_data.y
            ).mean()
            return loss

        loss, grads = nnx.value_and_grad(loss_fn)(model)
        optimizer.update(model, grads)

        return (model, optimizer), loss

    # Prepare data for all epochs
    all_epochs_data = []
    for epoch in range(num_epochs):
        for i, data in enumerate(data_list):
            all_epochs_data.append((epoch, data))

    # Run training with scan
    (model, optimizer), losses = jax.lax.scan(
        epoch_step,
        (model, optimizer),
        jnp.array(all_epochs_data)
    )

    return model, optimizer, losses

Gradient Checkpointing

For very deep GNN models, use gradient checkpointing to trade compute for memory:

from jax import checkpoint

class DeepGNN(nnx.Module):
    def __init__(self, num_layers, hidden_dim, rngs):
        self.layers = [
            GCNConv(hidden_dim, hidden_dim, rngs=rngs)
            for _ in range(num_layers)
        ]

    def __call__(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            # Checkpoint every other layer
            if i % 2 == 0:
                x = checkpoint(layer)(x, edge_index)
            else:
                x = layer(x, edge_index)
            x = nnx.relu(x)
        return x

Vectorized Graph Processing with vmap

Process multiple graphs in parallel using JAX’s vmap:

import jax
from jraphx.data.vmap_batch import pad_graph_data

def process_single_graph(data, model):
    """Process a single graph."""
    return model(data.x, data.edge_index)

# Vectorize over a batch of graphs
process_batch = nnx.vmap(process_single_graph, in_axes=(0, None))

# Pad graphs to same size for vmap
padded_graphs = pad_graph_data(graph_list, max_nodes=100, max_edges=200)

# Process all graphs in parallel
outputs = process_batch(padded_graphs, model)

Custom vmap Patterns

def custom_vmap_aggregation(graphs, model):
    """Custom vmap pattern for graph aggregation."""

    # Define per-graph operation
    def per_graph_op(graph):
        node_features = model(graph.x, graph.edge_index)
        # Custom aggregation
        graph_feature = node_features.mean(axis=0)
        return graph_feature

    # Vectorize and apply
    vmapped_op = nnx.vmap(per_graph_op)
    graph_features = vmapped_op(graphs)

    # Further processing on all graphs
    return graph_features.mean(axis=0)

Custom Message Passing Implementations

Implementing Edge-Conditioned Convolutions

from jraphx.nn.conv import MessagePassing
import flax.nnx as nnx

class EdgeConditionedConv(MessagePassing):
    """Message passing with edge features."""

    def __init__(self, in_features, out_features, edge_dim, rngs):
        super().__init__(aggr='mean')
        self.node_mlp = nnx.Sequential(
            nnx.Linear(in_features * 2 + edge_dim, out_features, rngs=rngs),
            nnx.relu,
            nnx.Linear(out_features, out_features, rngs=rngs)
        )

    def message(self, x_i, x_j, edge_attr):
        # Concatenate source, target, and edge features
        msg = jnp.concatenate([x_i, x_j, edge_attr], axis=-1)
        return self.node_mlp(msg)

    def __call__(self, x, edge_index, edge_attr):
        return self.propagate(
            edge_index, x=x, edge_attr=edge_attr
        )

Implementing Attention Mechanisms

class CustomAttentionConv(MessagePassing):
    """Custom attention-based message passing."""

    def __init__(self, in_features, out_features, heads=4, rngs=None):
        super().__init__(aggr='add')
        self.heads = heads
        self.out_features = out_features

        self.W_q = nnx.Linear(in_features, heads * out_features, rngs=rngs)
        self.W_k = nnx.Linear(in_features, heads * out_features, rngs=rngs)
        self.W_v = nnx.Linear(in_features, heads * out_features, rngs=rngs)

    def message(self, x_i, x_j, edge_index_i, size_i):
        # Multi-head attention
        Q = self.W_q(x_i).reshape(-1, self.heads, self.out_features)
        K = self.W_k(x_j).reshape(-1, self.heads, self.out_features)
        V = self.W_v(x_j).reshape(-1, self.heads, self.out_features)

        # Compute attention scores
        scores = (Q * K).sum(axis=-1) / jnp.sqrt(self.out_features)
        alpha = nnx.softmax(scores, axis=0)

        # Apply attention to values
        return (alpha[..., None] * V).reshape(-1, self.heads * self.out_features)

Performance Optimization

Efficient Scatter Operations

from jraphx.utils.scatter import scatter_sum, scatter_max

def efficient_aggregation(src, index, dim_size):
    """Combine multiple scatter operations efficiently."""

    # Compute multiple aggregations in single pass
    sum_result = scatter_sum(src, index, dim_size=dim_size)
    max_result = scatter_max(src, index, dim_size=dim_size)

    # Avoid redundant computations
    mean_result = sum_result / scatter_sum(
        jnp.ones_like(src), index, dim_size=dim_size
    )

    return {
        'sum': sum_result,
        'max': max_result,
        'mean': mean_result
    }

Working with Dynamic Graphs

Handling Variable-Size Graphs

def process_dynamic_graphs(graphs):
    """Process graphs with varying sizes."""

    def process_single(graph):
        # Pad to maximum size if needed
        max_nodes = 1000
        current_nodes = graph.x.shape[0]

        if current_nodes < max_nodes:
            pad_size = max_nodes - current_nodes
            x_padded = jnp.pad(
                graph.x,
                ((0, pad_size), (0, 0)),
                mode='constant'
            )
            mask = jnp.concatenate([
                jnp.ones(current_nodes),
                jnp.zeros(pad_size)
            ])
        else:
            x_padded = graph.x[:max_nodes]
            mask = jnp.ones(max_nodes)

        return x_padded, mask

    # Process each graph
    processed = [process_single(g) for g in graphs]
    return processed

Dynamic Edge Construction

def construct_knn_graph(x, k=10):
    """Dynamically construct k-NN graph from node features."""

    # Compute pairwise distances
    dist_matrix = jnp.sum((x[:, None] - x[None, :]) ** 2, axis=-1)

    # Find k nearest neighbors
    _, indices = jax.lax.top_k(-dist_matrix, k)

    # Construct edge index
    num_nodes = x.shape[0]
    source = jnp.repeat(jnp.arange(num_nodes), k)
    target = indices.flatten()

    edge_index = jnp.stack([source, target])
    return edge_index

Distributed Training

Data Parallel Training

import jax
from jax import pmap

def distributed_train_step(model, optimizer, batch):
    """Single training step for data parallel training."""

    def loss_fn(model):
        logits = model(batch.x, batch.edge_index)
        loss = compute_loss(logits, batch.y)
        return loss.mean()

    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Average gradients across devices
    grads = jax.tree_map(lambda x: jax.lax.pmean(x, 'batch'), grads)

    optimizer.update(model, grads)
    return loss

# Parallelize across devices
parallel_train_step = pmap(distributed_train_step, axis_name='batch')

Model Parallel GNNs

def model_parallel_gnn(x, edge_index, num_devices=2):
    """Split GNN layers across devices."""

    devices = jax.devices()[:num_devices]

    # Split layers across devices
    with jax.default_device(devices[0]):
        x = layer1(x, edge_index)
        x = layer2(x, edge_index)

    with jax.default_device(devices[1]):
        x = layer3(x, edge_index)
        x = layer4(x, edge_index)

    return x

Advanced Pooling Strategies

Differentiable Pooling

class DiffPool(nnx.Module):
    """Differentiable pooling layer."""

    def __init__(self, in_features, ratio=0.5, rngs=None):
        self.pool_gnn = GCNConv(in_features, int(in_features * ratio), rngs=rngs)
        self.embed_gnn = GCNConv(in_features, in_features, rngs=rngs)

    def __call__(self, x, edge_index, batch=None):
        # Compute cluster assignments
        s = self.pool_gnn(x, edge_index)
        s = nnx.softmax(s, axis=-1)

        # Compute new node features
        x_pooled = s.T @ x

        # Compute new adjacency
        adj = to_dense_adj(edge_index)
        adj_pooled = s.T @ adj @ s

        # Convert back to edge index
        edge_index_pooled = to_edge_index(adj_pooled)

        return x_pooled, edge_index_pooled, s

Best Practices Summary

  1. Memory Management - Use jax.lax.scan for sequential operations - Apply gradient checkpointing for deep models - Batch graphs efficiently with padding

  2. Performance - JIT compile performance-critical functions - Use static arguments for conditional logic - Leverage vmap for parallel processing

  3. Scalability - Implement data parallel training with pmap - Use model parallelism for very large models - Consider dynamic batching for variable-size graphs

  4. Debugging - Use jax.debug.print inside JIT-compiled functions - Check shapes with jax.debug.breakpoint - Profile with jax.profiler

See Also