Data Structures & Utilities Cheatsheet
This cheatsheet covers JraphX’s core data structures and utility functions for working with graphs.
Note
JraphX is focused on providing GNN layers and utilities for JAX, not datasets. For datasets, you’ll typically load data from external sources (files, other libraries) and convert them to JraphX format.
Core Data Structures
Class |
Purpose |
Key Methods |
---|---|---|
|
Single graph representation |
|
|
Multiple graphs in a batch |
|
Data Attributes
Attribute |
Shape |
Description |
Required |
---|---|---|---|
|
|
Node feature matrix |
Optional |
|
|
Edge connectivity in COO format |
Optional |
|
|
Edge feature matrix |
Optional |
|
|
Labels/targets |
Optional |
|
|
Node positions (3D point clouds) |
Optional |
|
|
Batch assignment vector |
Auto-generated |
Graph Utility Functions
Function |
Purpose |
JIT-ready |
---|---|---|
Add self-loop edges to graph |
✓ |
|
Remove self-loop edges from graph |
✓ |
|
Compute node degrees |
✓ |
|
|
Compute in-degrees (directed graphs) |
✓ |
|
Compute out-degrees (directed graphs) |
✓ |
Remove duplicate edges |
✓ |
|
Convert directed to undirected graph |
✓ |
|
Convert edge_index to dense adjacency |
✓ |
|
Convert dense adjacency to edge_index |
✓ |
Scatter Operations
Function |
Purpose |
JIT-ready |
---|---|---|
|
Scatter-add operation for aggregation |
✓ |
|
Scatter-mean operation for aggregation |
✓ |
|
Scatter-max operation for aggregation |
✓ |
|
Scatter-min operation for aggregation |
✓ |
|
Scatter-std operation for aggregation |
✓ |
|
Scatter-logsumexp for numerical stability |
✓ |
|
Scatter-softmax for attention mechanisms |
✓ |
|
Scatter-log-softmax for attention |
✓ |
Quick Examples
Creating a simple graph:
import jax.numpy as jnp
from jraphx.data import Data
# Create node features and edges
x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]]) # Triangle graph
data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
Batching multiple graphs:
from jraphx.data import Batch
# Create list of graphs
graphs = [
Data(x=jnp.ones((3, 2)), edge_index=jnp.array([[0, 1], [1, 2]])),
Data(x=jnp.ones((4, 2)), edge_index=jnp.array([[0, 1], [2, 3]])),
]
batch = Batch.from_data_list(graphs)
print(f"Batch has {batch.num_graphs} graphs")
Using utilities:
from jraphx.utils import add_self_loops, degree
# Add self-loops
edge_index_with_loops, _ = add_self_loops(edge_index, num_nodes=3)
# Compute degrees
degrees = degree(edge_index[1], num_nodes=3)
print(f"Node degrees: {degrees}")
JIT compilation:
import jax
@jax.jit
def process_graph(data):
from jraphx.utils import add_self_loops
edge_index, _ = add_self_loops(data.edge_index, data.x.shape[0])
return edge_index
processed = process_graph(data)
Working with PyTorch Geometric Datasets
You can easily use PyTorch Geometric datasets with JraphX by converting the data format:
Loading a PyG dataset:
import torch
from torch_geometric.datasets import Cora
import jax.numpy as jnp
from jraphx.data import Data
def pyg_to_jraphx(pyg_data):
"""Convert PyG Data to JraphX Data."""
return Data(
x=jnp.array(pyg_data.x.numpy()),
edge_index=jnp.array(pyg_data.edge_index.numpy()),
y=jnp.array(pyg_data.y.numpy()) if pyg_data.y is not None else None,
edge_attr=jnp.array(pyg_data.edge_attr.numpy()) if pyg_data.edge_attr is not None else None,
)
# Load Cora dataset
dataset = Cora(root='/tmp/Cora')
pyg_data = dataset[0] # Single graph dataset
# Convert to JraphX format
jraphx_data = pyg_to_jraphx(pyg_data)
print(f"Converted graph: {jraphx_data.num_nodes} nodes, {jraphx_data.num_edges} edges")
Batch processing multiple PyG graphs:
from torch_geometric.datasets import TUDataset
from jraphx.data import Batch
# Load graph classification dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# Convert first 10 graphs to JraphX format
jraphx_graphs = []
for i in range(10):
pyg_graph = dataset[i]
jraphx_graph = pyg_to_jraphx(pyg_graph)
jraphx_graphs.append(jraphx_graph)
# Create batch for JraphX processing
batch = Batch.from_data_list(jraphx_graphs)
print(f"Batch contains {batch.num_graphs} graphs")
Training with a PyG dataset:
import jax
import optax
from flax import nnx
from jraphx.nn.models import GCN
from jraphx.nn.pool import global_mean_pool
# Setup model for graph classification
model = GCN(
in_features=dataset.num_node_features,
hidden_features=64,
out_features=dataset.num_classes,
num_layers=3,
rngs=nnx.Rngs(42)
)
optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)
@jax.jit
def train_step(model, optimizer, batch, targets):
def loss_fn(model):
# Node-level predictions
node_predictions = model(batch)
# Pool to graph-level
graph_predictions = global_mean_pool(node_predictions, batch.batch)
# Compute loss
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
graph_predictions, targets
))
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
# Training loop
for epoch in range(100):
# Sample batch of graphs
indices = jnp.arange(min(32, len(dataset))) # Batch size 32
batch_graphs = [pyg_to_jraphx(dataset[i]) for i in indices]
batch = Batch.from_data_list(batch_graphs)
targets = jnp.array([dataset[i].y.item() for i in indices])
loss = train_step(model, optimizer, batch, targets)
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss:.4f}')
Note on Normalization:
JraphX normalization layers (BatchNorm, LayerNorm) follow Flax NNX conventions with the use_running_average
parameter:
from jraphx.nn.norm import BatchNorm
# Create graph-aware batch normalization
bn = BatchNorm(in_features=64, rngs=rngs)
# Training mode: model.train() causes BatchNorm to use use_running_average=False
model.train() # Sets training state
x_train = bn(x, batch=batch) # Automatically computes batch statistics
# Evaluation mode: model.eval() causes BatchNorm to use use_running_average=True
model.eval() # Sets evaluation state
x_eval = bn(x, batch=batch) # Automatically uses running statistics
# Manual control (overrides model state):
x_manual = bn(x, batch=batch, use_running_average=False) # Force batch stats
Common PyG datasets for JraphX:
Dataset |
Type |
Size |
Use Case |
---|---|---|---|
Cora |
Citation Network |
2,708 nodes |
Node classification |
ENZYMES |
Graph Classification |
600 graphs |
Graph classification |
Karate Club |
Social Network |
34 nodes |
Community detection |
QM7 |
Molecular |
7,165 molecules |
Graph regression |
Social Network |
232K nodes |
Large-scale node classification |
Memory-efficient dataset loading:
def lazy_pyg_to_jraphx_converter(dataset):
"""Generator that converts PyG graphs to JraphX format lazily."""
for i in range(len(dataset)):
yield pyg_to_jraphx(dataset[i])
# Use with large datasets to avoid memory issues
large_dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')
# Process in batches
batch_size = 32
for batch_start in range(0, len(large_dataset), batch_size):
batch_end = min(batch_start + batch_size, len(large_dataset))
batch_graphs = [
pyg_to_jraphx(large_dataset[i])
for i in range(batch_start, batch_end)
]
batch = Batch.from_data_list(batch_graphs)
# Process batch...
print(f"Processed batch {batch_start//batch_size + 1}")
This integration allows you to leverage the extensive PyG dataset collection while using JraphX’s JAX-optimized graph neural networks.
Data Augmentation with Flax 0.11.2
Use the new Rngs shorthand methods for data augmentation and preprocessing:
from flax import nnx
def augment_graph_data(data, rngs):
"""Augment graph data using new Rngs shorthand methods."""
# Add random noise to node features (traditional approach)
# noise = random.normal(rngs(), data.x.shape) * 0.1
# Use shorthand methods instead (Flax 0.11.2)
noise = rngs.normal(data.x.shape) * 0.1
x_noisy = data.x + noise
# Add random edge perturbations
num_perturb = 5
new_edges = jnp.stack([
rngs.params.randint((num_perturb,), 0, data.num_nodes),
rngs.params.randint((num_perturb,), 0, data.num_nodes)
])
# Combine original and new edges
edge_index_augmented = jnp.concatenate([data.edge_index, new_edges], axis=1)
return data.replace(x=x_noisy, edge_index=edge_index_augmented)
# Usage with multiple key streams
rngs = nnx.Rngs(0, params=1, dropout=2) # Named key streams
augmented_data = augment_graph_data(original_data, rngs)