jraphx.nn.norm
This module contains normalization layers specifically designed for graph neural networks.
Normalization Layers
BatchNorm
- class BatchNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.
- Parameters:
num_features (int) – Size of each input sample.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5)momentum (float, optional) – Decay rate for the exponential moving average of the batch statistics. Higher values mean slower adaptation (more weight on past values). (default:
0.99)track_running_stats (bool, optional) – If set to
True, this module tracks the running mean and variance, and when set toFalse, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True)use_running_average (bool, optional) – If set to
True, use running statistics instead of batch statistics during evaluation. (default:False)axis (int, optional) – The feature or non-batch axis of the input. (default:
-1)dtype (
Union[str,type[Any],dtype,SupportsDType,Any,None], default:None) – The dtype of the result (default: infer from input and params).param_dtype (
Union[str,type[Any],dtype,SupportsDType,Any], default:<class 'jax.numpy.float32'>) – The dtype passed to parameter initializers (default: float32).use_bias (bool, optional) – If True, bias (beta) is added. (default:
True)use_scale (bool, optional) – If True, multiply by scale (gamma). (default:
True)bias_init (
Union[Initializer,Callable[...,Any]], default:<function zeros at 0x7f92e65f4cc0>) – Initializer for bias, by default, zero.scale_init (
Union[Initializer,Callable[...,Any]], default:<function ones at 0x7f92e66747c0>) – Initializer for scale, by default, one.axis_name (
Optional[str], default:None) – The axis name used to combine batch statistics from multiple devices.axis_index_groups (
Any, default:None) – Groups of axis indices within that named axis.use_fast_variance (
bool, default:True) – If true, use faster, but less numerically stable variance calculation.rngs (
Optional[Rngs], default:None) – Random number generators for initialization.
Batch normalization layer with running statistics for graph data.
Key Features:
Maintains running mean and variance for inference
Configurable momentum and epsilon parameters
Supports both training and evaluation modes
Compatible with graph batching
Example:
from jraphx.nn.norm import BatchNorm import flax.nnx as nnx # Create batch norm layer norm = BatchNorm( num_features=64, eps=1e-5, momentum=0.99, affine=True, track_running_stats=True, rngs=nnx.Rngs(0) ) # Apply normalization x_normalized = norm(x) # Training mode # Switch to eval mode norm.eval() x_eval = norm(x) # Uses running statistics
LayerNorm
- class LayerNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies layer normalization over each individual example in a batch of node features as described in the “Layer Normalization” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.
- Parameters:
num_features (int or list) – Size of each input sample, or list of dimensions to normalize.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5)elementwise_affine (bool, optional) – If set to
True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True)mode (str, optional) – The normalization mode to use for layer normalization (
"graph"or"node"). If"graph"is used, each graph will be considered as an element to be normalized. If “node” is used, each node will be considered as an element to be normalized. (default:"node")dtype (
Union[str,type[Any],dtype,SupportsDType,Any,None], default:None) – The dtype of the result (default: infer from input and params).param_dtype (
Union[str,type[Any],dtype,SupportsDType,Any], default:<class 'jax.numpy.float32'>) – The dtype passed to parameter initializers (default: float32).use_bias (bool, optional) – If True, bias (beta) is added. (default:
True)use_scale (bool, optional) – If True, multiply by scale (gamma). (default:
True)bias_init (
Union[Initializer,Callable[...,Any]], default:<function zeros at 0x7f92e65f4cc0>) – Initializer for bias, by default, zero.scale_init (
Union[Initializer,Callable[...,Any]], default:<function ones at 0x7f92e66747c0>) – Initializer for scale, by default, one.reduction_axes (
Union[int,Sequence[int]], default:-1) – Axes for computing normalization statistics.feature_axes (
Union[int,Sequence[int]], default:-1) – Feature axes for learned bias and scaling.axis_name (
Optional[str], default:None) – The axis name used to combine batch statistics from multiple devices.axis_index_groups (
Any, default:None) – Groups of axis indices within that named axis.use_fast_variance (
bool, default:True) – If true, use faster, but less numerically stable variance calculation.rngs (
Optional[Rngs], default:None) – Random number generators for initialization.
Layer normalization for graph neural networks with node-wise or graph-wise modes.
Normalization Modes:
node: Normalize across feature dimensions for each node independently
graph: Normalize across all nodes and features in a graph
Example:
from jraphx.nn.norm import LayerNorm import flax.nnx as nnx # Node-wise normalization norm = LayerNorm( num_features=64, mode="node", eps=1e-5, elementwise_affine=True, rngs=nnx.Rngs(0) ) x_normalized = norm(x) # Graph-wise normalization (requires batch index) norm_graph = LayerNorm( num_features=64, mode="graph", rngs=nnx.Rngs(0) ) x_normalized = norm_graph(x, batch=batch)
GraphNorm
- class GraphNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleApplies graph normalization over individual graphs as described in the “GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta\]where \(\alpha\) denotes parameters that learn how much information to keep in the mean.
- Parameters:
Graph normalization layer that normalizes node features across the graph structure.
Algorithm:
Compute mean and variance per graph
Normalize features within each graph
Apply learnable affine transformation
Example:
from jraphx.nn.norm import GraphNorm import flax.nnx as nnx norm = GraphNorm( num_features=64, eps=1e-5, rngs=nnx.Rngs(0) ) # For batched graphs x_normalized = norm(x, batch=batch) # For single graph (batch=None) x_normalized = norm(x)
Normalization Selection Guide
When to Use Each Normalization
- BatchNorm:
Best with large, consistent batch sizes
Effective for deep networks
Requires sufficient batch statistics
Good for training stability
- LayerNorm:
Works well with variable batch sizes
Effective for attention-based models
Node mode: Best for heterogeneous graphs
Graph mode: Best for homogeneous graphs
- GraphNorm:
Specifically designed for graph data
Handles varying graph sizes well
Good for graph-level tasks
Robust to batch size variations
Performance Comparison
Normalization |
Batch Size Sensitivity |
Graph Size Sensitivity |
Best Use Case |
|---|---|---|---|
BatchNorm |
High |
Low |
Large batch training |
LayerNorm (node) |
None |
None |
Node-level tasks |
LayerNorm (graph) |
None |
Medium |
Small graphs |
GraphNorm |
Low |
Low |
Graph-level tasks |
Implementation Details
Running Statistics
BatchNorm maintains running statistics for inference:
# During training
norm.train() # Use batch statistics
output = norm(x)
# During inference
norm.eval() # Use running statistics
output = norm(x)
Affine Transformations
All normalization layers support learnable affine parameters:
# With affine transformation (default)
norm = LayerNorm(64, elementwise_affine=True)
# Learns gamma and beta parameters
# Without affine transformation
norm = LayerNorm(64, elementwise_affine=False)
# Pure normalization only
Numerical Stability
All layers use epsilon for numerical stability:
# Adjust epsilon for precision
norm = GraphNorm(64, eps=1e-5) # Default
norm = GraphNorm(64, eps=1e-8) # Higher precision
Integration with GNN Models
Using with Pre-built Models
from jraphx.nn.models import GCN
# GCN with layer normalization
model = GCN(
in_features=16,
hidden_features=64,
num_layers=3,
out_features=10,
norm="layer_norm", # or "batch_norm", "graph_norm"
rngs=nnx.Rngs(0)
)
Custom Model Integration
from jraphx.nn.conv import GCNConv
from jraphx.nn.norm import GraphNorm
import flax.nnx as nnx
class CustomGNN(nnx.Module):
def __init__(self, in_features, out_features, rngs):
self.conv = GCNConv(in_features, 64, rngs=rngs)
self.norm = GraphNorm(64, rngs=rngs)
self.linear = nnx.Linear(64, out_features, rngs=rngs)
def __call__(self, x, edge_index, batch=None):
x = self.conv(x, edge_index)
x = self.norm(x, batch)
x = nnx.relu(x)
return self.linear(x)