Data Structures & Utilities Cheatsheet

This cheatsheet covers JraphX’s core data structures and utility functions for working with graphs.

Note

JraphX is focused on providing GNN layers and utilities for JAX, not datasets. For datasets, you’ll typically load data from external sources (files, other libraries) and convert them to JraphX format.

Core Data Structures

Class

Purpose

Key Methods

Data

Single graph representation

num_nodes, num_edges, keys(), __getitem__

Batch

Multiple graphs in a batch

from_data_list(), num_graphs, batch

Data Attributes

Attribute

Shape

Description

Required

x

[num_nodes, num_features]

Node feature matrix

Optional

edge_index

[2, num_edges]

Edge connectivity in COO format

Optional

edge_attr

[num_edges, num_edge_features]

Edge feature matrix

Optional

y

[num_nodes, *] or [num_graphs, *]

Labels/targets

Optional

pos

[num_nodes, num_dimensions]

Node positions (3D point clouds)

Optional

batch

[num_nodes]

Batch assignment vector

Auto-generated

Graph Utility Functions

Function

Purpose

JIT-ready

add_self_loops()

Add self-loop edges to graph

remove_self_loops()

Remove self-loop edges from graph

degree()

Compute node degrees

in_degree()

Compute in-degrees (directed graphs)

out_degree()

Compute out-degrees (directed graphs)

coalesce()

Remove duplicate edges

to_undirected()

Convert directed to undirected graph

to_dense_adj()

Convert edge_index to dense adjacency

to_edge_index()

Convert dense adjacency to edge_index

Scatter Operations

Function

Purpose

JIT-ready

scatter_add()

Scatter-add operation for aggregation

scatter_mean()

Scatter-mean operation for aggregation

scatter_max()

Scatter-max operation for aggregation

scatter_min()

Scatter-min operation for aggregation

scatter_std()

Scatter-std operation for aggregation

scatter_logsumexp()

Scatter-logsumexp for numerical stability

scatter_softmax()

Scatter-softmax for attention mechanisms

scatter_log_softmax()

Scatter-log-softmax for attention

Quick Examples

Creating a simple graph:

import jax.numpy as jnp
from jraphx.data import Data

# Create node features and edges
x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])  # Triangle graph

data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")

Batching multiple graphs:

from jraphx.data import Batch

# Create list of graphs
graphs = [
    Data(x=jnp.ones((3, 2)), edge_index=jnp.array([[0, 1], [1, 2]])),
    Data(x=jnp.ones((4, 2)), edge_index=jnp.array([[0, 1], [2, 3]])),
]

batch = Batch.from_data_list(graphs)
print(f"Batch has {batch.num_graphs} graphs")

Using utilities:

from jraphx.utils import add_self_loops, degree

# Add self-loops
edge_index_with_loops, _ = add_self_loops(edge_index, num_nodes=3)

# Compute degrees
degrees = degree(edge_index[1], num_nodes=3)
print(f"Node degrees: {degrees}")

JIT compilation:

import jax

@jax.jit
def process_graph(data):
    from jraphx.utils import add_self_loops
    edge_index, _ = add_self_loops(data.edge_index, data.x.shape[0])
    return edge_index

processed = process_graph(data)

Working with PyTorch Geometric Datasets

You can easily use PyTorch Geometric datasets with JraphX by converting the data format:

Loading a PyG dataset:

import torch
from torch_geometric.datasets import Cora
import jax.numpy as jnp
from jraphx.data import Data

def pyg_to_jraphx(pyg_data):
    """Convert PyG Data to JraphX Data."""
    return Data(
        x=jnp.array(pyg_data.x.numpy()),
        edge_index=jnp.array(pyg_data.edge_index.numpy()),
        y=jnp.array(pyg_data.y.numpy()) if pyg_data.y is not None else None,
        edge_attr=jnp.array(pyg_data.edge_attr.numpy()) if pyg_data.edge_attr is not None else None,
    )

# Load Cora dataset
dataset = Cora(root='/tmp/Cora')
pyg_data = dataset[0]  # Single graph dataset

# Convert to JraphX format
jraphx_data = pyg_to_jraphx(pyg_data)
print(f"Converted graph: {jraphx_data.num_nodes} nodes, {jraphx_data.num_edges} edges")

Batch processing multiple PyG graphs:

from torch_geometric.datasets import TUDataset
from jraphx.data import Batch

# Load graph classification dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

# Convert first 10 graphs to JraphX format
jraphx_graphs = []
for i in range(10):
    pyg_graph = dataset[i]
    jraphx_graph = pyg_to_jraphx(pyg_graph)
    jraphx_graphs.append(jraphx_graph)

