jraphx.nn.norm
This module contains normalization layers specifically designed for graph neural networks.
Normalization Layers
BatchNorm
- class BatchNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module
Applies batch normalization over a batch of node features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over all nodes inside the mini-batch.
- Parameters:
num_features (int) – Size of each input sample.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5
)momentum (float, optional) – Decay rate for the exponential moving average of the batch statistics. Higher values mean slower adaptation (more weight on past values). (default:
0.99
)track_running_stats (bool, optional) – If set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True
)use_running_average (bool, optional) – If set to
True
, use running statistics instead of batch statistics during evaluation. (default:False
)axis (int, optional) – The feature or non-batch axis of the input. (default:
-1
)dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,Any
,None
], default:None
) – The dtype of the result (default: infer from input and params).param_dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,Any
], default:<class 'jax.numpy.float32'>
) – The dtype passed to parameter initializers (default: float32).use_bias (bool, optional) – If True, bias (beta) is added. (default:
True
)use_scale (bool, optional) – If True, multiply by scale (gamma). (default:
True
)bias_init (
Union
[Initializer
,Callable
[...
,Any
]], default:<function zeros at 0x7f02d4261300>
) – Initializer for bias, by default, zero.scale_init (
Union
[Initializer
,Callable
[...
,Any
]], default:<function ones at 0x7f02d40e4ea0>
) – Initializer for scale, by default, one.axis_name (
Optional
[str
], default:None
) – The axis name used to combine batch statistics from multiple devices.axis_index_groups (
Any
, default:None
) – Groups of axis indices within that named axis.use_fast_variance (
bool
, default:True
) – If true, use faster, but less numerically stable variance calculation.rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
Batch normalization layer with running statistics for graph data.
Key Features:
Maintains running mean and variance for inference
Configurable momentum and epsilon parameters
Supports both training and evaluation modes
Compatible with graph batching
Example:
from jraphx.nn.norm import BatchNorm import flax.nnx as nnx # Create batch norm layer norm = BatchNorm( num_features=64, eps=1e-5, momentum=0.99, affine=True, track_running_stats=True, rngs=nnx.Rngs(0) ) # Apply normalization x_normalized = norm(x) # Training mode # Switch to eval mode norm.eval() x_eval = norm(x) # Uses running statistics
LayerNorm
- class LayerNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module
Applies layer normalization over each individual example in a batch of node features as described in the “Layer Normalization” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta\]The mean and standard-deviation are calculated across all nodes and all node channels separately for each object in a mini-batch.
- Parameters:
num_features (int or list) – Size of each input sample, or list of dimensions to normalize.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5
)elementwise_affine (bool, optional) – If set to
True
, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True
)mode (str, optional) – The normalization mode to use for layer normalization (
"graph"
or"node"
). If"graph"
is used, each graph will be considered as an element to be normalized. If “node” is used, each node will be considered as an element to be normalized. (default:"node"
)dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,Any
,None
], default:None
) – The dtype of the result (default: infer from input and params).param_dtype (
Union
[str
,type
[Any
],dtype
,SupportsDType
,Any
], default:<class 'jax.numpy.float32'>
) – The dtype passed to parameter initializers (default: float32).use_bias (bool, optional) – If True, bias (beta) is added. (default:
True
)use_scale (bool, optional) – If True, multiply by scale (gamma). (default:
True
)bias_init (
Union
[Initializer
,Callable
[...
,Any
]], default:<function zeros at 0x7f02d4261300>
) – Initializer for bias, by default, zero.scale_init (
Union
[Initializer
,Callable
[...
,Any
]], default:<function ones at 0x7f02d40e4ea0>
) – Initializer for scale, by default, one.reduction_axes (
Union
[int
,Sequence
[int
]], default:-1
) – Axes for computing normalization statistics.feature_axes (
Union
[int
,Sequence
[int
]], default:-1
) – Feature axes for learned bias and scaling.axis_name (
Optional
[str
], default:None
) – The axis name used to combine batch statistics from multiple devices.axis_index_groups (
Any
, default:None
) – Groups of axis indices within that named axis.use_fast_variance (
bool
, default:True
) – If true, use faster, but less numerically stable variance calculation.rngs (
Optional
[Rngs
], default:None
) – Random number generators for initialization.
Layer normalization for graph neural networks with node-wise or graph-wise modes.
Normalization Modes:
node: Normalize across feature dimensions for each node independently
graph: Normalize across all nodes and features in a graph
Example:
from jraphx.nn.norm import LayerNorm import flax.nnx as nnx # Node-wise normalization norm = LayerNorm( num_features=64, mode="node", eps=1e-5, elementwise_affine=True, rngs=nnx.Rngs(0) ) x_normalized = norm(x) # Graph-wise normalization (requires batch index) norm_graph = LayerNorm( num_features=64, mode="graph", rngs=nnx.Rngs(0) ) x_normalized = norm_graph(x, batch=batch)
GraphNorm
- class GraphNorm(*args: Any, **kwargs: Any)[source]
Bases:
Module
Applies graph normalization over individual graphs as described in the “GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training” paper.
\[\mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]} {\sqrt{\textrm{Var}[\mathbf{x} - \alpha \odot \textrm{E}[\mathbf{x}]] + \epsilon}} \odot \gamma + \beta\]where \(\alpha\) denotes parameters that learn how much information to keep in the mean.
- Parameters:
Graph normalization layer that normalizes node features across the graph structure.
Algorithm:
Compute mean and variance per graph
Normalize features within each graph
Apply learnable affine transformation
Example:
from jraphx.nn.norm import GraphNorm import flax.nnx as nnx norm = GraphNorm( num_features=64, eps=1e-5, rngs=nnx.Rngs(0) ) # For batched graphs x_normalized = norm(x, batch=batch) # For single graph (batch=None) x_normalized = norm(x)
Normalization Selection Guide
When to Use Each Normalization
- BatchNorm:
Best with large, consistent batch sizes
Effective for deep networks
Requires sufficient batch statistics
Good for training stability
- LayerNorm:
Works well with variable batch sizes
Effective for attention-based models
Node mode: Best for heterogeneous graphs
Graph mode: Best for homogeneous graphs
- GraphNorm:
Specifically designed for graph data
Handles varying graph sizes well
Good for graph-level tasks
Robust to batch size variations
Performance Comparison
Normalization |
Batch Size Sensitivity |
Graph Size Sensitivity |
Best Use Case |
---|---|---|---|
BatchNorm |
High |
Low |
Large batch training |
LayerNorm (node) |
None |
None |
Node-level tasks |
LayerNorm (graph) |
None |
Medium |
Small graphs |
GraphNorm |
Low |
Low |
Graph-level tasks |
Implementation Details
Running Statistics
BatchNorm maintains running statistics for inference:
# During training
norm.train() # Use batch statistics
output = norm(x)
# During inference
norm.eval() # Use running statistics
output = norm(x)
Affine Transformations
All normalization layers support learnable affine parameters:
# With affine transformation (default)
norm = LayerNorm(64, elementwise_affine=True)
# Learns gamma and beta parameters
# Without affine transformation
norm = LayerNorm(64, elementwise_affine=False)
# Pure normalization only
Numerical Stability
All layers use epsilon for numerical stability:
# Adjust epsilon for precision
norm = GraphNorm(64, eps=1e-5) # Default
norm = GraphNorm(64, eps=1e-8) # Higher precision
Integration with GNN Models
Using with Pre-built Models
from jraphx.nn.models import GCN
# GCN with layer normalization
model = GCN(
in_features=16,
hidden_features=64,
num_layers=3,
out_features=10,
norm="layer_norm", # or "batch_norm", "graph_norm"
rngs=nnx.Rngs(0)
)
Custom Model Integration
from jraphx.nn.conv import GCNConv
from jraphx.nn.norm import GraphNorm
import flax.nnx as nnx
class CustomGNN(nnx.Module):
def __init__(self, in_features, out_features, rngs):
self.conv = GCNConv(in_features, 64, rngs=rngs)
self.norm = GraphNorm(64, rngs=rngs)
self.linear = nnx.Linear(64, out_features, rngs=rngs)
def __call__(self, x, edge_index, batch=None):
x = self.conv(x, edge_index)
x = self.norm(x, batch)
x = nnx.relu(x)
return self.linear(x)