Working with Graph Data
This guide covers how to work with graph data structures in JraphX. JraphX uses flax.struct.dataclass for its data structures, making them fully compatible with JAX transformations like jit
, vmap
, grad
, and pmap
.
Core Data Classes
|
A data object representing a single graph. |
|
A batch of graphs represented as a single large disconnected graph. |
The Data Class
The Data
class is the fundamental data structure for representing graphs in JraphX.
from jraphx import Data
import jax.numpy as jnp
# Create a graph
data = Data(
x=jnp.array([[1.0], [2.0], [3.0]]), # Node features [num_nodes, num_features]
edge_index=jnp.array([[0, 1, 2], [1, 2, 0]]), # Edge indices [2, num_edges]
y=jnp.array([0]) # Graph label
)
# Access properties
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_node_features}")
Key Characteristics
Immutability
Data objects are immutable to ensure functional purity:
# Cannot modify attributes directly
# data.x = new_x # This will raise an error
# Use replace() to create a modified copy
new_data = data.replace(x=data.x * 2)
JAX Compatibility
Data objects work seamlessly with JAX transformations:
import jax
# JIT compilation
@jax.jit
def process_graph(data):
return data.x.sum()
# Vectorization
batched_process = jax.vmap(process_graph)
# Device placement
data_on_gpu = jax.device_put(data, jax.devices()[0])
PyTree Operations
As registered PyTrees, Data objects support tree operations:
# Apply function to all arrays
data_float32 = jax.tree.map(
lambda x: x.astype(jnp.float32) if x is not None else None,
data
)
Graph Batching
The Batch
class efficiently combines multiple graphs into a single disconnected graph:
from jraphx import Data, Batch
# Create individual graphs
graph1 = Data(
x=jnp.array([[1.0], [2.0]]),
edge_index=jnp.array([[0], [1]]),
y=jnp.array([0])
)
graph2 = Data(
x=jnp.array([[3.0], [4.0], [5.0]]),
edge_index=jnp.array([[0, 1], [1, 2]]),
y=jnp.array([1])
)
# Batch them together
batch = Batch.from_data_list([graph1, graph2])
print(f"Batched nodes: {batch.num_nodes}") # 5 total
print(f"Batched edges: {batch.num_edges}") # 3 total
print(f"Batch vector: {batch.batch}") # [0, 0, 1, 1, 1]
The batch vector indicates which graph each node belongs to, enabling proper pooling operations:
from jraphx.nn.pool import global_mean_pool
# Process batched graphs
node_embeddings = model(batch.x, batch.edge_index)
# Pool to graph-level representations
graph_embeddings = global_mean_pool(node_embeddings, batch.batch)
print(f"Graph embeddings shape: {graph_embeddings.shape}") # [2, hidden_dim]
Extending the Data and Batch Classes
For domain-specific attributes, we’ll subclass both the base Data
and Batch
classes.
The Data subclass will have easy-to-understand additional fields. The corresponding Batch subclass will do the same while also specifying batching behavior using class attributes.
from flax.struct import dataclass
from typing import Optional
import jraphx
@dataclass
class FaceData(Data):
"""Data class for 3D mesh graphs with face connectivity."""
face: jnp.ndarray | None = None # Face connectivity [3, num_faces]
pos: jnp.ndarray | None = None # 3D node positions
normal: jnp.ndarray | None = None # Face normals
face_color: jnp.ndarray | None = None # Face colors
@dataclass
class FaceBatch(jraphx.Batch):
"""Batch class for 3D mesh graphs."""
face: jnp.ndarray | None = None
pos: jnp.ndarray | None = None
normal: jnp.ndarray | None = None
face_color: jnp.ndarray | None = None
# Configure batching behavior as class attributes
NODE_INDEX_FIELDS = {'face'}
ELEMENT_LEVEL_FIELDS = {'normal', 'face_color', 'pos'}
_DATA_CLASS = FaceData # Link for unbatching
def __repr__(self) -> str:
"""Use the nice shape-based representation from parent class."""
return jraphx.Batch.__repr__(self)
# Create mesh graphs
mesh1 = FaceData(
x=jnp.ones((4, 3)), # 4 vertices
face=jnp.array([[0, 1, 2], [1, 2, 3]]).T, # 2 triangular faces
normal=jnp.array([[0., 0., 1.], [0., 1., 0.]]), # Face normals
face_color=jnp.array([[1., 0., 0.], [0., 1., 0.]]) # Red and green
)
mesh2 = FaceData(
x=jnp.ones((3, 3)), # 3 vertices
face=jnp.array([[0, 1, 2]]).T, # 1 triangular face
normal=jnp.array([[1., 0., 0.]]), # Face normal
face_color=jnp.array([[0., 0., 1.]]) # Blue
)
# Batch them together
batch = FaceBatch.from_data_list([mesh1, mesh2])
# Unbatch
meshes = batch.to_data_list() # Returns list of FaceData objects
The batching system provides three configuration options:
NODE_INDEX_FIELDS: Fields containing node indices that need adjustment during batching (like
edge_index
orface
)ELEMENT_LEVEL_FIELDS: Fields that are element-level features aligned with a node index field (concatenated during batching)
GRAPH_LEVEL_FIELDS: Fields that are per-graph attributes (stacked, not concatenated)
Example: Molecular Graphs
@dataclass
class MolecularData(Data):
"""Data class for molecular graphs."""
bond_index: jnp.ndarray | None = None # Bond connectivity
bond_type: jnp.ndarray | None = None # Bond type features
atom_charge: jnp.ndarray | None = None # Node-level charges
mol_weight: float | None = None # Graph-level property
@dataclass
class MolecularBatch(jraphx.Batch):
"""Batch class for molecular graphs."""
bond_index: jnp.ndarray | None = None
bond_type: jnp.ndarray | None = None
atom_charge: jnp.ndarray | None = None
mol_weight: jnp.ndarray | None = None
# Configure batching behavior as class attributes
NODE_INDEX_FIELDS = {'bond_index'}
ELEMENT_LEVEL_FIELDS = {'bond_type', 'atom_charge'}
GRAPH_LEVEL_FIELDS = {'mol_weight'} # Per-molecule property
_DATA_CLASS = MolecularData # Link for unbatching
def __repr__(self) -> str:
"""Use the nice shape-based representation from parent class."""
return jraphx.Batch.__repr__(self)
# Create molecules
mol1 = MolecularData(
x=jnp.array([[6.], [1.], [1.]]), # C, H, H
edge_index=jnp.array([[0, 0], [1, 2]]),
bond_index=jnp.array([[0, 0], [1, 2]]),
bond_type=jnp.array([[1.], [1.]]), # Single bonds
atom_charge=jnp.array([0., 0., 0.]),
mol_weight=16.04
)
mol2 = MolecularData(
x=jnp.array([[8.], [1.]]), # O, H
edge_index=jnp.array([[0], [1]]),
bond_index=jnp.array([[0], [1]]),
bond_type=jnp.array([[1.]]), # Single bond
atom_charge=jnp.array([-0.5, 0.5]),
mol_weight=17.01
)
# Batch molecules
batch = MolecularBatch.from_data_list([mol1, mol2])
# Access graph-level properties
print(f"Molecular weights: {batch.mol_weight}") # [16.04, 17.01]
Working with PyTorch Geometric
When converting from PyTorch Geometric datasets, create a custom Data class:
@dataclass
class PyGData(Data):
"""Data class compatible with PyTorch Geometric datasets."""
train_mask: jnp.ndarray | None = None
val_mask: jnp.ndarray | None = None
test_mask: jnp.ndarray | None = None
edge_attr: jnp.ndarray | None = None
def from_pyg(pyg_data):
"""Convert PyTorch Geometric data to JraphX format."""
return PyGData(
x=jnp.array(pyg_data.x.numpy()),
edge_index=jnp.array(pyg_data.edge_index.numpy()),
y=jnp.array(pyg_data.y.numpy()),
train_mask=jnp.array(pyg_data.train_mask.numpy()),
val_mask=jnp.array(pyg_data.val_mask.numpy()),
test_mask=jnp.array(pyg_data.test_mask.numpy()),
edge_attr=jnp.array(pyg_data.edge_attr.numpy())
if hasattr(pyg_data, 'edge_attr') else None
)
Common Patterns
Device Management
# Move entire graph to GPU
device = jax.devices('gpu')[0]
data_gpu = jax.device_put(data, device)
# Check device placement
print(f"Data on device: {data_gpu.x.device()}")
Preprocessing
from functools import partial
@partial(jax.jit, donate_argnums=0)
def normalize_features(data: Data) -> Data:
"""Normalize node features to zero mean and unit variance."""
x = data.x # [num_nodes, num_node_features]
mean = x.mean(axis=0, keepdims=True)
std = x.std(axis=0, keepdims=True)
x_normalized = (x - mean) / (std + 1e-6)
return data.replace(x=x_normalized)
# Apply normalization
data_normalized = normalize_features(data)
Data Augmentation
from functools import partial
@partial(jax.jit, donate_argnums=0)
def add_noise(data: Data, rng: jax.Array, noise_level: float = 0.1) -> Data:
"""Add Gaussian noise to node features."""
noise = random.normal(rng, data.x.shape) * noise_level
return data.replace(x=data.x + noise)
@partial(jax.jit, donate_argnums=0)
def drop_edges(data: Data, rng: jax.Array, drop_rate: float = 0.1) -> Data:
"""Randomly drop edges for augmentation."""
num_edges = data.edge_index.shape[1]
mask = random.bernoulli(rng, shape=(num_edges,), p=1-drop_rate)
new_edge_index = data.edge_index[:, mask]
return data.replace(edge_index=new_edge_index)
Performance Considerations
Memory Efficiency
Immutability: Creates new objects for modifications, but JAX’s XLA compiler optimizes this. Consider using
donate_argnums
/donate_argnames
withjax.jit
/nnx.jit
and related functions.PyTree operations: Very efficient for batch operations
Subclassing: No overhead - only stores defined attributes
JIT Compilation
@nnx.jit
def efficient_forward(data: Data):
# All operations on Data work with JIT
return model(data.x, data.edge_index)
Large Graphs
For very large graphs that don’t fit in memory:
def process_large_graph_in_chunks(data: Data, chunk_size: int = 1000):
"""Process large graphs in chunks using scan."""
num_nodes = data.num_nodes
num_chunks = (num_nodes + chunk_size - 1) // chunk_size
def process_chunk(carry, chunk_idx):
start = chunk_idx * chunk_size
end = min(start + chunk_size, num_nodes)
chunk_x = data.x[start:end]
# Process chunk...
return carry, chunk_output
_, outputs = jax.lax.scan(process_chunk, None, jnp.arange(num_chunks))
return outputs
Best Practices
Always subclass Data for domain-specific attributes rather than trying to modify instances
Use Optional types for attributes that may not always be present
Leverage immutability for reproducible and debuggable code
Use replace() method for creating modified instances
Take advantage of PyTree operations for efficient batch processing
Prefer JAX arrays over Python lists or NumPy arrays for all tensor data
Troubleshooting
Common Issues
AttributeError when setting attributes
# Wrong
data.custom_attr = value # Raises AttributeError
# Right - subclass Data
@dataclass
class MyData(Data):
custom_attr: jnp.ndarray | None = None
data = MyData(x=x, edge_index=edges, custom_attr=value)
Type errors with JAX transforms
Ensure all attributes are JAX-compatible types or mark non-JAX attributes:
from flax import struct
@dataclass
class DataWithMetadata(Data):
# JAX array - will be traced
features: jnp.ndarray | None = None
# Non-JAX metadata - won't be traced
name: str = struct.field(pytree_node=False, default="")
See Also
Practical Examples - Tutorials demonstrating Data usage
Introduction by Example - Introduction to JraphX concepts
Flax Struct Documentation - Details on flax.struct.dataclass