jraphx.nn

This module contains neural network layers and operations for graph neural networks.

Overview

The jraphx.nn module provides a comprehensive set of neural network components for building graph neural networks:

Core Components:

  • Message Passing Framework (jraphx.nn.conv): Base class and implementations for graph convolutions

  • Pre-built Models (jraphx.nn.models): Ready-to-use GNN architectures (GCN, GAT, GraphSAGE, GIN)

  • Normalization Layers (jraphx.nn.norm): BatchNorm, LayerNorm, and GraphNorm for GNNs

  • Pooling Operations (jraphx.nn.pool): Global and hierarchical pooling methods

Quick Start

Using Pre-built Models

from jraphx.nn.models import GCN
import flax.nnx as nnx

# Create a 3-layer GCN model
model = GCN(
    in_features=16,
    hidden_features=64,
    num_layers=3,
    out_features=10,
    dropout_rate=0.5,
    norm="layer_norm",
    rngs=nnx.Rngs(0)
)

# Forward pass
out = model(x, edge_index, batch=batch)

Building Custom Models

from jraphx.nn.conv import GCNConv, GATConv
from jraphx.nn.norm import GraphNorm
from jraphx.nn.pool import TopKPooling, global_mean_pool
import flax.nnx as nnx

class CustomGNN(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        self.conv1 = GCNConv(in_features, 64, rngs=rngs)
        self.norm1 = GraphNorm(64, rngs=rngs)
        self.pool1 = TopKPooling(64, ratio=0.8, rngs=rngs)

        self.conv2 = GATConv(64, 64, heads=4, rngs=rngs)
        self.norm2 = GraphNorm(256, rngs=rngs)  # 64 * 4 heads

        self.classifier = nnx.Linear(256, out_features, rngs=rngs)
        self.dropout = nnx.Dropout(0.5, rngs=rngs)

    def __call__(self, x, edge_index, batch):
        # First conv block
        x = self.conv1(x, edge_index)
        x = self.norm1(x, batch)
        x = nnx.relu(x)

        # Pooling
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, batch=batch)

        # Second conv block (GAT)
        x = self.conv2(x, edge_index)
        x = self.norm2(x, batch)
        x = nnx.relu(x)

        # Global pooling and classification
        x = global_mean_pool(x, batch)
        x = self.dropout(x)
        return self.classifier(x)

Module Organization

Convolution Layers (jraphx.nn.conv):
  • MessagePassing: Base class for custom layers

  • GCNConv: Graph Convolutional Network

  • GATConv: Graph Attention Network

  • GATv2Conv: Improved GAT with dynamic attention

  • SAGEConv: GraphSAGE with multiple aggregations

  • GINConv: Graph Isomorphism Network

  • EdgeConv: Edge convolution for point clouds

  • DynamicEdgeConv: Dynamic edge convolution (requires pre-computed k-NN)

  • TransformerConv: Graph Transformer with multi-head attention

Pre-built Models (jraphx.nn.models):
  • GCN: Multi-layer GCN architecture

  • GAT: Multi-layer GAT architecture

  • GraphSAGE: Multi-layer GraphSAGE architecture

  • GIN: Multi-layer GIN architecture

  • MLP: Multi-layer perceptron

  • JumpingKnowledge: Layer aggregation module

  • BasicGNN: Abstract base class for GNN models

Normalization (jraphx.nn.norm):
  • BatchNorm: Batch normalization with running statistics

  • LayerNorm: Layer normalization (node-wise or graph-wise)

  • GraphNorm: Graph-specific normalization

Pooling (jraphx.nn.pool):
  • global_add_pool: Sum aggregation

  • global_mean_pool: Mean aggregation

  • global_max_pool: Max aggregation

  • global_min_pool: Min aggregation

  • TopKPooling: Select top-k important nodes

  • SAGPooling: Self-attention graph pooling