jraphx.nn.pool

This module contains pooling operations for graph neural networks, including global pooling and hierarchical pooling methods.

Global Pooling Operations

Global pooling aggregates node features across entire graphs, producing graph-level representations.

global_add_pool

global_add_pool(x: Array, batch: jax.Array | None = None, size: int | None = None) Array[source]

Returns batch-wise graph-level-outputs by adding node features across the node dimension.

For a single graph \(\mathcal{G}_i\), its output is computed by

\[\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n.\]
Parameters:
  • x (jax.Array) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (jax.Array, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)

Returns:

Graph-level features \(\mathbf{R} \in \mathbb{R}^{B \times F}\).

Return type:

jax.Array

Sum all node features in each graph.

Example:

from jraphx.nn.pool import global_add_pool

# For batched graphs
graph_features = global_add_pool(x, batch)
# Output shape: [num_graphs, features]

global_mean_pool

global_mean_pool(x: Array, batch: jax.Array | None = None, size: int | None = None) Array[source]

Returns batch-wise graph-level-outputs by averaging node features across the node dimension.

For a single graph \(\mathcal{G}_i\), its output is computed by

\[\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.\]
Parameters:
  • x (jax.Array) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}\).

  • batch (jax.Array, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)

Returns:

Graph-level features \(\mathbf{R} \in \mathbb{R}^{B \times F}\).

Return type:

jax.Array

Average all node features in each graph.

Example:

from jraphx.nn.pool import global_mean_pool

# Most common for graph classification
graph_features = global_mean_pool(x, batch)

global_max_pool

global_max_pool(x: Array, batch: jax.Array | None = None, size: int | None = None) Array[source]

Optimized global max pooling over a batch of graphs.

Computes the maximum of node features for each graph in the batch.

Parameters:
  • x (Array) – Node features [num_nodes, num_features]

  • batch (Optional[Array], default: None) – Batch indices for each node [num_nodes]

  • size (Optional[int], default: None) – Number of graphs in the batch (avoids computing max)

Returns:

Array – Graph-level features [batch_size, num_features]

Take element-wise maximum across all nodes in each graph.

Example:

from jraphx.nn.pool import global_max_pool

# Good for capturing dominant features
graph_features = global_max_pool(x, batch)

global_min_pool

global_min_pool(x: Array, batch: jax.Array | None = None, size: int | None = None) Array[source]

Optimized global min pooling over a batch of graphs.

Computes the minimum of node features for each graph in the batch.

Parameters:
  • x (Array) – Node features [num_nodes, num_features]

  • batch (Optional[Array], default: None) – Batch indices for each node [num_nodes]

  • size (Optional[int], default: None) – Number of graphs in the batch (avoids computing max)

Returns:

Array – Graph-level features [batch_size, num_features]

Take element-wise minimum across all nodes in each graph.

Example:

from jraphx.nn.pool import global_min_pool

# Less common but useful for specific tasks
graph_features = global_min_pool(x, batch)

Hierarchical Pooling Layers

Hierarchical pooling layers select important nodes and create coarsened graph representations.

TopKPooling

class TopKPooling(*args: Any, **kwargs: Any)[source]

Bases: Module

\(\mathrm{top}_k\) pooling operator from the “Graph U-Nets”, “Towards Sparse Hierarchical Graph Classifiers” and “Understanding Attention and Generalization in Graph Neural Networks” papers.

If min_score \(\tilde{\alpha}\) is None, computes:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \sigma \left( \frac{\mathbf{X}\mathbf{p}}{\| \mathbf{p} \|} \right)\\\mathbf{i} &= \mathrm{top}_k(\mathbf{y})\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}\end{aligned}\end{align} \]

If min_score \(\tilde{\alpha}\) is a value in [0, 1], computes:

\[ \begin{align}\begin{aligned}\mathbf{y} &= \mathrm{softmax}(\mathbf{X}\mathbf{p})\\\mathbf{i} &= \mathbf{y}_i > \tilde{\alpha}\\\mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}}\\\mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}},\end{aligned}\end{align} \]

where nodes are dropped based on a learnable projection score \(\mathbf{p}\).

