Installation ============ **JraphX** is a Graph Neural Network library for JAX/Flax NNX, available for Python 3.11+. Quick Start ----------- **Option 1: Install from PyPI (Recommended)** Install JraphX directly from PyPI: .. code-block:: bash pip install jraphx This will automatically install the required dependencies: JAX, Flax, and NumPy. **Option 2: JAX AI Stack + JraphX** The `JAX AI Stack `__ provides a curated collection of JAX, Flax, Optax, and other ML libraries. After installing it, you can add JraphX: .. code-block:: bash pip install jax-ai-stack pip install jraphx .. note:: **Current Compatibility Issue:** This approach currently won't work because JraphX requires Flax 0.11.2 or higher, but jax-ai-stack 2025.9.3 is pinned to exactly Flax 0.11.1. We're waiting for JAX AI Stack to update their Flax version. Development Installation ------------------------ For development or to get the latest features, clone the repository and install in development mode: .. code-block:: bash git clone https://github.com/DBraun/jraphx.git cd jraphx pip install -e . You can also use this approach after installing JAX AI Stack (once the Flax compatibility issue is resolved): .. code-block:: bash pip install jax-ai-stack # once Flax 0.11.2+ is supported git clone https://github.com/DBraun/jraphx.git cd jraphx pip install -e . Verification ------------ To verify your installation is working correctly: .. code-block:: python import jax import jax.numpy as jnp from flax import nnx import jraphx print(f"JAX version: {jax.__version__}") print(f"JAX backend: {jax.default_backend()}") print(f"JraphX version: {jraphx.__version__}") # Test basic functionality from jraphx.data import Data from jraphx.nn.conv import GCNConv # Create a simple graph data = Data( x=jnp.ones((3, 4)), edge_index=jnp.array([[0, 1, 2], [1, 2, 0]]) ) # Create and use a GNN layer layer = GCNConv(4, 8, rngs=nnx.Rngs(42)) output = layer(data.x, data.edge_index) print(f"Successfully processed graph: {output.shape}") Troubleshooting --------------- **Import Error:** If you get "No module named 'jraphx'", make sure you installed with `pip install -e .` from the jraphx directory. **JAX Issues:** Refer to the `JAX installation guide `__ for platform-specific troubleshooting. For other issues, please create an issue on the `JraphX GitHub repository `__.