Changelog
Version 0.0.4
Breaking Changes
Updated minimum Flax requirement to 0.12.0 for improved pytree handling:
Now uses
nnx.Listfor module lists
Now uses
nnx.data(None)for optional module attributes
Version 0.0.3
Initial release of JraphX.
Features
Core Data Structures
Dataclass: Single graph representation with node features, edge indices, edge attributes, and graph-level propertiesBatchclass: Efficient batching of multiple graphs into disconnected graph batches with automatic indexing management
Message Passing Framework
Unified
MessagePassingbase class providing a standardized interface for all graph neural network layersFlexible message computation, aggregation (sum, mean, max, min), and node update functions
Support for both node-to-node and edge-enhanced message passing paradigms
Graph Convolution Layers
GCNConv: Graph Convolutional Network with spectral-based convolution and optional edge weightsGATConv: Graph Attention Network with multi-head attention mechanism and learnable attention weightsGATv2Conv: Improved Graph Attention Network with enhanced attention computation for better expressivityGraphSAGE(SAGEConv): GraphSAGE with multiple aggregation functions (mean, max, LSTM) for inductive learningGINConv: Graph Isomorphism Network with theoretical guarantees for graph representation powerEdgeConv: Dynamic edge convolution for learning on point clouds and dynamic graph constructionDynamicEdgeConv: Enhanced EdgeConv with k-nearest neighbor graph constructionTransformerConv: Graph Transformer layer with optimized query-key-value projections and positional encodings
Pooling Operations
Global pooling:
global_add_pool,global_mean_pool,global_max_pool,global_min_poolfor graph-level representationsAdvanced pooling:
global_softmax_pool,global_sort_poolfor differentiable and sorted aggregationsHierarchical pooling:
TopKPoolingandSAGPoolingfor coarsening graph structures with learnable node selectionBatched operations: Optimized versions (
batched_global_*_pool) for efficient parallel processing of graph batches
Utility Functions
Scatter operations: Comprehensive set including
scatter_add,scatter_mean,scatter_max,scatter_min,scatter_std,scatter_logsumexpfor flexible aggregationScatter softmax:
scatter_softmax,scatter_log_softmax,masked_scatter_softmaxfor attention-like mechanismsGraph utilities: Degree computation (
degree,in_degree,out_degree), self-loop management (add_self_loops,remove_self_loops)Conversion functions:
to_dense_adj,to_edge_index,to_undirectedfor different graph representationsGraph preprocessing:
coalescefor edge deduplication,maybe_num_nodesfor automatic node count inference
Pre-built Models
GCN,GAT,GraphSAGE,GIN: Complete model implementations with configurable depth, hidden dimensions, and activation functionsJumpingKnowledge: Multi-layer aggregation with concatenation, max, and LSTM-based combination strategiesMLP: Multi-layer perceptron with dropout, batch normalization, and flexible activation functionsBasicGNN: Abstract base class for implementing custom GNN architectures with standardized interfaces
Normalization Layers
BatchNorm: Batch normalization with running statistics for stable training across graph batchesLayerNorm: Layer normalization supporting both node-wise and graph-wise normalization schemesGraphNorm: Graph-specific normalization designed for graph neural network architectures
JAX Integration & Performance
Extensive use of
jax.vmapandnnx.vmapfor efficient parallel processing of graph batchesMemory-efficient training patterns using
jax.lax.scanandnnx.scanfor sequential operationsJIT compilation support for all operations with optimized JAX primitives
Efficient scatter operations using JAX’s advanced indexing (
at[].add/max/min) for high-performance aggregation