Parameters:
  • num_features (int) – Size of each input sample.

  • ratio (float or int, optional) – The graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of ratio is float or int. This value is ignored if min_score is not None. (default: 0.5)

  • min_score (float, optional) – Minimal node score \(\tilde{\alpha}\) which is used to compute indices of pooled nodes \(\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}\). When this value is not None, the ratio argument is ignored. (default: None)

  • multiplier (float, optional) – Coefficient by which features gets multiplied after pooling. (default: 1.0)

  • nonlinearity (str, optional) – The nonlinearity to use. (default: "tanh")

  • rngs (Optional[Rngs], default: None) – Random number generators for initialization.

Top-K pooling layer that selects the most important nodes based on learnable scores.

Algorithm:

  1. Compute node scores using a learnable projection

  2. Select top-k nodes based on scores

  3. Update node features by multiplying with scores

  4. Filter edges to maintain graph connectivity

Example:

from jraphx.nn.pool import TopKPooling
import flax.nnx as nnx

# Select top 50% of nodes
pool = TopKPooling(
    in_features=64,
    ratio=0.5,
    min_score=None,  # Optional minimum score threshold
    multiplier=1.0,  # Score multiplier
    rngs=nnx.Rngs(0)
)

# Apply pooling
x_pool, edge_index_pool, edge_attr_pool, batch_pool, perm = pool(
    x, edge_index, edge_attr=edge_attr, batch=batch
)

# perm contains indices of selected nodes

Parameters Explained:

  • ratio: If < 1, fraction of nodes to keep; if >= 1, exact number of nodes

  • min_score: Minimum score threshold (nodes below are filtered)

  • multiplier: Multiply scores before selection (affects gradients)

SAGPooling

class SAGPooling(*args: Any, **kwargs: Any)[source]

Bases: TopKPooling

Self-Attention Graph Pooling layer.

From “Self-Attention Graph Pooling” (https://arxiv.org/abs/1904.08082)

An extension of TopKPooling that uses graph convolution to compute scores, making them aware of the graph structure.

Parameters:
  • num_features (int) – Number of input features

  • ratio (Union[float, int], default: 0.5) – Pooling ratio

  • gnn (str, default: 'gcn') – Type of GNN to use for score computation (‘gcn’, ‘gat’, ‘sage’)

  • min_score (Optional[float], default: None) – Minimum score threshold

  • multiplier (float, default: 1.0) – Score multiplier for features

  • nonlinearity (str, default: 'tanh') – Activation function

  • rngs (Optional[Rngs], default: None) – Random number generators

Self-attention graph pooling using GNN layers to compute importance scores.

Key Features:

  • Uses GNN layers (GCN, GAT, or SAGE) to compute scores

  • More expressive than simple projection

  • Can capture graph structure in scoring

Example:

from jraphx.nn.pool import SAGPooling
import flax.nnx as nnx

# SAGPooling with GCN scoring
pool = SAGPooling(
    in_features=64,
    ratio=0.5,
    gnn="gcn",  # Options: "gcn", "gat", "sage"
    min_score=None,
    multiplier=1.0,
    rngs=nnx.Rngs(0)
)

x_pool, edge_index_pool, edge_attr_pool, batch_pool, perm = pool(
    x, edge_index, edge_attr=edge_attr, batch=batch
)

# Using GAT for attention-based scoring
pool_gat = SAGPooling(
    in_features=64,
    ratio=0.3,
    gnn="gat",
    rngs=nnx.Rngs(0)
)

Pooling Strategies

Graph Classification Pipeline

from jraphx.nn.conv import GCNConv
from jraphx.nn.pool import TopKPooling, global_mean_pool
import flax.nnx as nnx

class GraphClassifier(nnx.Module):
    def __init__(self, in_features, num_classes, rngs):
        self.conv1 = GCNConv(in_features, 64, rngs=rngs)
        self.pool1 = TopKPooling(64, ratio=0.8, rngs=rngs)
        self.conv2 = GCNConv(64, 64, rngs=rngs)
        self.pool2 = TopKPooling(64, ratio=0.8, rngs=rngs)
        self.conv3 = GCNConv(64, 64, rngs=rngs)
        self.classifier = nnx.Linear(64, num_classes, rngs=rngs)
        self.dropout = nnx.Dropout(0.5, rngs=rngs)

    def __call__(self, x, edge_index, batch):
        # First GNN layer
        x = nnx.relu(self.conv1(x, edge_index))

        # First pooling
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, batch=batch)

        # Second GNN layer
        x = nnx.relu(self.conv2(x, edge_index))

        # Second pooling
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, batch=batch)

        # Third GNN layer
        x = nnx.relu(self.conv3(x, edge_index))

        # Global pooling
        x = global_mean_pool(x, batch)

        # Classification
        x = self.dropout(x)
        return self.classifier(x)

