JraphX Documentation

JraphX is a Graph Neural Network (GNN) library for JAX/Flax NNX, designed as an unofficial successor to DeepMind’s archived jraph library. It provides a PyTorch Geometric-inspired API while leveraging the JAX ecosystem’s strengths in JIT compilation, sharding, and more.

Note

Attribution Notice: JraphX builds upon and incorporates code from multiple open-source projects:

  • PyTorch Geometric (MIT License, Copyright (c) 2023 PyG Team): JraphX contains substantial portions of code and documentation derived from PyTorch Geometric.

  • Flax (Apache License 2.0): The Flax NNX library made JraphX’s implementation significantly easier.

  • Jraph (Apache License 2.0): DeepMind’s original JAX GNN library, which is now archived.

We are grateful to all development teams for creating these foundational libraries that make JraphX possible.

JraphX consists of various methods for deep learning on graphs and other irregular structures, implementing core GNN layers and utilities with JAX and Flax/NNX. It features efficient message passing frameworks, vectorized operations using nnx.vmap, sequential processing with nnx.scan, and seamless integration with the JAX ecosystem including automatic differentiation and JIT compilation.

Install JraphX

Package Reference

Project Info