jraphx.nn.norm

This module contains normalization layers specifically designed for graph neural networks.

Normalization Layers

BatchNorm

class BatchNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

Applies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.

Parameters:
  • num_features (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • momentum (float, optional) – Decay rate for the exponential moving average of the batch statistics. Higher values mean slower adaptation (more weight on past values). (default: 0.99)

  • track_running_stats (bool, optional) – If set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: True)

  • use_running_average (bool, optional) – If set to True, use running statistics instead of batch statistics during evaluation. (default: False)

  • axis (int, optional) – The feature or non-batch axis of the input. (default: -1)

  • dtype (Union[str, type[Any], dtype, SupportsDType, Any, None], default: None) – The dtype of the result (default: infer from input and params).

  • param_dtype (Union[str, type[Any], dtype, SupportsDType, Any], default: <class 'jax.numpy.float32'>) – The dtype passed to parameter initializers (default: float32).

  • use_bias (bool, optional) – If True, bias (beta) is added. (default: True)

  • use_scale (bool, optional) – If True, multiply by scale (gamma). (default: True)

  • bias_init (Union[Initializer, Callable[..., Any]], default: <function zeros at 0x7f02d4261300>) – Initializer for bias, by default, zero.

  • scale_init (Union[Initializer, Callable[..., Any]], default: <function ones at 0x7f02d40e4ea0>) – Initializer for scale, by default, one.

  • axis_name (Optional[str], default: None) – The axis name used to combine batch statistics from multiple devices.

  • axis_index_groups (Any, default: None) – Groups of axis indices within that named axis.

  • use_fast_variance (bool, default: True) – If true, use faster, but less numerically stable variance calculation.

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Batch normalization layer with running statistics for graph data.

Key Features:

  • Maintains running mean and variance for inference

  • Configurable momentum and epsilon parameters

  • Supports both training and evaluation modes

  • Compatible with graph batching

Example:

from jraphx.nn.norm import BatchNorm
import flax.nnx as nnx

# Create batch norm layer
norm = BatchNorm(
    num_features=64,
    eps=1e-5,
    momentum=0.99,
    affine=True,
    track_running_stats=True,
    rngs=nnx.Rngs(0)
)

# Apply normalization
x_normalized = norm(x)  # Training mode

# Switch to eval mode
norm.eval()
x_eval = norm(x)  # Uses running statistics

LayerNorm

class LayerNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

Applies layer normalization over each individual example in a batch of node features as described in the “Layer Normalization” paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]

The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.

Parameters:
  • num_features (int or list) – Size of each input sample, or list of dimensions to normalize.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • elementwise_affine (bool, optional) – If set to True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default: True)

  • mode (str, optional) – The normalization mode to use for layer normalization ("graph" or "node"). If "graph" is used, each graph will be considered as an element to be normalized. If “node” is used, each node will be considered as an element to be normalized. (default: "node")

  • dtype (Union[str, type[Any], dtype, SupportsDType, Any, None], default: None) – The dtype of the result (default: infer from input and params).

  • param_dtype (Union[str, type[Any], dtype, SupportsDType, Any], default: <class 'jax.numpy.float32'>) – The dtype passed to parameter initializers (default: float32).

  • use_bias (bool, optional) – If True, bias (beta) is added. (default: True)

  • use_scale (bool, optional) – If True, multiply by scale (gamma). (default: True)

  • bias_init (Union[Initializer, Callable[..., Any]], default: <function zeros at 0x7f02d4261300>) – Initializer for bias, by default, zero.

  • scale_init (Union[Initializer, Callable[..., Any]], default: <function ones at 0x7f02d40e4ea0>) – Initializer for scale, by default, one.

  • reduction_axes (Union[int, Sequence[int]], default: -1) – Axes for computing normalization statistics.

  • feature_axes (Union[int, Sequence[int]], default: -1) – Feature axes for learned bias and scaling.

  • axis_name (Optional[str], default: None) – The axis name used to combine batch statistics from multiple devices.

  • axis_index_groups (Any, default: None) – Groups of axis indices within that named axis.

  • use_fast_variance (bool, default: True) – If true, use faster, but less numerically stable variance calculation.

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Layer normalization for graph neural networks with node-wise or graph-wise modes.

