jraphx.utils

Utility functions and operations for graph processing and manipulation.

Scatter Operations

The scatter module provides efficient implementations of scatter operations for aggregating node features.

scatter_sum

scatter_mean

scatter_mean(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array[source]

Scatter mean operation - averages values from src at indices specified by index.

Parameters:
  • src (Array) – Source tensor to scatter

  • index (Array) – Indices where to scatter

  • dim_size (Optional[int], default: None) – Size of the output dimension

  • dim (int, default: -2) – Dimension along which to scatter

Returns:

Array – Tensor with scattered values

Scatter mean operation for averaging values by index.

Example:

from jraphx.utils.scatter import scatter_mean
import jax.numpy as jnp

src = jnp.array([1.0, 2.0, 3.0, 4.0])
index = jnp.array([0, 0, 1, 1])

# Average values by index
out = scatter_mean(src, index, dim_size=2)
# Result: [1.5, 3.5]

scatter_max

scatter_max(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array[source]

Scatter max operation - takes maximum of values from src at indices specified by index.

Parameters:
  • src (Array) – Source tensor to scatter

  • index (Array) – Indices where to scatter

  • dim_size (Optional[int], default: None) – Size of the output dimension

  • dim (int, default: -2) – Dimension along which to scatter

Returns:

Array – Tensor with scattered values

Scatter max operation for finding maximum values by index.

Example:

from jraphx.utils.scatter import scatter_max
import jax.numpy as jnp

src = jnp.array([1.0, 3.0, 2.0, 4.0])
index = jnp.array([0, 0, 1, 1])

# Find max values by index
out = scatter_max(src, index, dim_size=2)
# Result: [3.0, 4.0]

scatter_min

scatter_min(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array[source]

Scatter min operation - takes minimum of values from src at indices specified by index.

Parameters:
  • src (Array) – Source tensor to scatter

  • index (Array) – Indices where to scatter

  • dim_size (Optional[int], default: None) – Size of the output dimension

  • dim (int, default: -2) – Dimension along which to scatter

Returns:

Array – Tensor with scattered values

Scatter min operation for finding minimum values by index.

scatter_std

scatter_std(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array[source]

Scatter standard deviation - computes std of values at indices.

Uses the formula: std = sqrt(E[X^2] - E[X]^2)

Parameters:
  • src (Array) – Source tensor to scatter

  • index (Array) – Indices where to scatter

  • dim_size (Optional[int], default: None) – Size of the output dimension

  • dim (int, default: -2) – Dimension along which to scatter

Returns:

Array – Tensor with scattered standard deviations

Scatter standard deviation operation.

scatter

scatter(src: Array, index: Array, dim_size: int | None = None, dim: int = -2, reduce: str = 'add') Array[source]

Generic scatter operation using JAX’s optimized segment operations.

This function scatters values from src tensor at indices specified by index tensor, applying the specified reduction operation. Uses JAX’s built-in segment operations which are XLA-optimized for better performance on GPU/TPU.

Parameters:
  • src (Array) – Source tensor to scatter [*, N, *]

  • index (Array) – Indices where to scatter [N] or same shape as src

  • dim_size (Optional[int], default: None) – Size of the output dimension (inferred if None)

  • dim (int, default: -2) – Dimension along which to scatter (default: -2, which maps to 0)

  • reduce (str, default: 'add') – Reduction operation - “add”, “mean”, “max”, “min”

Returns:

Array – Output tensor with scattered values [*, dim_size, *]

Generic scatter operation with configurable reduction.

Parameters:

  • src: Source tensor to scatter

  • index: Index tensor for scattering

  • dim: Dimension to scatter along

  • dim_size: Size of the output dimension

  • reduce: Reduction operation (‘sum’, ‘mean’, ‘max’, ‘min’, ‘mul’)

Graph Utilities

degree

degree(index: Array, num_nodes: int | None = None, dtype: numpy.dtype | None = None) Array[source]

Computes the (unweighted) degree of a given one-dimensional index tensor.

Parameters:
  • index (jax.Array) – Index tensor.

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of index. (default: None)

  • dtype (jax.dtype, optional) – The desired data type of the returned tensor.

Returns:

Node degrees.

Return type:

jax.Array

Example

>>> import jax.numpy as jnp
>>> row = jnp.array([0, 1, 0, 2, 0])
>>> degree(row, dtype=jnp.int32)
Array([3, 1, 1], dtype=int32)

Compute the degree of each node in a graph.

Example:

from jraphx.utils import degree
import jax.numpy as jnp

edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])

# Compute in-degree and out-degree
in_deg = degree(edge_index[1], num_nodes=3)
out_deg = degree(edge_index[0], num_nodes=3)

to_undirected

to_undirected(edge_index: Array, edge_attr: jax.Array | None = None, num_nodes: int | None = None, reduce: str = 'add') tuple[jax.Array, jax.Array | None][source]

Converts the graph given by edge_index to an undirected graph such that \((j,i) \in \mathcal{E}\) for every edge \((i,j) \in \mathcal{E}\).

Parameters:
  • edge_index (jax.Array) – The edge indices.

  • edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default: None)

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of edge_index. (default: None)

  • reduce (str, optional) – The reduce operation to use for merging edge features ("add", "mean", "min", "max"). (default: "add")

Returns:

tuple[Array, Optional[Array]] – Tuple of (undirected edge_index, undirected edge_attr).

Convert a directed graph to undirected by adding reverse edges.

Example:

from jraphx.utils import to_undirected
import jax.numpy as jnp

edge_index = jnp.array([[0, 1], [1, 2]])
edge_attr = jnp.array([[1.0], [2.0]])

# Convert to undirected
edge_index_undirected, edge_attr_undirected = to_undirected(
    edge_index, edge_attr
)

add_self_loops

add_self_loops(edge_index: Array, edge_attr: jax.Array | None = None, fill_value: Union[float, str] = 1.0, num_nodes: int | None = None) tuple[jax.Array, jax.Array | None][source]

Adds a self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by edge_index. In case the graph is weighted and already contains self-loops, only non-existent self-loops will be added with edge weights denoted by fill_value.

Parameters:
  • edge_index (jax.Array) – The edge indices.

  • edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default: None)

  • fill_value (float or str, optional) – The way to generate edge features of self-loops. If float, edge features are set to this value. If str, edge features are computed by aggregating existing edge features that point to each node using the specified reduction (‘mean’, ‘add’, ‘max’, ‘min’). (default: 1.0)

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of edge_index. (default: None)