Multi-Scale Pooling

class MultiScaleGNN(nnx.Module):
    def __init__(self, in_features, out_features, rngs):
        self.conv1 = GCNConv(in_features, 64, rngs=rngs)
        self.conv2 = GCNConv(64, 64, rngs=rngs)
        self.conv3 = GCNConv(64, 64, rngs=rngs)
        self.pool = TopKPooling(64, ratio=0.5, rngs=rngs)
        self.lin = nnx.Linear(192, out_features, rngs=rngs)

    def __call__(self, x, edge_index, batch):
        # Compute representations at multiple scales
        x1 = nnx.relu(self.conv1(x, edge_index))
        g1 = global_mean_pool(x1, batch)

        # Pool and compute second scale
        x2, edge_index2, _, batch2, _ = self.pool(x1, edge_index, batch=batch)
        x2 = nnx.relu(self.conv2(x2, edge_index2))
        g2 = global_mean_pool(x2, batch2)

        # Pool again and compute third scale
        x3, edge_index3, _, batch3, _ = self.pool(x2, edge_index2, batch=batch2)
        x3 = nnx.relu(self.conv3(x3, edge_index3))
        g3 = global_mean_pool(x3, batch3)

        # Concatenate multi-scale features
        out = jnp.concatenate([g1, g2, g3], axis=-1)
        return self.lin(out)

Pooling Selection Guide

Global Pooling

Global Pooling Methods

Method

Properties

Advantages

Use Cases

add_pool

Sum aggregation

Preserves magnitude

Counting tasks

mean_pool

Average aggregation

Size invariant

Most common, stable

max_pool

Maximum values

Captures peaks

Dominant features

min_pool

Minimum values

Captures valleys

Outlier detection

Hierarchical Pooling

Hierarchical Pooling Methods

Method

Scoring Method

Complexity

Best For

TopKPooling

Linear projection

Low

Fast coarsening

SAGPooling (GCN)

GCN layer

Medium

Structure-aware

SAGPooling (GAT)

GAT layer

High

Attention-based

SAGPooling (SAGE)

SAGE layer

Medium

Neighbor aggregation

Performance Considerations

Memory Efficiency

# Aggressive pooling for memory efficiency
pool = TopKPooling(64, ratio=0.1, rngs=nnx.Rngs(42))  # Keep only 10% of nodes

# Gradual pooling for better gradients
pool1 = TopKPooling(64, ratio=0.8, rngs=nnx.Rngs(42))  # First layer: 80%
pool2 = TopKPooling(64, ratio=0.6, rngs=nnx.Rngs(42))  # Second layer: 60%

Batch Processing

from jraphx.data import Batch

# Efficient batched pooling
batch_data = Batch.from_data_list(graphs)

# Pool entire batch at once
pooled_x, pooled_edge_index, _, pooled_batch, perm = pool(
    batch_data.x,
    batch_data.edge_index,
    batch=batch_data.batch
)

JIT Compilation

import jax

@jax.jit
def pool_and_classify(x, edge_index, batch):
    # Pooling operations are JIT-compatible
    x_pool, edge_pool, _, batch_pool, _ = pool(x, edge_index, batch=batch)
    graph_features = global_mean_pool(x_pool, batch_pool)
    return classifier(graph_features)

Common Patterns

Differentiable Pooling

All pooling operations maintain differentiability:

def loss_fn(params, x, edge_index, batch, y):
    # Pooling in computation graph
    x_pool, edge_pool, _, batch_pool, _ = pool(x, edge_index, batch=batch)
    graph_rep = global_mean_pool(x_pool, batch_pool)
    pred = classifier(graph_rep)
    return jnp.mean((pred - y) ** 2)

# Gradients flow through pooling
grads = jax.grad(loss_fn)(params, x, edge_index, batch, y)

Attention Visualization

# Get pooling scores for visualization
pool = TopKPooling(64, ratio=0.5, rngs=nnx.Rngs(42))
x_pool, _, _, _, perm = pool(x, edge_index)

# perm contains indices of top nodes
# Can visualize which nodes were selected
selected_nodes = jnp.zeros(num_nodes)
selected_nodes = selected_nodes.at[perm].set(1.0)