Missing Features in JraphX

This document tracks PyTorch Geometric features that are not yet implemented in JraphX.

High Priority (Core GNN functionality)

Convolution Layers

  • AGNNConv - Attention-based Graph Neural Network

  • APPNP - Approximate Personalized Propagation of Neural Predictions

  • ARMAConv - ARMA filters on graphs

  • CGConv - Crystal Graph Convolutional Networks

  • ChebConv - Chebyshev spectral graph convolution

  • ClusterGCNConv - Cluster-GCN

  • DNAConv - Dynamic Network Architecture

  • FastRGCNConv - Fast Relational Graph Convolutional Networks

  • FeaStConv - Feature-Steered Convolution

  • FiLMConv - Feature-wise Linear Modulation

  • GCN2Conv - Simple and Deep Graph Convolutional Networks

  • GENConv - Generalized Graph Convolutional Networks

  • GeneralConv - General GNN layer

  • GPSConv - Graph Transformer with Positional and Structural Encoding

  • GravNetConv - GravNet layer for point clouds

  • HGTConv - Heterogeneous Graph Transformer

  • HypergraphConv - Hypergraph Convolution

  • LEConv - Local Extremum Graph Neural Networks

  • LGConv - Light Graph Convolution

  • MFConv - Molecular Fingerprint Convolution

  • NNConv - Continuous kernel-based convolution

  • PANConv - Path Augmented Graph Neural Networks

  • PDNConv - Pathfinder Discovery Networks

  • PNAConv - Principal Neighbourhood Aggregation

  • PointConv - Point Convolution for 3D

  • PPFConv - Point Pair Feature Convolution

  • RGCNConv - Relational Graph Convolutional Networks

  • RGATConv - Relational Graph Attention Networks

  • ResGatedGraphConv - Residual Gated Graph ConvNets

  • SGConv - Simplifying Graph Convolutional Networks

  • SignedConv - Signed Graph Convolutional Networks

  • SplineConv - Spline-based convolution

  • SuperGATConv - SuperGAT

  • TAGConv - Topology Adaptive Graph Convolutional Networks

  • TWirls - Trainable Wishart Relational Networks

  • XConv - PointNet++ XConv layer

Aggregation Functions

  • Aggregation Module - Advanced aggregation functions

  • MultiAggregation - Multiple aggregation combination

  • AttentionalAggregation - Attention-based aggregation

  • DeepSetsAggregation - DeepSets aggregation

  • DegreeScalerAggregation - Degree-based scaling

  • EquilibriumAggregation - Equilibrium-based aggregation

  • GraphMultisetTransformer - Graph Multiset Transformer

  • LSTMAggregation - LSTM-based aggregation

  • MLPAggregation - MLP aggregation

  • PowerMeanAggregation - Power mean aggregation

  • Set2Set - Set2Set aggregation

  • SoftmaxAggregation - Softmax aggregation

  • SortAggregation - Sort aggregation

  • VarAggregation - Variance aggregation

Medium Priority (Advanced Features)

Pooling Layers

  • ASAPooling - Adaptive Structure Aware Pooling

  • EdgePooling - Edge-based pooling

  • GCNPool - GCN-based pooling

  • GlobalAttention - Global attention pooling

  • GraphSAINTSampler - GraphSAINT sampling

  • HitAndRun - Hit and Run sampling

  • MaxPooling - Max pooling on graphs

  • MemPooling - Memory-based pooling

  • NodeSAINTSampler - Node-based GraphSAINT sampling

  • PANPooling - Path Augmented Pooling

Pre-built Models

  • AttentiveFP - Attentive Fingerprinting

  • BASIC_GNN - Enhanced basic model variations

  • DeepGCN - Deep Graph Convolutional Networks

  • DeepGraphInfomax - Deep Graph Infomax

  • DiffPool - Differentiable Pooling

  • GAE - Graph Autoencoders

  • VGAE - Variational Graph Autoencoders

  • GCN - Enhanced versions

  • GraphSAGE - Enhanced versions

  • GraphUNet - Graph U-Net

  • JK-Net - Enhanced Jumping Knowledge Networks

  • MetaPath2Vec - MetaPath2Vec for heterogeneous graphs

  • Node2Vec - Node2Vec embeddings

  • PNA - Principal Neighbourhood Aggregation networks

  • SchNet - SchNet for molecular property prediction

  • TGN - Temporal Graph Networks

