jraphx.utils
Utility functions and operations for graph processing and manipulation.
Scatter Operations
The scatter module provides efficient implementations of scatter operations for aggregating node features.
scatter_sum
scatter_mean
- scatter_mean(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array [source]
Scatter mean operation - averages values from src at indices specified by index.
- Parameters:
- Returns:
Array
– Tensor with scattered values
Scatter mean operation for averaging values by index.
Example:
from jraphx.utils.scatter import scatter_mean import jax.numpy as jnp src = jnp.array([1.0, 2.0, 3.0, 4.0]) index = jnp.array([0, 0, 1, 1]) # Average values by index out = scatter_mean(src, index, dim_size=2) # Result: [1.5, 3.5]
scatter_max
- scatter_max(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array [source]
Scatter max operation - takes maximum of values from src at indices specified by index.
- Parameters:
- Returns:
Array
– Tensor with scattered values
Scatter max operation for finding maximum values by index.
Example:
from jraphx.utils.scatter import scatter_max import jax.numpy as jnp src = jnp.array([1.0, 3.0, 2.0, 4.0]) index = jnp.array([0, 0, 1, 1]) # Find max values by index out = scatter_max(src, index, dim_size=2) # Result: [3.0, 4.0]
scatter_min
- scatter_min(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array [source]
Scatter min operation - takes minimum of values from src at indices specified by index.
- Parameters:
- Returns:
Array
– Tensor with scattered values
Scatter min operation for finding minimum values by index.
scatter_std
- scatter_std(src: Array, index: Array, dim_size: int | None = None, dim: int = -2) Array [source]
Scatter standard deviation - computes std of values at indices.
Uses the formula: std = sqrt(E[X^2] - E[X]^2)
- Parameters:
- Returns:
Array
– Tensor with scattered standard deviations
Scatter standard deviation operation.
scatter
- scatter(src: Array, index: Array, dim_size: int | None = None, dim: int = -2, reduce: str = 'add') Array [source]
Generic scatter operation using JAX’s optimized segment operations.
This function scatters values from src tensor at indices specified by index tensor, applying the specified reduction operation. Uses JAX’s built-in segment operations which are XLA-optimized for better performance on GPU/TPU.
- Parameters:
src (
Array
) – Source tensor to scatter [*, N, *]index (
Array
) – Indices where to scatter [N] or same shape as srcdim_size (
Optional
[int
], default:None
) – Size of the output dimension (inferred if None)dim (
int
, default:-2
) – Dimension along which to scatter (default: -2, which maps to 0)reduce (
str
, default:'add'
) – Reduction operation - “add”, “mean”, “max”, “min”
- Returns:
Array
– Output tensor with scattered values [*, dim_size, *]
Generic scatter operation with configurable reduction.
Parameters:
src: Source tensor to scatter
index: Index tensor for scattering
dim: Dimension to scatter along
dim_size: Size of the output dimension
reduce: Reduction operation (‘sum’, ‘mean’, ‘max’, ‘min’, ‘mul’)
Graph Utilities
degree
- degree(index: Array, num_nodes: int | None = None, dtype: numpy.dtype | None = None) Array [source]
Computes the (unweighted) degree of a given one-dimensional index tensor.
- Parameters:
- Returns:
Node degrees.
- Return type:
jax.Array
Example
>>> import jax.numpy as jnp >>> row = jnp.array([0, 1, 0, 2, 0]) >>> degree(row, dtype=jnp.int32) Array([3, 1, 1], dtype=int32)
Compute the degree of each node in a graph.
Example:
from jraphx.utils import degree import jax.numpy as jnp edge_index = jnp.array([[0, 1, 2], [1, 2, 0]]) # Compute in-degree and out-degree in_deg = degree(edge_index[1], num_nodes=3) out_deg = degree(edge_index[0], num_nodes=3)
to_undirected
- to_undirected(edge_index: Array, edge_attr: jax.Array | None = None, num_nodes: int | None = None, reduce: str = 'add') tuple[jax.Array, jax.Array | None] [source]
Converts the graph given by
edge_index
to an undirected graph such that \((j,i) \in \mathcal{E}\) for every edge \((i,j) \in \mathcal{E}\).- Parameters:
edge_index (jax.Array) – The edge indices.
edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)reduce (str, optional) – The reduce operation to use for merging edge features (
"add"
,"mean"
,"min"
,"max"
). (default:"add"
)
- Returns:
tuple
[Array
,Optional
[Array
]] – Tuple of (undirected edge_index, undirected edge_attr).
Convert a directed graph to undirected by adding reverse edges.
Example:
from jraphx.utils import to_undirected import jax.numpy as jnp edge_index = jnp.array([[0, 1], [1, 2]]) edge_attr = jnp.array([[1.0], [2.0]]) # Convert to undirected edge_index_undirected, edge_attr_undirected = to_undirected( edge_index, edge_attr )
add_self_loops
- add_self_loops(edge_index: Array, edge_attr: jax.Array | None = None, fill_value: Union[float, str] = 1.0, num_nodes: int | None = None) tuple[jax.Array, jax.Array | None] [source]
Adds a self-loop \((i,i) \in \mathcal{E}\) to every node \(i \in \mathcal{V}\) in the graph given by
edge_index
. In case the graph is weighted and already contains self-loops, only non-existent self-loops will be added with edge weights denoted byfill_value
.- Parameters:
edge_index (jax.Array) – The edge indices.
edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default:
None
)fill_value (float or str, optional) – The way to generate edge features of self-loops. If float, edge features are set to this value. If str, edge features are computed by aggregating existing edge features that point to each node using the specified reduction (‘mean’, ‘add’, ‘max’, ‘min’). (default:
1.0
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)
- Returns:
tuple
[Array
,Optional
[Array
]] – Tuple of (edge_index with self-loops, edge_attr with self-loops).
Note
For JIT compatibility,
num_nodes
should be provided as a static integer when possible.Add self-loop edges to a graph.
Example:
from jraphx.utils import add_self_loops import jax.numpy as jnp edge_index = jnp.array([[0, 1], [1, 2]]) # Add self-loops edge_index_with_loops, edge_attr = add_self_loops( edge_index, num_nodes=3 )
remove_self_loops
- remove_self_loops(edge_index: Array, edge_attr: jax.Array | None = None) tuple[jax.Array, jax.Array | None] [source]
Remove self-loops from edge indices.
- Parameters:
edge_index (
Array
) – Edge indices [2, num_edges]edge_attr (
Optional
[Array
], default:None
) – Optional edge attributes [num_edges, *]
- Returns:
tuple
[Array
,Optional
[Array
]] – Tuple of (edge_index without self-loops, edge_attr without self-loops)
Remove self-loop edges from a graph.
contains_self_loops
is_undirected
coalesce
- coalesce(edge_index: Array, edge_attr: jax.Array | None = None, num_nodes: int | None = None, reduce: str = 'add') tuple[jax.Array, jax.Array | None] [source]
Row-wise sorts
edge_index
and removes its duplicated entries. Duplicate entries inedge_attr
are merged by scattering them together according to the givenreduce
option.- Parameters:
edge_index (jax.Array) – The edge indices.
edge_attr (jax.Array, optional) – Edge weights or multi-dimensional edge features. (default:
None
)num_nodes (int, optional) – The number of nodes, i.e.
max_val + 1
ofedge_index
. (default:None
)reduce (str, optional) – The reduce operation to use for merging edge features (
"add"
,"mean"
,"min"
,"max"
). (default:"add"
)
- Returns:
tuple
[Array
,Optional
[Array
]] – Tuple of (coalesced edge_index, coalesced edge_attr).
Note
For JIT compatibility,
num_nodes
should be provided as a static integer when possible.Remove duplicate edges and optionally sum their attributes.
Example:
from jraphx.utils import coalesce import jax.numpy as jnp # Graph with duplicate edges edge_index = jnp.array([[0, 0, 1], [1, 1, 2]]) edge_attr = jnp.array([[1.0], [2.0], [3.0]]) # Remove duplicates and sum attributes edge_index_clean, edge_attr_clean = coalesce( edge_index, edge_attr, reduce='sum' )
Conversion Utilities
to_dense_adj
- to_dense_adj(edge_index: Array, edge_attr: jax.Array | None = None, max_num_nodes: int | None = None) Array [source]
Convert edge indices to dense adjacency matrix.
- Parameters:
edge_index (
Array
) – Edge indices [2, num_edges]edge_attr (
Optional
[Array
], default:None
) – Optional edge attributes [num_edges] or [num_edges, num_features]max_num_nodes (
Optional
[int
], default:None
) – Maximum number of nodes (for padding)
- Returns:
Array
– Dense adjacency matrix [num_nodes, num_nodes] or [num_nodes, num_nodes, num_features]
Convert edge indices to a dense adjacency matrix.
Example:
from jraphx.utils import to_dense_adj import jax.numpy as jnp edge_index = jnp.array([[0, 1, 2], [1, 2, 0]]) # Convert to dense adjacency matrix adj = to_dense_adj(edge_index, num_nodes=3)
to_edge_index
- to_edge_index(adj: Array) tuple[jax.Array, jax.Array | None] [source]
Convert adjacency matrix to edge indices.
- Parameters:
adj (
Array
) – Adjacency matrix [num_nodes, num_nodes] or [num_nodes, num_nodes, num_features]- Returns:
tuple
[Array
,Optional
[Array
]] – Tuple of (edge_index [2, num_edges], edge_attr [num_edges] or [num_edges, num_features])
Convert adjacency representation to edge index format.