JAX JIT Compilation
JAX provides Just-In-Time (JIT) compilation through XLA to optimize and accelerate your JraphX models. JIT compilation can provide significant performance improvements by optimizing the entire computation graph. If you are unfamiliar with JAX JIT, we recommend reading the official “JAX JIT tutorial” first.
JIT-Compiling GNN Models
All JraphX layers and models are designed to be JIT-compatible out of the box. Here’s how to JIT-compile a simple GNN model:
import jax
import jax.numpy as jnp
from flax import nnx
from jraphx.nn.models import GCN
from jraphx.data import Data
# Create model and data
model = GCN(
in_features=16,
hidden_features=64,
out_features=7,
num_layers=3,
rngs=nnx.Rngs(42)
)
data = Data(
x=jnp.ones((100, 16)),
edge_index=jnp.array([[0, 1, 2], [1, 2, 0]])
)
# JIT compile the forward pass
@jax.jit
def predict(model, x, edge_index):
return model(x, edge_index)
# First call compiles, subsequent calls are fast
predictions = predict(model, data.x, data.edge_index)
print(f"Predictions shape: {predictions.shape}")
JIT-Compiling Training Steps
For optimal performance, JIT-compile your entire training step:
import optax
# Setup optimizer
optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)
@jax.jit
def train_step(model, optimizer, x, edge_index, targets, train_indices):
"""JIT-compiled training step."""
def loss_fn(model):
predictions = model(x, edge_index)
# Use concrete indices instead of boolean mask for JIT compatibility
train_predictions = predictions[train_indices]
train_targets = targets[train_indices]
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
train_predictions, train_targets
))
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
# Training loop with JIT compilation
targets = jnp.array([0, 1, 2, 0, 1, 2, 0] * 14 + [0, 1, 2]) # 100 targets
train_indices = jnp.arange(80) # First 80 nodes for training (concrete indices)
for epoch in range(100):
loss = train_step(model, optimizer, data.x, data.edge_index, targets, train_indices)
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss:.4f}')
Custom Layer JIT Compatibility
When creating custom JraphX layers, ensure they are JIT-compatible by following these guidelines:
Use only JAX operations: Avoid Python control flow in favor of
jax.lax()
operationsStatic shapes: Ensure array shapes are statically known when possible
Pure functions: No side effects or global state modifications
from jraphx.nn.conv import MessagePassing
class CustomGNNLayer(MessagePassing):
def __init__(self, in_features, out_features, *, rngs: nnx.Rngs):
super().__init__(aggr='mean')
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
def __call__(self, x, edge_index):
# All operations here must be JAX-compatible
x = self.linear(x)
# Use JAX operations for conditionals
x = jnp.where(x > 0, x, 0.0) # ReLU activation
# Standard message passing
return self.propagate(edge_index, x)
# This layer is automatically JIT-compatible
@jax.jit
def forward_with_custom_layer(x, edge_index):
layer = CustomGNNLayer(16, 32, rngs=nnx.Rngs(42))
return layer(x, edge_index)
Performance Benefits
JIT compilation provides several benefits for JraphX models:
Speed: 2-10x faster execution after compilation
Memory: Optimized memory usage patterns
Optimization: XLA performs advanced optimizations like operator fusion
Parallelization: Automatic vectorization where possible
Benchmarking JIT vs non-JIT:
import time
# Non-JIT version
def slow_predict(model, x, edge_index):
return model(x, edge_index)
# JIT version
fast_predict = jax.jit(slow_predict)
# Warm up JIT (compilation happens here)
_ = fast_predict(model, data.x, data.edge_index)
# Benchmark
start = time.time()
for _ in range(100):
_ = slow_predict(model, data.x, data.edge_index)
slow_time = time.time() - start
start = time.time()
for _ in range(100):
_ = fast_predict(model, data.x, data.edge_index)
fast_time = time.time() - start
print(f"Speed improvement: {slow_time / fast_time:.2f}x")
Best Practices
JIT the training step: Compile the entire training step for maximum benefit
Warm up on dummy data: Compile before timing-critical sections
Static shapes: Use fixed-size arrays when possible for better optimization
Batch processing: JIT works especially well with batched operations
Avoid Python loops: Use
jax.lax.scan()
ornnx.vmap()
instead
# Good: JIT-friendly batch processing
@jax.jit
def process_batch(model, batch_x, batch_edge_index):
return nnx.vmap(model)(batch_x, batch_edge_index)
# Better: Use JraphX Batch for variable-size graphs
@jax.jit
def process_jraphx_batch(model, batch):
return model(batch.x, batch.edge_index)
Common Pitfalls
Dynamic shapes: Avoid operations that change array shapes based on data
Python conditionals: Use
jnp.where()
instead ofif
statementsGlobal state: Avoid modifying global variables inside JIT functions
Device transfers: Minimize data movement between devices within JIT functions
For more information on JAX JIT compilation, see the JAX documentation.