Changelog ========= Version 0.0.3 ------------- Initial release of JraphX. Features ~~~~~~~~ **Core Data Structures** * ``Data`` class: Single graph representation with node features, edge indices, edge attributes, and graph-level properties * ``Batch`` class: Efficient batching of multiple graphs into disconnected graph batches with automatic indexing management **Message Passing Framework** * Unified ``MessagePassing`` base class providing a standardized interface for all graph neural network layers * Flexible message computation, aggregation (sum, mean, max, min), and node update functions * Support for both node-to-node and edge-enhanced message passing paradigms **Graph Convolution Layers** * ``GCNConv``: Graph Convolutional Network with spectral-based convolution and optional edge weights * ``GATConv``: Graph Attention Network with multi-head attention mechanism and learnable attention weights * ``GATv2Conv``: Improved Graph Attention Network with enhanced attention computation for better expressivity * ``GraphSAGE`` (``SAGEConv``): GraphSAGE with multiple aggregation functions (mean, max, LSTM) for inductive learning * ``GINConv``: Graph Isomorphism Network with theoretical guarantees for graph representation power * ``EdgeConv``: Dynamic edge convolution for learning on point clouds and dynamic graph construction * ``DynamicEdgeConv``: Enhanced EdgeConv with k-nearest neighbor graph construction * ``TransformerConv``: Graph Transformer layer with optimized query-key-value projections and positional encodings **Pooling Operations** * **Global pooling**: ``global_add_pool``, ``global_mean_pool``, ``global_max_pool``, ``global_min_pool`` for graph-level representations * **Advanced pooling**: ``global_softmax_pool``, ``global_sort_pool`` for differentiable and sorted aggregations * **Hierarchical pooling**: ``TopKPooling`` and ``SAGPooling`` for coarsening graph structures with learnable node selection * **Batched operations**: Optimized versions (``batched_global_*_pool``) for efficient parallel processing of graph batches **Utility Functions** * **Scatter operations**: Comprehensive set including ``scatter_add``, ``scatter_mean``, ``scatter_max``, ``scatter_min``, ``scatter_std``, ``scatter_logsumexp`` for flexible aggregation * **Scatter softmax**: ``scatter_softmax``, ``scatter_log_softmax``, ``masked_scatter_softmax`` for attention-like mechanisms * **Graph utilities**: Degree computation (``degree``, ``in_degree``, ``out_degree``), self-loop management (``add_self_loops``, ``remove_self_loops``) * **Conversion functions**: ``to_dense_adj``, ``to_edge_index``, ``to_undirected`` for different graph representations * **Graph preprocessing**: ``coalesce`` for edge deduplication, ``maybe_num_nodes`` for automatic node count inference **Pre-built Models** * ``GCN``, ``GAT``, ``GraphSAGE``, ``GIN``: Complete model implementations with configurable depth, hidden dimensions, and activation functions * ``JumpingKnowledge``: Multi-layer aggregation with concatenation, max, and LSTM-based combination strategies * ``MLP``: Multi-layer perceptron with dropout, batch normalization, and flexible activation functions * ``BasicGNN``: Abstract base class for implementing custom GNN architectures with standardized interfaces **Normalization Layers** * ``BatchNorm``: Batch normalization with running statistics for stable training across graph batches * ``LayerNorm``: Layer normalization supporting both node-wise and graph-wise normalization schemes * ``GraphNorm``: Graph-specific normalization designed for graph neural network architectures **JAX Integration & Performance** * Extensive use of ``jax.vmap`` and ``nnx.vmap`` for efficient parallel processing of graph batches * Memory-efficient training patterns using ``jax.lax.scan`` and ``nnx.scan`` for sequential operations * JIT compilation support for all operations with optimized JAX primitives * Efficient scatter operations using JAX's advanced indexing (``at[].add/max/min``) for high-performance aggregation