jraphx
Core JraphX functionality and version information.
Core Classes
|
A data object representing a single graph. |
|
A batch of graphs represented as a single large disconnected graph. |
Version Information
JraphX: Graph Neural Networks with JAX/NNX.
JraphX provides graph neural network layers and utilities for JAX, serving as an unofficial successor to DeepMind’s archived jraph library. It is derived from PyTorch Geometric code and documentation.
Quick Start
JraphX provides a simple, JAX-based interface for graph neural networks:
import jax.numpy as jnp
from jraphx.data import Data
from jraphx.nn.conv import GCNConv
from flax import nnx
# Create a graph
x = jnp.ones((10, 16))
edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])
data = Data(x=x, edge_index=edge_index)
# Create and use a GNN layer
layer = GCNConv(16, 32, rngs=nnx.Rngs(42))
output = layer(data.x, data.edge_index)
print(f"Output shape: {output.shape}")
Submodules
|
Graph data structures for JraphX. |
|
Neural network modules for JraphX. |
|
Graph utility functions for JraphX. |
JAX Integration
JraphX is designed from the ground up for JAX:
All operations are pure functions
Full support for @jax.jit compilation
Compatible with jax.vmap and nnx.vmap for batching
Integrates with jax.grad for automatic differentiation
Works seamlessly with Optax optimizers