jraphx.nn
This module contains neural network layers and operations for graph neural networks.
Overview
The jraphx.nn module provides a comprehensive set of neural network components for building graph neural networks:
Core Components:
Message Passing Framework (jraphx.nn.conv): Base class and implementations for graph convolutions
Pre-built Models (jraphx.nn.models): Ready-to-use GNN architectures (GCN, GAT, GraphSAGE, GIN)
Normalization Layers (jraphx.nn.norm): BatchNorm, LayerNorm, and GraphNorm for GNNs
Pooling Operations (jraphx.nn.pool): Global and hierarchical pooling methods
Quick Start
Using Pre-built Models
from jraphx.nn.models import GCN
import flax.nnx as nnx
# Create a 3-layer GCN model
model = GCN(
in_features=16,
hidden_features=64,
num_layers=3,
out_features=10,
dropout_rate=0.5,
norm="layer_norm",
rngs=nnx.Rngs(0)
)
# Forward pass
out = model(x, edge_index, batch=batch)
Building Custom Models
from jraphx.nn.conv import GCNConv, GATConv
from jraphx.nn.norm import GraphNorm
from jraphx.nn.pool import TopKPooling, global_mean_pool
import flax.nnx as nnx
class CustomGNN(nnx.Module):
def __init__(self, in_features, out_features, rngs):
self.conv1 = GCNConv(in_features, 64, rngs=rngs)
self.norm1 = GraphNorm(64, rngs=rngs)
self.pool1 = TopKPooling(64, ratio=0.8, rngs=rngs)
self.conv2 = GATConv(64, 64, heads=4, rngs=rngs)
self.norm2 = GraphNorm(256, rngs=rngs) # 64 * 4 heads
self.classifier = nnx.Linear(256, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x, edge_index, batch):
# First conv block
x = self.conv1(x, edge_index)
x = self.norm1(x, batch)
x = nnx.relu(x)
# Pooling
x, edge_index, _, batch, _ = self.pool1(x, edge_index, batch=batch)
# Second conv block (GAT)
x = self.conv2(x, edge_index)
x = self.norm2(x, batch)
x = nnx.relu(x)
# Global pooling and classification
x = global_mean_pool(x, batch)
x = self.dropout(x)
return self.classifier(x)
Module Organization
- Convolution Layers (
jraphx.nn.conv): MessagePassing: Base class for custom layersGCNConv: Graph Convolutional NetworkGATConv: Graph Attention NetworkGATv2Conv: Improved GAT with dynamic attentionSAGEConv: GraphSAGE with multiple aggregationsGINConv: Graph Isomorphism NetworkEdgeConv: Edge convolution for point cloudsDynamicEdgeConv: Dynamic edge convolution (requires pre-computed k-NN)TransformerConv: Graph Transformer with multi-head attention
- Pre-built Models (
jraphx.nn.models): GCN: Multi-layer GCN architectureGAT: Multi-layer GAT architectureGraphSAGE: Multi-layer GraphSAGE architectureGIN: Multi-layer GIN architectureMLP: Multi-layer perceptronJumpingKnowledge: Layer aggregation moduleBasicGNN: Abstract base class for GNN models
- Normalization (
jraphx.nn.norm): BatchNorm: Batch normalization with running statisticsLayerNorm: Layer normalization (node-wise or graph-wise)GraphNorm: Graph-specific normalization
- Pooling (
jraphx.nn.pool): global_add_pool: Sum aggregationglobal_mean_pool: Mean aggregationglobal_max_pool: Max aggregationglobal_min_pool: Min aggregationTopKPooling: Select top-k important nodesSAGPooling: Self-attention graph pooling