Normalization Layers

  • DiffGroupNorm - Differentiable Group Normalization

  • InstanceNorm - Instance Normalization

  • MessageNorm - Message Normalization

  • PairNorm - Pair Normalization

JAX/JraphX Specific Limitations

k-NN Graph Construction

  • torch-cluster integration - PyTorch Geometric’s DynamicEdgeConv uses torch_cluster.knn() for automatic k-nearest neighbor computation from node features. JraphX’s DynamicEdgeConv is a simplified version that requires pre-computed k-NN indices as input.

  • Dynamic graph construction - Full dynamic graph construction would require a JAX-native k-NN implementation, which is not currently available.

Lower Priority (Specialized Features)

Knowledge Graph Embeddings

  • ComplEx - Complex embeddings

  • DistMult - DistMult embeddings

  • HolE - Holographic embeddings

  • KGEModel - Base class for KG embeddings

  • PairRE - Paired relation embeddings

  • RotatE - Rotation-based embeddings

  • TransE - Translation embeddings

Dense Layers

  • DenseGCNConv - Dense GCN convolution

  • DenseGINConv - Dense GIN convolution

  • DenseGraphConv - Dense graph convolution

  • DenseSAGEConv - Dense SAGE convolution

  • LinearTransformation - Dense linear layers

Functional Operations

  • dropout - Graph-aware dropout

  • gumbel_softmax - Gumbel softmax for graphs

  • local_graph_clustering - Local clustering

  • pagerank - PageRank algorithm

  • subgraph - Subgraph sampling

Transforms (Not Core but Useful)

  • AddSelfLoops - Add self-loops transform

  • Compose - Transform composition

  • NormalizeFeatures - Feature normalization

  • RandomNodeSplit - Random node splitting

  • RemoveIsolatedNodes - Remove isolated nodes

  • ToDevice - Device placement transform

  • ToSparseTensor - Sparse tensor conversion

Data Loading & Sampling

  • DataLoader - Graph data loading

  • NeighborSampler - Neighborhood sampling

  • RandomWalkSampler - Random walk sampling

  • ShaDowKHopSampler - ShaDow k-hop sampling

Datasets (Not Applicable - JAX doesn’t need this)

❌ All dataset classes (TUDataset, Planetoid, etc.) - Not relevant for JAX-only library

Distributed Training (Future Consideration)

  • DistributedSampler - For future JAX distributed training

  • GraphSAINT - Distributed sampling strategies

Features Deliberately Omitted

PyTorch-Specific

DataParallel - JAX uses different parallelization

torch.compile integration - JAX uses jit instead

SparseTensor support - JAX has different sparse support

CUDA-specific operations - might be heavy lift

Framework-Specific

Heterogeneous graphs - Complex feature, may not fit JAX patterns

Explainability modules - Separate concern

NLP modules - Out of scope

Remote backend - PyG-specific

Implementation Status Legend

  • Implemented - Available in JraphX

  • Planned - Should be implemented

  • Omitted - Deliberately not implementing

Removed Documentation Files

The following PyTorch Geometric documentation files have been removed from JraphX as they are not applicable to a JAX-based GNN library:

Advanced Concepts (Removed)

  • cpu_affinity.rst - PyTorch-specific CPU affinity settings

  • graphgym.rst - GraphGym framework (PyTorch ecosystem)

  • hgam.rst - Heterogeneous Graph Attention Memory (not implemented)

  • remote.rst - Remote backend for PyTorch Geometric

  • sparse_tensor.rst - PyTorch sparse tensor integration