Returns:

tuple[Array, Optional[Array]] – Tuple of (edge_index with self-loops, edge_attr with self-loops).

Note

For JIT compatibility, num_nodes should be provided as a static integer when possible.

Add self-loop edges to a graph.

Example:

from jraphx.utils import add_self_loops
import jax.numpy as jnp

edge_index = jnp.array([[0, 1], [1, 2]])

# Add self-loops
edge_index_with_loops, edge_attr = add_self_loops(
    edge_index, num_nodes=3
)

remove_self_loops

remove_self_loops(edge_index: Array, edge_attr: jax.Array | None = None) tuple[jax.Array, jax.Array | None][source]

Remove self-loops from edge indices.

Parameters:
  • edge_index (Array) – Edge indices [2, num_edges]

  • edge_attr (Optional[Array], default: None) – Optional edge attributes [num_edges, *]

Returns:

tuple[Array, Optional[Array]] – Tuple of (edge_index without self-loops, edge_attr without self-loops)

Remove self-loop edges from a graph.

contains_self_loops

is_undirected

coalesce

coalesce(edge_index: Array, edge_attr: jax.Array | None = None, num_nodes: int | None = None, reduce: str = 'add') tuple[jax.Array, jax.Array | None][source]

Row-wise sorts edge_index and removes its duplicated entries. Duplicate entries in edge_attr are merged by scattering them together according to the given reduce option.

Parameters:
  • edge_index (jax.Array) – The edge indices.

  • edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default: None)

  • num_nodes (int, optional) – The number of nodes, i.e. max_val + 1 of edge_index. (default: None)

  • reduce (str, optional) – The reduce operation to use for merging edge features ("add", "mean", "min", "max"). (default: "add")

Returns:

tuple[Array, Optional[Array]] – Tuple of (coalesced edge_index, coalesced edge_attr).

Note

For JIT compatibility, num_nodes should be provided as a static integer when possible.

Remove duplicate edges and optionally sum their attributes.

Example:

from jraphx.utils import coalesce
import jax.numpy as jnp

# Graph with duplicate edges
edge_index = jnp.array([[0, 0, 1], [1, 1, 2]])
edge_attr = jnp.array([[1.0], [2.0], [3.0]])

# Remove duplicates and sum attributes
edge_index_clean, edge_attr_clean = coalesce(
    edge_index, edge_attr, reduce='sum'
)

Conversion Utilities

to_dense_adj

to_dense_adj(edge_index: Array, edge_attr: jax.Array | None = None, max_num_nodes: int | None = None) Array[source]

Convert edge indices to dense adjacency matrix.

Parameters:
  • edge_index (Array) – Edge indices [2, num_edges]

  • edge_attr (Optional[Array], default: None) – Optional edge attributes [num_edges] or [num_edges, num_features]

  • max_num_nodes (Optional[int], default: None) – Maximum number of nodes (for padding)

Returns:

Array – Dense adjacency matrix [num_nodes, num_nodes] or [num_nodes, num_nodes, num_features]

Convert edge indices to a dense adjacency matrix.

Example:

from jraphx.utils import to_dense_adj
import jax.numpy as jnp

edge_index = jnp.array([[0, 1, 2], [1, 2, 0]])

# Convert to dense adjacency matrix
adj = to_dense_adj(edge_index, num_nodes=3)

to_edge_index

to_edge_index(adj: Array) tuple[jax.Array, jax.Array | None][source]

Convert adjacency matrix to edge indices.

Parameters:

adj (Array) – Adjacency matrix [num_nodes, num_nodes] or [num_nodes, num_nodes, num_features]

Returns:

tuple[Array, Optional[Array]] – Tuple of (edge_index [2, num_edges], edge_attr [num_edges] or [num_edges, num_features])

Convert adjacency representation to edge index format.