Source code for jraphx.nn.conv.gatv2_conv

"""Graph Attention Network v2 (GATv2) layer implementation."""

from typing import Union

from flax.nnx import Dropout, Linear, Param, Rngs, initializers, leaky_relu
from jax import numpy as jnp

from jraphx.nn.conv.message_passing import MessagePassing
from jraphx.utils import scatter_add, scatter_softmax
from jraphx.utils.loop import add_self_loops as add_self_loops_fn


[docs]class GATv2Conv(MessagePassing): r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?" <https://arxiv.org/abs/2105.14491>`_ paper, which fixes the static attention problem of the standard GAT layer. Since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. Args: in_features (int or tuple): Size of each input sample, or tuple for bipartite graphs. A tuple corresponds to the sizes of source and target dimensionalities. out_features (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) fill_value (float, optional): The way to generate edge features of self-loops (in case :obj:`edge_dim != None`). (default: :obj:`0.0`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) share_weights (bool, optional): If set to :obj:`True`, the same matrix will be applied to the source and the target node of every edge. (default: :obj:`False`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) rngs: Random number generators for initialization. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` where :math:`H` is the number of heads. """ def __init__( self, in_features: Union[int, tuple[int, int]], out_features: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: int | None = None, fill_value: float = 0.0, bias: bool = True, share_weights: bool = False, residual: bool = False, rngs: Rngs | None = None, ): """Initialize the GATv2 layer.""" super().__init__(aggr="add") self.in_features = in_features self.out_features = out_features self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout_rate = dropout self._add_self_loops = add_self_loops self.edge_dim = edge_dim self.fill_value = fill_value self.share_weights = share_weights self.residual = residual # Linear transformations if isinstance(in_features, int): self.lin_l = Linear( in_features, heads * out_features, use_bias=bias, rngs=rngs, ) if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear( in_features, heads * out_features, use_bias=bias, rngs=rngs, ) else: # Bipartite graph with different source and target features self.lin_l = Linear( in_features[0], heads * out_features, use_bias=bias, rngs=rngs, ) if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear( in_features[1], heads * out_features, use_bias=bias, rngs=rngs, ) # Attention parameter (single vector per head) self.att = Param(initializers.glorot_uniform()(rngs.params(), (heads, out_features))) # Edge feature transformation if edge_dim is not None: self.lin_edge = Linear( edge_dim, heads * out_features, use_bias=False, rngs=rngs, ) else: self.lin_edge = None # Residual connection total_out_features = heads * out_features if concat else out_features if residual: res_in_features = in_features if isinstance(in_features, int) else in_features[1] self.res = Linear( res_in_features, total_out_features, use_bias=False, rngs=rngs, ) else: self.res = None # Bias (applied after aggregation) if bias and not isinstance(in_features, int): # For bipartite graphs, bias is handled by lin_l and lin_r self.bias = None elif bias: self.bias = Param(jnp.zeros((total_out_features,))) else: self.bias = None # Dropout if dropout > 0: self.dropout = Dropout(dropout, rngs=rngs) else: self.dropout = None def __call__( self, x: Union[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]], edge_index: jnp.ndarray, edge_attr: jnp.ndarray | None = None, return_attention_weights: bool = False, ) -> Union[jnp.ndarray, tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]]: """Forward pass of the GATv2 layer. Args: x: Node features [num_nodes, in_features] or tuple for bipartite graphs edge_index: Edge indices [2, num_edges] edge_attr: Optional edge features [num_edges, edge_dim] return_attention_weights: If True, return attention weights Returns: Updated node features [num_nodes, heads * out_features] if concat or [num_nodes, out_features] if not concat. If return_attention_weights is True, also returns (out, (edge_index, alpha)). """ H, C = self.heads, self.out_features # Handle input types and compute residual res = None if isinstance(x, tuple): x_l, x_r = x num_nodes = x_r.shape[0] if x_r is not None else x_l.shape[0] # Residual connection for target nodes if self.res is not None and x_r is not None: res = self.res(x_r) # Linear transformation x_l = self.lin_l(x_l).reshape(-1, H, C) x_r = self.lin_r(x_r).reshape(-1, H, C) if x_r is not None else x_l else: num_nodes = x.shape[0] # Residual connection if self.res is not None: res = self.res(x) # Linear transformation x_l = self.lin_l(x).reshape(-1, H, C) x_r = self.lin_r(x).reshape(-1, H, C) # Add self-loops if self._add_self_loops: edge_index, edge_attr = add_self_loops_fn( edge_index, edge_attr=edge_attr, fill_value=self.fill_value, num_nodes=num_nodes ) # Get edge endpoints row, col = edge_index[0], edge_index[1] # Get source and target features for edges x_i = x_r[col] # [num_edges, heads, out_features] x_j = x_l[row] # [num_edges, heads, out_features] # Key difference from GAT: combine features BEFORE applying attention x_combined = x_i + x_j # [num_edges, heads, out_features] # Add edge features if available if edge_attr is not None and self.lin_edge is not None: if edge_attr.ndim == 1: edge_attr = edge_attr.reshape(-1, 1) edge_feat = self.lin_edge(edge_attr) edge_feat = edge_feat.reshape(-1, H, C) x_combined = x_combined + edge_feat # Apply LeakyReLU (this is the key difference - applied after combination) x_combined = leaky_relu(x_combined, negative_slope=self.negative_slope) # Compute attention scores alpha = jnp.sum(x_combined * self.att.value, axis=-1) # [num_edges, heads] # Apply softmax using our optimized scatter_softmax num_edges = alpha.shape[0] alpha_flat = alpha.reshape(-1) # [num_edges * heads] # Create expanded index for each head col_expanded = jnp.repeat(col, self.heads) # Apply softmax alpha_flat = scatter_softmax(alpha_flat, col_expanded, dim_size=num_nodes) alpha = alpha_flat.reshape(num_edges, self.heads) # Apply dropout to attention coefficients if self.dropout is not None: alpha = self.dropout(alpha) # Apply attention weights to features weighted_features = x_j * alpha.reshape( -1, self.heads, 1 ) # [num_edges, heads, out_features] # Aggregate messages weighted_features_flat = weighted_features.reshape(-1, self.heads * self.out_features) out_flat = scatter_add(weighted_features_flat, col, dim_size=num_nodes) out = out_flat.reshape(num_nodes, self.heads, self.out_features) # Concatenate or average heads if self.concat: out = out.reshape(num_nodes, self.heads * self.out_features) else: out = out.mean(axis=1) # Add residual connection if res is not None: out = out + res # Add bias if self.bias is not None: out = out + self.bias.value if return_attention_weights: return out, (edge_index, alpha) return out