Module Documentation (Removed)

  • contrib.rst - Community contributions (PyTorch-specific)

  • datasets.rst - Dataset loading (JraphX uses external datasets)

  • distributed.rst - Distributed training (PyTorch-specific)

  • explain.rst - Model explainability (separate concern)

  • graphgym.rst - GraphGym configuration system

  • loader.rst - Data loading utilities (not needed for JAX)

  • metrics.rst - Evaluation metrics (use external libraries)

  • profile.rst - Performance profiling (JAX has its own tools)

  • sampler.rst - Graph sampling utilities (not implemented)

  • transforms.rst - Data transforms (JAX uses functional preprocessing)

Tutorial Documentation (Removed)

  • application.rst - Application-specific tutorials

  • compile.rst - torch.compile integration (JAX uses jit)

  • create_dataset.rst - Dataset creation (not JraphX’s scope)

  • dataset_splitting.rst - Dataset splitting utilities

  • dataset.rst - Dataset handling

  • distributed_pyg.rst - Distributed PyTorch Geometric

  • distributed.rst - Distributed training

  • explain.rst - Model explainability

  • graph_transformer.rst - Advanced transformer architectures (not implemented)

  • heterogeneous.rst - Heterogeneous graph processing (not implemented)

  • load_csv.rst - CSV loading utilities

  • multi_gpu_vanilla.rst - Multi-GPU training (PyTorch-specific)

  • multi_node_multi_gpu_vanilla.rst - Multi-node training (PyTorch-specific)

  • neighbor_loader.rst - Neighborhood sampling (not implemented)

  • point_cloud.rst - Point cloud processing (limited support)

  • shallow_node_embeddings.rst - Node embedding methods (not implemented)

Rationale for Removal

  • Framework Mismatch: PyTorch-specific features that don’t apply to JAX

  • Scope Limitation: JraphX focuses on core GNN layers, not entire ML pipelines

  • Unimplemented Features: Advanced features not yet available in JraphX

  • External Dependencies: Features that rely on PyTorch ecosystem

Removed Figure Files

The following figure files from docs/source/_figures/ have been removed as they were not referenced in the JraphX documentation:

  • architecture.pdf / architecture.svg - PyTorch Geometric architecture diagrams

  • dist_part.png / dist_proc.png / dist_sampling.png - Distributed training figures (PyTorch-specific)

  • graphgps_layer.png - GraphGPS layer architecture (not implemented)

  • graphgym_design_space.png / graphgym_evaluation.png / graphgym_results.png - GraphGym framework figures

  • hg_example.svg / hg_example.tex - Heterogeneous graph examples (not implemented)

  • intel_kumo.png - Intel optimization figures (not applicable)

  • meshcnn_edge_adjacency.svg - MeshCNN figures (not implemented)

  • point_cloud1.png - point_cloud4.png - Point cloud examples (limited support)

  • remote_1.png - remote_3.png - Remote backend figures (not applicable)

  • shallow_node_embeddings.png - Node embedding figures (not implemented)

  • to_hetero.svg / to_hetero.tex / to_hetero_with_bases.svg / to_hetero_with_bases.tex - Heterogeneous graph conversion (not implemented)

  • training_affinity.png - CPU affinity training (PyTorch-specific)

Kept Figure Files

Only the essential figures were retained:

  • graph.svg / graph.tex - Basic graph visualization used in introduction tutorial

  • build.sh - Figure generation script

Kept Documentation Files

The following files were retained and translated to JraphX:

  • Core tutorials: create_gnn.rst, gnn_design.rst (JAX integration)

  • Essential concepts: batching.rst, jit.rst, compile.rst

  • API reference: nn.rst, data.rst, utils.rst, root.rst

  • Cheatsheets: gnn_cheatsheet.rst, data_cheatsheet.rst

  • Getting started: introduction.rst, installation.rst

Notes

  • Priority is based on common usage patterns and core GNN functionality

  • JAX-specific optimizations should be added where applicable (jit, vmap, scan)

  • Some features may need significant adaptation for JAX/NNX paradigms

  • Documentation cleanup focused on maintaining only relevant, translated content