# Create batch for JraphX processing
batch = Batch.from_data_list(jraphx_graphs)
print(f"Batch contains {batch.num_graphs} graphs")

Training with a PyG dataset:

import jax
import optax
from flax import nnx
from jraphx.nn.models import GCN
from jraphx.nn.pool import global_mean_pool

# Setup model for graph classification
model = GCN(
    in_features=dataset.num_node_features,
    hidden_features=64,
    out_features=dataset.num_classes,
    num_layers=3,
    rngs=nnx.Rngs(42)
)

optimizer = nnx.Optimizer(model, optax.adam(0.01), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, batch, targets):
    def loss_fn(model):
        # Node-level predictions
        node_predictions = model(batch)
        # Pool to graph-level
        graph_predictions = global_mean_pool(node_predictions, batch.batch)
        # Compute loss
        return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
            graph_predictions, targets
        ))

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss

# Training loop
for epoch in range(100):
    # Sample batch of graphs
    indices = jnp.arange(min(32, len(dataset)))  # Batch size 32
    batch_graphs = [pyg_to_jraphx(dataset[i]) for i in indices]
    batch = Batch.from_data_list(batch_graphs)
    targets = jnp.array([dataset[i].y.item() for i in indices])

    loss = train_step(model, optimizer, batch, targets)
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}')

Note on Normalization:

JraphX normalization layers (BatchNorm, LayerNorm) follow Flax NNX conventions with the use_running_average parameter:

from jraphx.nn.norm import BatchNorm

# Create graph-aware batch normalization
bn = BatchNorm(in_features=64, rngs=rngs)

# Training mode: model.train() causes BatchNorm to use use_running_average=False
model.train()  # Sets training state
x_train = bn(x, batch=batch)  # Automatically computes batch statistics

# Evaluation mode: model.eval() causes BatchNorm to use use_running_average=True
model.eval()   # Sets evaluation state
x_eval = bn(x, batch=batch)   # Automatically uses running statistics

# Manual control (overrides model state):
x_manual = bn(x, batch=batch, use_running_average=False)  # Force batch stats

Common PyG datasets for JraphX:

Dataset

Type

Size

Use Case

Cora

Citation Network

2,708 nodes

Node classification

ENZYMES

Graph Classification

600 graphs

Graph classification

Karate Club

Social Network

34 nodes

Community detection

QM7

Molecular

7,165 molecules

Graph regression

Reddit

Social Network

232K nodes

Large-scale node classification

Memory-efficient dataset loading:

def lazy_pyg_to_jraphx_converter(dataset):
    """Generator that converts PyG graphs to JraphX format lazily."""
    for i in range(len(dataset)):
        yield pyg_to_jraphx(dataset[i])

# Use with large datasets to avoid memory issues
large_dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')

# Process in batches
batch_size = 32
for batch_start in range(0, len(large_dataset), batch_size):
    batch_end = min(batch_start + batch_size, len(large_dataset))
    batch_graphs = [
        pyg_to_jraphx(large_dataset[i])
        for i in range(batch_start, batch_end)
    ]
    batch = Batch.from_data_list(batch_graphs)
    # Process batch...
    print(f"Processed batch {batch_start//batch_size + 1}")

This integration allows you to leverage the extensive PyG dataset collection while using JraphX’s JAX-optimized graph neural networks.

Data Augmentation with Flax 0.11.2

Use the new Rngs shorthand methods for data augmentation and preprocessing:

from flax import nnx

def augment_graph_data(data, rngs):
    """Augment graph data using new Rngs shorthand methods."""

    # Add random noise to node features (traditional approach)
    # noise = random.normal(rngs(), data.x.shape) * 0.1

    # Use shorthand methods instead (Flax 0.11.2)
    noise = rngs.normal(data.x.shape) * 0.1
    x_noisy = data.x + noise

    # Add random edge perturbations
    num_perturb = 5
    new_edges = jnp.stack([
        rngs.params.randint((num_perturb,), 0, data.num_nodes),
        rngs.params.randint((num_perturb,), 0, data.num_nodes)
    ])

    # Combine original and new edges
    edge_index_augmented = jnp.concatenate([data.edge_index, new_edges], axis=1)

    return data.replace(x=x_noisy, edge_index=edge_index_augmented)

# Usage with multiple key streams
rngs = nnx.Rngs(0, params=1, dropout=2)  # Named key streams
augmented_data = augment_graph_data(original_data, rngs)