Missing Features in JraphX
This document tracks PyTorch Geometric features that are not yet implemented in JraphX.
High Priority (Core GNN functionality)
Convolution Layers
AGNNConv - Attention-based Graph Neural Network
APPNP - Approximate Personalized Propagation of Neural Predictions
ARMAConv - ARMA filters on graphs
CGConv - Crystal Graph Convolutional Networks
ChebConv - Chebyshev spectral graph convolution
ClusterGCNConv - Cluster-GCN
DNAConv - Dynamic Network Architecture
FastRGCNConv - Fast Relational Graph Convolutional Networks
FeaStConv - Feature-Steered Convolution
FiLMConv - Feature-wise Linear Modulation
GCN2Conv - Simple and Deep Graph Convolutional Networks
GENConv - Generalized Graph Convolutional Networks
GeneralConv - General GNN layer
GPSConv - Graph Transformer with Positional and Structural Encoding
GravNetConv - GravNet layer for point clouds
HGTConv - Heterogeneous Graph Transformer
HypergraphConv - Hypergraph Convolution
LEConv - Local Extremum Graph Neural Networks
LGConv - Light Graph Convolution
MFConv - Molecular Fingerprint Convolution
NNConv - Continuous kernel-based convolution
PANConv - Path Augmented Graph Neural Networks
PDNConv - Pathfinder Discovery Networks
PNAConv - Principal Neighbourhood Aggregation
PointConv - Point Convolution for 3D
PPFConv - Point Pair Feature Convolution
RGCNConv - Relational Graph Convolutional Networks
RGATConv - Relational Graph Attention Networks
ResGatedGraphConv - Residual Gated Graph ConvNets
SGConv - Simplifying Graph Convolutional Networks
SignedConv - Signed Graph Convolutional Networks
SplineConv - Spline-based convolution
SuperGATConv - SuperGAT
TAGConv - Topology Adaptive Graph Convolutional Networks
TWirls - Trainable Wishart Relational Networks
XConv - PointNet++ XConv layer
Aggregation Functions
Aggregation Module - Advanced aggregation functions
MultiAggregation - Multiple aggregation combination
AttentionalAggregation - Attention-based aggregation
DeepSetsAggregation - DeepSets aggregation
DegreeScalerAggregation - Degree-based scaling
EquilibriumAggregation - Equilibrium-based aggregation
GraphMultisetTransformer - Graph Multiset Transformer
LSTMAggregation - LSTM-based aggregation
MLPAggregation - MLP aggregation
PowerMeanAggregation - Power mean aggregation
Set2Set - Set2Set aggregation
SoftmaxAggregation - Softmax aggregation
SortAggregation - Sort aggregation
VarAggregation - Variance aggregation
Medium Priority (Advanced Features)
Pooling Layers
ASAPooling - Adaptive Structure Aware Pooling
EdgePooling - Edge-based pooling
GCNPool - GCN-based pooling
GlobalAttention - Global attention pooling
GraphSAINTSampler - GraphSAINT sampling
HitAndRun - Hit and Run sampling
MaxPooling - Max pooling on graphs
MemPooling - Memory-based pooling
NodeSAINTSampler - Node-based GraphSAINT sampling
PANPooling - Path Augmented Pooling
Pre-built Models
AttentiveFP - Attentive Fingerprinting
BASIC_GNN - Enhanced basic model variations
DeepGCN - Deep Graph Convolutional Networks
DeepGraphInfomax - Deep Graph Infomax
DiffPool - Differentiable Pooling
GAE - Graph Autoencoders
VGAE - Variational Graph Autoencoders
GCN - Enhanced versions
GraphSAGE - Enhanced versions
GraphUNet - Graph U-Net
JK-Net - Enhanced Jumping Knowledge Networks
MetaPath2Vec - MetaPath2Vec for heterogeneous graphs
Node2Vec - Node2Vec embeddings
PNA - Principal Neighbourhood Aggregation networks
SchNet - SchNet for molecular property prediction
TGN - Temporal Graph Networks
Normalization Layers
DiffGroupNorm - Differentiable Group Normalization
InstanceNorm - Instance Normalization
MessageNorm - Message Normalization
PairNorm - Pair Normalization
JAX/JraphX Specific Limitations
k-NN Graph Construction
torch-cluster integration - PyTorch Geometric’s DynamicEdgeConv uses torch_cluster.knn() for automatic k-nearest neighbor computation from node features. JraphX’s DynamicEdgeConv is a simplified version that requires pre-computed k-NN indices as input.
Dynamic graph construction - Full dynamic graph construction would require a JAX-native k-NN implementation, which is not currently available.
Lower Priority (Specialized Features)
Knowledge Graph Embeddings
ComplEx - Complex embeddings
DistMult - DistMult embeddings
HolE - Holographic embeddings
KGEModel - Base class for KG embeddings
PairRE - Paired relation embeddings
RotatE - Rotation-based embeddings
TransE - Translation embeddings
Dense Layers
DenseGCNConv - Dense GCN convolution
DenseGINConv - Dense GIN convolution
DenseGraphConv - Dense graph convolution
DenseSAGEConv - Dense SAGE convolution
LinearTransformation - Dense linear layers
Functional Operations
dropout - Graph-aware dropout
gumbel_softmax - Gumbel softmax for graphs
local_graph_clustering - Local clustering
pagerank - PageRank algorithm
subgraph - Subgraph sampling
Transforms (Not Core but Useful)
AddSelfLoops - Add self-loops transform
Compose - Transform composition
NormalizeFeatures - Feature normalization
RandomNodeSplit - Random node splitting
RemoveIsolatedNodes - Remove isolated nodes
ToDevice - Device placement transform
ToSparseTensor - Sparse tensor conversion
Data Loading & Sampling
DataLoader - Graph data loading
NeighborSampler - Neighborhood sampling
RandomWalkSampler - Random walk sampling
ShaDowKHopSampler - ShaDow k-hop sampling
Datasets (Not Applicable - JAX doesn’t need this)
❌ All dataset classes (TUDataset, Planetoid, etc.) - Not relevant for JAX-only library
Distributed Training (Future Consideration)
DistributedSampler - For future JAX distributed training
GraphSAINT - Distributed sampling strategies
Features Deliberately Omitted
PyTorch-Specific
❌ DataParallel - JAX uses different parallelization
❌ torch.compile integration - JAX uses jit instead
❌ SparseTensor support - JAX has different sparse support
❌ CUDA-specific operations - might be heavy lift
Framework-Specific
❌ Heterogeneous graphs - Complex feature, may not fit JAX patterns
❌ Explainability modules - Separate concern
❌ NLP modules - Out of scope
❌ Remote backend - PyG-specific
Implementation Status Legend
✅ Implemented - Available in JraphX
Planned - Should be implemented
❌ Omitted - Deliberately not implementing
Removed Documentation Files
The following PyTorch Geometric documentation files have been removed from JraphX as they are not applicable to a JAX-based GNN library:
Advanced Concepts (Removed)
cpu_affinity.rst - PyTorch-specific CPU affinity settings
graphgym.rst - GraphGym framework (PyTorch ecosystem)
hgam.rst - Heterogeneous Graph Attention Memory (not implemented)
remote.rst - Remote backend for PyTorch Geometric
sparse_tensor.rst - PyTorch sparse tensor integration
Module Documentation (Removed)
contrib.rst - Community contributions (PyTorch-specific)
datasets.rst - Dataset loading (JraphX uses external datasets)
distributed.rst - Distributed training (PyTorch-specific)
explain.rst - Model explainability (separate concern)
graphgym.rst - GraphGym configuration system
loader.rst - Data loading utilities (not needed for JAX)
metrics.rst - Evaluation metrics (use external libraries)
profile.rst - Performance profiling (JAX has its own tools)
sampler.rst - Graph sampling utilities (not implemented)
transforms.rst - Data transforms (JAX uses functional preprocessing)
Tutorial Documentation (Removed)
application.rst - Application-specific tutorials
compile.rst - torch.compile integration (JAX uses jit)
create_dataset.rst - Dataset creation (not JraphX’s scope)
dataset_splitting.rst - Dataset splitting utilities
dataset.rst - Dataset handling
distributed_pyg.rst - Distributed PyTorch Geometric
distributed.rst - Distributed training
explain.rst - Model explainability
graph_transformer.rst - Advanced transformer architectures (not implemented)
heterogeneous.rst - Heterogeneous graph processing (not implemented)
load_csv.rst - CSV loading utilities
multi_gpu_vanilla.rst - Multi-GPU training (PyTorch-specific)
multi_node_multi_gpu_vanilla.rst - Multi-node training (PyTorch-specific)
neighbor_loader.rst - Neighborhood sampling (not implemented)
point_cloud.rst - Point cloud processing (limited support)
shallow_node_embeddings.rst - Node embedding methods (not implemented)
Rationale for Removal
Framework Mismatch: PyTorch-specific features that don’t apply to JAX
Scope Limitation: JraphX focuses on core GNN layers, not entire ML pipelines
Unimplemented Features: Advanced features not yet available in JraphX
External Dependencies: Features that rely on PyTorch ecosystem
Removed Figure Files
The following figure files from docs/source/_figures/
have been removed as they were not referenced in the JraphX documentation:
architecture.pdf / architecture.svg - PyTorch Geometric architecture diagrams
dist_part.png / dist_proc.png / dist_sampling.png - Distributed training figures (PyTorch-specific)
graphgps_layer.png - GraphGPS layer architecture (not implemented)
graphgym_design_space.png / graphgym_evaluation.png / graphgym_results.png - GraphGym framework figures
hg_example.svg / hg_example.tex - Heterogeneous graph examples (not implemented)
intel_kumo.png - Intel optimization figures (not applicable)
meshcnn_edge_adjacency.svg - MeshCNN figures (not implemented)
point_cloud1.png - point_cloud4.png - Point cloud examples (limited support)
remote_1.png - remote_3.png - Remote backend figures (not applicable)
shallow_node_embeddings.png - Node embedding figures (not implemented)
to_hetero.svg / to_hetero.tex / to_hetero_with_bases.svg / to_hetero_with_bases.tex - Heterogeneous graph conversion (not implemented)
training_affinity.png - CPU affinity training (PyTorch-specific)
Kept Figure Files
Only the essential figures were retained:
graph.svg / graph.tex - Basic graph visualization used in introduction tutorial
build.sh - Figure generation script
Kept Documentation Files
The following files were retained and translated to JraphX:
Core tutorials:
create_gnn.rst
,gnn_design.rst
(JAX integration)Essential concepts:
batching.rst
,jit.rst
,compile.rst
API reference:
nn.rst
,data.rst
,utils.rst
,root.rst
Cheatsheets:
gnn_cheatsheet.rst
,data_cheatsheet.rst
Getting started:
introduction.rst
,installation.rst
Notes
Priority is based on common usage patterns and core GNN functionality
JAX-specific optimizations should be added where applicable (jit, vmap, scan)
Some features may need significant adaptation for JAX/NNX paradigms
Documentation cleanup focused on maintaining only relevant, translated content