Changelog
Version 0.0.3
Initial release of JraphX.
Features
Core Data Structures
Data
class: Single graph representation with node features, edge indices, edge attributes, and graph-level propertiesBatch
class: Efficient batching of multiple graphs into disconnected graph batches with automatic indexing management
Message Passing Framework
Unified
MessagePassing
base class providing a standardized interface for all graph neural network layersFlexible message computation, aggregation (sum, mean, max, min), and node update functions
Support for both node-to-node and edge-enhanced message passing paradigms
Graph Convolution Layers
GCNConv
: Graph Convolutional Network with spectral-based convolution and optional edge weightsGATConv
: Graph Attention Network with multi-head attention mechanism and learnable attention weightsGATv2Conv
: Improved Graph Attention Network with enhanced attention computation for better expressivityGraphSAGE
(SAGEConv
): GraphSAGE with multiple aggregation functions (mean, max, LSTM) for inductive learningGINConv
: Graph Isomorphism Network with theoretical guarantees for graph representation powerEdgeConv
: Dynamic edge convolution for learning on point clouds and dynamic graph constructionDynamicEdgeConv
: Enhanced EdgeConv with k-nearest neighbor graph constructionTransformerConv
: Graph Transformer layer with optimized query-key-value projections and positional encodings
Pooling Operations
Global pooling:
global_add_pool
,global_mean_pool
,global_max_pool
,global_min_pool
for graph-level representationsAdvanced pooling:
global_softmax_pool
,global_sort_pool
for differentiable and sorted aggregationsHierarchical pooling:
TopKPooling
andSAGPooling
for coarsening graph structures with learnable node selectionBatched operations: Optimized versions (
batched_global_*_pool
) for efficient parallel processing of graph batches
Utility Functions
Scatter operations: Comprehensive set including
scatter_add
,scatter_mean
,scatter_max
,scatter_min
,scatter_std
,scatter_logsumexp
for flexible aggregationScatter softmax:
scatter_softmax
,scatter_log_softmax
,masked_scatter_softmax
for attention-like mechanismsGraph utilities: Degree computation (
degree
,in_degree
,out_degree
), self-loop management (add_self_loops
,remove_self_loops
)Conversion functions:
to_dense_adj
,to_edge_index
,to_undirected
for different graph representationsGraph preprocessing:
coalesce
for edge deduplication,maybe_num_nodes
for automatic node count inference
Pre-built Models
GCN
,GAT
,GraphSAGE
,GIN
: Complete model implementations with configurable depth, hidden dimensions, and activation functionsJumpingKnowledge
: Multi-layer aggregation with concatenation, max, and LSTM-based combination strategiesMLP
: Multi-layer perceptron with dropout, batch normalization, and flexible activation functionsBasicGNN
: Abstract base class for implementing custom GNN architectures with standardized interfaces
Normalization Layers
BatchNorm
: Batch normalization with running statistics for stable training across graph batchesLayerNorm
: Layer normalization supporting both node-wise and graph-wise normalization schemesGraphNorm
: Graph-specific normalization designed for graph neural network architectures
JAX Integration & Performance
Extensive use of
jax.vmap
andnnx.vmap
for efficient parallel processing of graph batchesMemory-efficient training patterns using
jax.lax.scan
andnnx.scan
for sequential operationsJIT compilation support for all operations with optimized JAX primitives
Efficient scatter operations using JAX’s advanced indexing (
at[].add/max/min
) for high-performance aggregation