Missing Tests in JraphX
This document tracks PyTorch Geometric tests that could not be ported to JraphX due to missing functionality, architectural differences, or deliberate design choices.
Test Conversion Progress
The JraphX library aims to provide core GNN functionality with a JAX-first design. Not all PyTorch Geometric features are implemented, as some are outside the scope of a lightweight JAX library.
Successfully Converted Tests
All tests have been successfully converted with comprehensive test coverage. Below is the complete list of converted test files and their status.
Core Data Structures
test_data_jraphx.py
✓ - Data class functionality (9 tests)test_batch_jraphx.py
✓ - Batch class functionality (7 tests)
Convolution Layers
test_gcn_conv_jraphx.py
✓ - Graph Convolution Network (10 tests)test_gat_conv_jraphx.py
✓ - Graph Attention Network (12 tests)test_gatv2_conv_jraphx.py
✓ - Improved Graph Attention Network (referenced via GAT)test_sage_conv_jraphx.py
✓ - GraphSAGE layer (16 tests)test_gin_conv_jraphx.py
✓ - Graph Isomorphism Network (9 tests)test_edge_conv_jraphx.py
✓ - EdgeConv and DynamicEdgeConv layers (12 tests)test_transformer_conv_jraphx.py
✓ - Graph Transformer layer (13 tests)
Utility Functions
test_scatter_jraphx.py
✓ - Scatter operations (add, mean, max, min, std, logsumexp) (9 tests)test_degree_jraphx.py
✓ - Degree computation functions (3 tests)test_coalesce_jraphx.py
✓ - Edge coalescing functionality (5 tests)test_loop_jraphx.py
✓ - Self-loop addition/removal (6 tests)
Models
test_basic_gnn_jraphx.py
✓ - Pre-built GNN models (GCN, GAT, SAGE, GIN) (769 tests)test_jumping_knowledge_jraphx.py
✓ - JumpingKnowledge aggregation (4 tests)test_mlp_jraphx.py
✓ - Multi-layer perceptron (12 tests)
Pooling Operations
test_glob_jraphx.py
✓ - Global pooling operations (add, mean, max) (4 tests)
Normalization Layers
test_batch_norm_jraphx.py
✓ - Batch normalization (5 tests)test_layer_norm_jraphx.py
✓ - Layer normalization (9 tests)test_graph_norm_jraphx.py
✓ - Graph normalization (6 tests)
Total: 920+ test cases covering all implemented JraphX functionality
Tests Not Converted - Missing Core Features
High Priority Missing Features
These tests cannot be converted because the underlying JraphX features don’t exist yet:
Convolution Layers
test_appnp.py
- APPNP layer not implementedtest_cheb_conv.py
- Chebyshev convolution not implementedtest_arma_conv.py
- ARMA filters not implementedtest_graph_conv.py
- Basic graph convolution not implementedtest_nn_conv.py
- NN-based convolution not implementedtest_spline_conv.py
- Spline-based convolution not implementedtest_pna_conv.py
- Principal Neighbourhood Aggregation not implementedtest_film_conv.py
- FiLM layers not implemented
Message Passing Framework
test_message_passing.py
- Only partial message passing framework existsAdvanced aggregation functions beyond basic scatter operations
Custom message/update function hooks
Heterogeneous Graphs
test_hetero_conv.py
- No heterogeneous graph supporttest_hgt_conv.py
- Heterogeneous Graph Transformer not implementedtest_han_conv.py
- Heterogeneous Attention Network not implemented
Pooling Operations
test_topk_pool.py
- TopK pooling not implementedtest_sag_pool.py
- SAG pooling not implementedtest_asap.py
- ASAP pooling not implementedtest_diff_pool.py
- DiffPool not implementedtest_edge_pool.py
- Edge pooling not implemented
Advanced Models
test_autoencoder.py
- Autoencoder models not implementedtest_deep_graph_infomax.py
- Deep Graph Infomax not implementedtest_node2vec.py
- Node2Vec not implementedtest_metapath2vec.py
- Metapath2Vec not implemented
Tests Not Converted - PyTorch-Specific Features
These tests rely on PyTorch-specific functionality that doesn’t have JAX equivalents:
JIT Compilation & TorchScript
Tests involving
torch.jit.script()
compilationTorchScript-specific functionality tests
Dynamic graph compilation features
CUDA-Specific Tests
test_fused_gat_conv.py
- CUDA kernel optimizationsCuGraph integration tests (
tests/nn/conv/cugraph/
)GPU memory management tests
PyTorch Lightning Integration
test_datamodule.py
- PyTorch Lightning specific featuresLightning-based training loop tests
Sparse Tensor Operations
test_sparse.py
- PyTorch sparse tensor operationstorch_sparse library integration tests
SparseTensor class functionality
Tests Not Converted - Out of Scope
These test areas are deliberately excluded from JraphX scope:
Explainability
tests/explain/
- Entire explainability module (47 test files)GNN explainability is a specialized domain outside core GNN functionality
Visualization
tests/visualization/
- Graph visualization functionalityVisualization is typically handled by specialized libraries
Natural Language Processing
tests/nn/nlp/
- NLP-specific functionalitySentence transformers and language model integration
3D Point Clouds & Molecular Data
Point cloud convolutions (
test_point_conv.py
,test_ppf_conv.py
)Molecular property prediction models
3D geometry-specific operations
Advanced Optimization
test_correct_and_smooth.py
- Advanced optimization techniquesLabel propagation algorithms
Advanced training strategies
Database & Storage Integration
tests/data/test_database.py
- Database integrationtests/data/test_remote_backend_utils.py
- Remote storageComplex data pipeline functionality
Conversion Notes and Patterns
When converting tests, the following patterns were applied:
Import Conversions
# PyTorch Geometric
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
# JraphX
import jax.numpy as jnp
from jraphx.nn.conv import GCNConv
from jraphx.data import Data, Batch
Tensor Operations
# PyTorch
x = torch.randn(4, 16)
assert torch.allclose(x, x)
# JAX
x = random.normal(random.key(0), (4, 16))
assert jnp.allclose(x, x)
Model Initialization
# PyTorch Geometric
conv = GCNConv(16, 32)
# JraphX
conv = GCNConv(16, 32, rngs=nnx.Rngs(0))
Testing Strategy
Maintain original test logic and assertions
Comment out unportable sections with clear TODO notes
Document conversion limitations in test docstrings
Future Work
This document should be updated as new features are added to JraphX. Priority should be given to:
Core Convolution Layers - APPNP, Chebyshev, Graph convolutions
Advanced Pooling - TopK, SAG, hierarchical pooling
Heterogeneous Support - Multi-relation and multi-node-type graphs
Advanced Aggregation - Beyond basic scatter operations
Each new feature implementation should be accompanied by converted tests from the corresponding PyG test files.
Contributing Test Conversions
When contributing new test conversions:
Follow the established conversion patterns above
Document any limitations or missing functionality
Update this document with the conversion status
Ensure tests pass with
python -m pytest
Add TODO comments for unportable test sections
For questions about test conversion priorities or implementation approaches, refer to the main JraphX documentation and the docs/source/missing_features.rst
file.