jraphx

Core JraphX functionality and version information.

Core Classes

Data

A data object representing a single graph.

Batch

A batch of graphs represented as a single large disconnected graph.

Version Information

JraphX: Graph Neural Networks with JAX/NNX.

JraphX provides graph neural network layers and utilities for JAX, serving as an unofficial successor to DeepMind’s archived jraph library. It is derived from PyTorch Geometric code and documentation.

Quick Start

JraphX provides a simple, JAX-based interface for graph neural networks:

import jax.numpy as jnp
from jraphx.data import Data
from jraphx.nn.conv import GCNConv
from flax import nnx

# Create a graph
x = jnp.ones((10, 16))
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])
data = Data(x=x, edge_index=edge_index)

# Create and use a GNN layer
layer = GCNConv(16, 32, rngs=nnx.Rngs(42))
output = layer(data.x, data.edge_index)

print(f"Output shape: {output.shape}")

Submodules

data

Graph data structures for JraphX.

nn

Neural network modules for JraphX.

utils

Graph utility functions for JraphX.

JAX Integration

JraphX is designed from the ground up for JAX:

  • All operations are pure functions

  • Full support for @jax.jit compilation

  • Compatible with jax.vmap and nnx.vmap for batching

  • Integrates with jax.grad for automatic differentiation

  • Works seamlessly with Optax optimizers