Normalization Modes:

  • node: Normalize across feature dimensions for each node independently

  • graph: Normalize across all nodes and features in a graph

Example:

from jraphx.nn.norm import LayerNorm
import flax.nnx as nnx

# Node-wise normalization
norm = LayerNorm(
    num_features=64,
    mode="node",
    eps=1e-5,
    elementwise_affine=True,
    rngs=nnx.Rngs(0)
)

x_normalized = norm(x)

# Graph-wise normalization (requires batch index)
norm_graph = LayerNorm(
    num_features=64,
    mode="graph",
    rngs=nnx.Rngs(0)
)

x_normalized = norm_graph(x, batch=batch)

GraphNorm

class GraphNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

Applies graph normalization over individual graphs as described in the “GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training” paper.

\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta\]

where \(\alpha\) denotes parameters that learn how much information to keep in the mean.

Parameters:
  • num_features (int) – Size of each input sample.

  • eps (float, optional) – A value added to the denominator for numerical stability. (default: 1e-5)

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Graph normalization layer that normalizes node features across the graph structure.

Algorithm:

  1. Compute mean and variance per graph

  2. Normalize features within each graph

  3. Apply learnable affine transformation

Example:

from jraphx.nn.norm import GraphNorm
import flax.nnx as nnx

norm = GraphNorm(
    num_features=64,
    eps=1e-5,
    rngs=nnx.Rngs(0)
)

# For batched graphs
x_normalized = norm(x, batch=batch)

# For single graph (batch=None)
x_normalized = norm(x)

Normalization Selection Guide

When to Use Each Normalization

BatchNorm:
  • Best with large, consistent batch sizes

  • Effective for deep networks

  • Requires sufficient batch statistics

  • Good for training stability

LayerNorm:
  • Works well with variable batch sizes

  • Effective for attention-based models

  • Node mode: Best for heterogeneous graphs

  • Graph mode: Best for homogeneous graphs

GraphNorm:
  • Specifically designed for graph data

  • Handles varying graph sizes well

  • Good for graph-level tasks

  • Robust to batch size variations

Performance Comparison

Normalization Characteristics

Normalization

Batch Size Sensitivity

Graph Size Sensitivity

Best Use Case

BatchNorm

High

Low

Large batch training

LayerNorm (node)

None

None

Node-level tasks

LayerNorm (graph)

None

Medium

Small graphs

GraphNorm

Low

Low

Graph-level tasks

Implementation Details

Running Statistics

BatchNorm maintains running statistics for inference:

# During training
norm.train()  # Use batch statistics
output = norm(x)

# During inference
norm.eval()  # Use running statistics
output = norm(x)

Affine Transformations

All normalization layers support learnable affine parameters:

# With affine transformation (default)
norm = LayerNorm(64, elementwise_affine=True)
# Learns gamma and beta parameters

# Without affine transformation
norm = LayerNorm(64, elementwise_affine=False)
# Pure normalization only

Numerical Stability

All layers use epsilon for numerical stability:

# Adjust epsilon for precision
norm = GraphNorm(64, eps=1e-5)  # Default
norm = GraphNorm(64, eps=1e-8)  # Higher precision

Integration with GNN Models

Using with Pre-built Models

from jraphx.nn.models import GCN

# GCN with layer normalization
model = GCN(
    in_features=16,
    hidden_features=64,
    num_layers=3,
    out_features=10,
    norm="layer_norm",  # or "batch_norm", "graph_norm"
    rngs=nnx.Rngs(0)
)

Custom Model Integration

from jraphx.nn.conv import GCNConv
from jraphx.nn.norm import GraphNorm
import flax.nnx as nnx

class CustomGNN(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        self.conv = GCNConv(in_features, 64, rngs=rngs)
        self.norm = GraphNorm(64, rngs=rngs)
        self.linear = nnx.Linear(64, out_features, rngs=rngs)

    def __call__(self, x, edge_index, batch=None):
        x = self.conv(x, edge_index)
        x = self.norm(x, batch)
        x = nnx.relu(x)
        return self.linear(x)