The Intelligible and Effective Graph Neural Additive Networks

Maya Bechler-Speicher
Blavatnik School of Computer Science
Tel-Aviv University
&Amir Globerson
Blavatnik School of Computer Science
Tel-Aviv University
&Ran Gilad-Bachrach
Department of Bio-Medical Engineering
Edmond J. Safra Center for Bioinformatics
Tel-Aviv University
Abstract

Graph Neural Networks (GNNs) have emerged as the predominant approach for learning over graph-structured data. However, most GNNs operate as black-box models and require post-hoc explanations, which may not suffice in high-stakes scenarios where transparency is crucial. In this paper, we present a GNN that is interpretable by design. Our model, Graph Neural Additive Network (GNAN), is a novel extension of the interpretable class of Generalized Additive Models, and can be visualized and fully understood by humans. GNAN is designed to be fully interpretable, allowing both global and local explanations at the feature and graph levels through direct visualization of the model. These visualizations describe the exact way the model uses the relationships between the target variable, the features, and the graph. We demonstrate the intelligibility of GNANs in a series of examples on different tasks and datasets. In addition, we show that the accuracy of GNAN is on par with black-box GNNs, making it suitable for critical applications where transparency is essential, alongside high accuracy.

1 Introduction

In many domains, ranging from biology to fraud detection, Artificial Intelligence (AI) is applied to data with graph structure. Neural Networks, and specifically Graph Neural Networks (GNNs), have emerged as the predominant approach in these applications (see, for example, Zhou et al. [1]). While GNNs demonstrate high accuracy, in terms of the correctness of their predictions, they often function as black-box models; thus, their decision-making processes are opaque. Transparency is vital for assessing potential biases or safety risks and is particularly critical in high-stakes areas such as criminal justice, healthcare, and finance, where decisions significantly impact individuals’ lives. In such contexts, interpretable models, despite sometimes being less accurate, may be preferred over complex black-box models [2]. Furthermore, the transparency of automated decision making processes is increasingly becoming a legal mandate. While there is ongoing debate over whether the European Union’s General Data Protection Regulation (GDPR) implies a “right to explanation” [3, 4], the proposed European AI Act explicitly addresses this issue, stating that “To address concerns related to opacity and complexity of certain AI systems and help deployers to fulfill their obligations under this regulation, transparency should be required for high-risk AI systems before they are placed on the market or put into service” [5].

In this context, interpretability refers to the ease with which a human can understand the reasoning behind model decisions or the general logic of a model’s operation. It is important to distinguish between interpretability and explainability [2]. Interpretability relates to models that are inherently comprehensible by design, while explainability pertains to post-hoc methods that elucidate aspects of black-box models [6]. These explanations often come without correctness guarantees [7, 8] and may not provide a complete description of the model and its predictions, potentially failing to expose hidden pitfalls [9, 10, 11].

Methods for model explainability or interpretability can be categorized into local and global types. Local methods, such as SHAP [12] and LIME [13], elucidate individual predictions made by a model, whereas global methods, such as feature-importance [14] and partial dependence plots [15], provide holistic insights about the model, i.e., explain the overarching logic of the model decision making [16]. However, it has been noted that local explainability methods may not consistently align with their global counterparts [17]. Moreover, local explanations may be inadequate for verifying fairness and other risks [8].

In this work, we introduce the Graph Neural Additive Networks (GNAN), an interpretable-by-design GNN that offers both transparency and accuracy. GNAN is a glass-box model [18] that allows for both local and global interpretability. GNAN extends the family of Generalized Additive Models (GAMs) [19], to accommodate graph data. GAMs are known for their ability to fit complex, nonlinear functions while remaining interpretable and have proven effective across various domains [20, 21, 22, 23, 24]. They operate by learning shape functions for each feature and then linearly combining these functions, making it easy to interpret them, as the influence of each feature on the prediction is independent of other features and can be visualized through their corresponding shape functions. Similarly, GNAN’s interpretability is achieved through an architecture that restricts the use of cross-products of features and graphs’ topology, thereby reducing its complexity compared to other GNNs. Nonetheless, we demonstrate that GNAN, despite its limited capacity, matches the performance of more expressive GNNs on several real-world datasets. Additionally, GNAN does not rely on iterative local message-passing, avoiding the computational bottlenecks commonly associated with such GNNs [25].

In Section 4, we showcase through a series of examples how users can interpret GNAN and gain precise insights into the connections between the target and the graph, the target and the features, and the interplay between features and graph information. In some cases, an exact description of the model can be visualized through only a few figures. We also demonstrate how the interpretability of GNAN allows users to debug their model, a process that can be used for ensuring consistency with prior knowledge and avoiding biases and safety risks. In Section 5, we compare the performance of GNAN with other GNN architectures. This comparison underscores that sacrificing performance for intelligibility is not necessary, as the performance of GNAN is comparable to that of commonly used black-box GNNs.

The main contributions of this work are:

  1. 1.

    The extension of Generalized Additive Models (GAMs) to graph data.

  2. 2.

    The introduction of a fully interpretable-by-design model for graph prediction tasks, demonstrating that its explanations provide both global and local insights, through visualizations of the model itself, and include debugging capabilities.

  3. 3.

    The demonstration that GNAN achieves good performance on common real-world graph datasets, despite its limited capacity. This observation supports previous findings that some real-world graph problems are simple and do not require the capacity of other GNNs.

Thus we argue that GNAN is suitable for high-stakes applications due to its interpretability and performance.

2 Related work

Generalized Additive Models

Generalized Additive Models (GAMs) are a class of statistical models that build upon generalized linear models by incorporating non-linear functions for each variable while maintaining additivity [19, 20, 21]. Essentially, GAMs model the expected value of the target variable as a sum of univariate functions of the features. Formally, in GAMs, a prediction for an input x𝑥xitalic_x is computed by σ(fk(xk))𝜎subscript𝑓𝑘subscript𝑥𝑘\sigma\left(\sum f_{k}\left(x_{k}\right)\right)italic_σ ( ∑ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) where σ𝜎\sigmaitalic_σ is a predefined activation function, such as the sigmoid111In the context of Generalized Linear Models (GLMs) σ𝜎\sigmaitalic_σ can be though of as the inverse of the link-function., and the fksubscript𝑓𝑘f_{k}italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT’s are shape functions learned during the training process. This approach extends generalized linear models, in which predictions are computed by σ(wkxk)𝜎subscript𝑤𝑘subscript𝑥𝑘\sigma\left(\sum w_{k}x_{k}\right)italic_σ ( ∑ italic_w start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) where w𝑤witalic_w is a learned weight vector.

GAMs are more expressive than generalized linear models while remaining interpretable, as the effect of each predictor is modeled separately. For example, they can capture non-monotone effects of features, which generalized linear models cannot achieve without feature engineering. Traditionally, GAMs utilize splines or other smooth shape functions to model the non-linear relationships between each feature and the target variable. However, other methods, such as trees, have been proposed to fit the shape functions [24]. Recently, Agarwal et al. [26] suggested using neural networks to learn the shape functions. This approach combines the representational power of deep learning with the interpretability of additive models.

Graph Neural Networks

Graph Neural Networks (GNNs) [27, 28, 29, 30] have emerged as the leading approach for learning over graph data. The fundamental idea behind GNNs is to use neural-networks that combine node features with graph-structure. A commonly used family of GNNs is message-passing GNNs, where the representations of nodes are updated in iterations through neighborhood aggregations. This aggregation is done, for example, through a convolution-like operation or an attention mechanism. [31, 32, 33]

Various non-message-passing approaches have been explored to disentangle the node features from the graph structure. Such approaches were shown to enhance performance across diverse applications [34, 35, 36]. Disentanglement can also reduce overfitting, as popular GNNs which do entangle features and graph-structure, were shown to have the tendency to overfit non-informative graph information [37] GNAN uses these concepts in order to achieve a model that is both high-performing and fully interpretable.

There are different prediction tasks on graphs [38]. In graph tasks, the goal is to predict properties of entire graphs. For example, a graph could represent a molecule, and the goal would be to predict its toxicity level. In node tasks, the goal is to predict a property of a node (vertex) within a graph. An example of a node task is predicting whether a user in a social network is a human or a bot. In link prediction tasks, the goal is to determine whether there is an edge between two nodes of a graph. In this work, we focus on graph tasks and node tasks. Although link prediction tasks are not within the scope of this work, it is possible to view these problems as node tasks on the dual line graph [39].

GNNs explanations

The inherent complexity of graph-structured data poses unique challenges for explainability. Most approaches for explaining black-box GNNs focus on providing a sub-graph or a similar structure that can explain a certain example. This is done either as a post-hoc explanation for GNNs [40, 41, 42] or by adjusting the data a priory  [43, 44, 45]. For example, the method suggested in Ying et al. [41] identifies both important subgraph structures and node features influencing the GNN’s predictions by maximizing the mutual information between the prediction and the distribution of possible subgraph structures and node features. Yin et al. [43] suggested a structural pattern learning module that is learning through pre-training. GNAN, on the contrary to these methods, does not aim to provide an explanation through a proxy object like a subgraph, nor does it require modification to the data, or the training process. Instead, GNAN is a interpretable by design, and its exact description can be visualized through its learned shape functions. In particular, the exact relation between the target, the features, and the graph can be visualized and conveyed to users.

3 Graph Neural Additive Networks

In this section, we introduce the Graph Neural Additive Networks (GNAN). We begin by defining some essential notation. A graph G𝐺Gitalic_G has a set of N𝑁Nitalic_N vertices, where each vertex is associated with a d𝑑ditalic_d-dimensional feature vector. Specifically, 𝐱idsubscript𝐱𝑖superscript𝑑\mathbf{x}_{i}\in\mathbb{R}^{d}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT represents the feature vector of the i𝑖iitalic_i’th node in G𝐺Gitalic_G. We define the distance dist(j,i)dist𝑗𝑖\text{dist}\left(j,i\right)dist ( italic_j , italic_i ) between node j𝑗jitalic_j and node i𝑖iitalic_i within the graph G𝐺Gitalic_G as the number of edges in the shortest path from j𝑗jitalic_j to i𝑖iitalic_i. This definition implies that the distance from a node to itself is zero. In cases where no path exists from j𝑗jitalic_j to i𝑖iitalic_i, we set dist(j,i)=dist𝑗𝑖\text{dist}\left(j,i\right)=\inftydist ( italic_j , italic_i ) = ∞. For enhanced readability, we denote vectors in boldface, and an entry k𝑘kitalic_k of a vector 𝐱𝐱\mathbf{x}bold_x is denoted by [𝐱]ksubscriptdelimited-[]𝐱𝑘[\mathbf{x}]_{k}[ bold_x ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. We begin by describing GNAN for applications such as binary classification and regression where the model output is one-dimensional. At the end of this section, we discuss extensions to scenarios such as multi-class classification, where the model output is multi-dimensional.

GNAN generates a representation 𝐡idsubscript𝐡𝑖superscript𝑑\mathbf{h}_{i}\in\mathbb{R}^{d}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT for each node i𝑖iitalic_i by learning a distance function ρ(x;θ)::𝜌𝑥𝜃\rho(x;\theta):\mathbb{R}\rightarrow\mathbb{R}italic_ρ ( italic_x ; italic_θ ) : blackboard_R → blackboard_R and a set of feature shape functions {fk}k=1d,fk(x;θk)::superscriptsubscriptsubscript𝑓𝑘𝑘1𝑑subscript𝑓𝑘𝑥subscript𝜃𝑘\{f_{k}\}_{k=1}^{d},f_{k}(x;\theta_{k}):\mathbb{R}\rightarrow\mathbb{R}{ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) : blackboard_R → blackboard_R. Each of these functions is a neural network, and is optimized through back-propagation. For brevity, we omit the parameterization θ𝜃\thetaitalic_θ and θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for the remainder of this section. In GNAN, the k𝑘kitalic_k’th entry of the representation 𝐡isubscript𝐡𝑖\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for the i𝑖iitalic_i’th node is defined as follows:

[𝐡i]k=j=1N1#disti(j,i)ρ(11+dist(j,i))fk([𝐱j]k)subscriptdelimited-[]subscript𝐡𝑖𝑘superscriptsubscript𝑗1𝑁1#subscriptdist𝑖𝑗𝑖𝜌11dist𝑗𝑖subscript𝑓𝑘subscriptdelimited-[]subscript𝐱𝑗𝑘[\mathbf{h}_{i}]_{k}=\sum_{j=1}^{N}\frac{1}{\#\text{dist}_{i}(j,i)}\cdot\rho% \left(\frac{1}{1+\text{dist}\left(j,i\right)}\right)\cdot f_{k}\left(\left[% \mathbf{x}_{j}\right]_{k}\right)[ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG # dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) end_ARG ⋅ italic_ρ ( divide start_ARG 1 end_ARG start_ARG 1 + dist ( italic_j , italic_i ) end_ARG ) ⋅ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( [ bold_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )

where #disti(j,i)#subscriptdist𝑖𝑗𝑖\#\text{dist}_{i}(j,i)# dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) represents the number of nodes at distance dist(j,i)dist𝑗𝑖\text{dist}\left(j,i\right)dist ( italic_j , italic_i ) from node i𝑖iitalic_i. The underlying rationale for this definition is as follows: each node’s k𝑘kitalic_k’th feature is transformed by a shape function fksubscript𝑓𝑘f_{k}italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, independently from other features. The effect the k𝑘kitalic_k’th feature value of node j𝑗jitalic_j has on node i𝑖iitalic_i’s representation is influenced by their distance. Specifically, if dist(j,i)=ldist𝑗𝑖𝑙\text{dist}\left(j,i\right)=ldist ( italic_j , italic_i ) = italic_l, then the cumulative influence of all nodes at distance l𝑙litalic_l from node i𝑖iitalic_i is captured by ρ(1/(1+l))𝜌11𝑙\rho(\nicefrac{{1}}{{(1+l)}})italic_ρ ( / start_ARG 1 end_ARG start_ARG ( 1 + italic_l ) end_ARG ). This is achieved by the normalization term 1/#disti(j,i)1#subscriptdist𝑖𝑗𝑖\nicefrac{{1}}{{\#\text{dist}_{i}(j,i)}}/ start_ARG 1 end_ARG start_ARG # dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) end_ARG. Here, ρ𝜌\rhoitalic_ρ’s argument 1/(1+l)11𝑙\nicefrac{{1}}{{(1+l)}}/ start_ARG 1 end_ARG start_ARG ( 1 + italic_l ) end_ARG scales the distance such that a distance of 00 (the self-distance of a node) is mapped to 1111, and an infinite distance, which implies no path exists, is scaled to 00. Thus, ρ𝜌\rhoitalic_ρ spans the interval [0,1]01[0,1][ 0 , 1 ].

The representation of each node is dependent on the entire graph, yet the function ρ𝜌\rhoitalic_ρ enables weighting the influence from nodes, based on their distance. This enables, for example, diminishing the impact of distant nodes, or close neighbors. For each node i𝑖iitalic_i, the weighted sum of the transformed feature vectors of all other nodes is computed, with weights assigned according to their distance from i𝑖iitalic_i. This weighted sum is computed after the shape functions are applied to the distances and the features of each node.

Given the node representations, both node prediction tasks and graph prediction tasks can be implemented. To predict for the i𝑖iitalic_i’th node, the computation is as follows:

σ(k=1d[𝐡i]k),𝜎superscriptsubscript𝑘1𝑑subscriptdelimited-[]subscript𝐡𝑖𝑘\sigma\left(\sum_{k=1}^{d}[\mathbf{h}_{i}]_{k}\right)~{}~{}~{},italic_σ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT [ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ,

where the entry-wise sum of the representation vector 𝐡isubscript𝐡𝑖\mathbf{h}_{i}bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is computed and subsequently processed using an activation function such as the sigmoid for classification and the identity for regression. For a prediction over the entire graph, the collective node representation is computed via sum-pooling:

𝐡=i=1N𝐡i.𝐡superscriptsubscript𝑖1𝑁subscript𝐡𝑖\mathbf{h}=\sum_{i=1}^{N}\mathbf{h}_{i}~{}~{}~{}.bold_h = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT .

Following this, the entry-wise sum of the graph representation 𝐡𝐡\mathbf{h}bold_h is computed and also processed using the activation function:

σ(k=1d[𝐡]k).𝜎superscriptsubscript𝑘1𝑑subscriptdelimited-[]𝐡𝑘\sigma\left(\sum_{k=1}^{d}[\mathbf{h}]_{k}\right)~{}~{}~{}.italic_σ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT [ bold_h ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) .

Once the model is trained, it can be fully described using its univariate functions ρ𝜌\rhoitalic_ρ and {fk}k=1dsuperscriptsubscriptsubscript𝑓𝑘𝑘1𝑑\{f_{k}\}_{k=1}^{d}{ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

From the definitions provided above, it follows that the entire model can be represented with just a few figures, thus providing global interpretability. For local explanations, it is feasible to examine the contribution of each feature and each node to the predictions. Furthermore, consider the following graph representation:

[𝐡]k=i=1N[𝐡i]k=subscriptdelimited-[]𝐡𝑘superscriptsubscript𝑖1𝑁subscriptdelimited-[]subscript𝐡𝑖𝑘absent\displaystyle[\mathbf{h}]_{k}=\sum_{i=1}^{N}[\mathbf{h}_{i}]_{k}=[ bold_h ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT [ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = i=1Nj=1N1#disti(j,i)ρ(11+dist(j,i))fk([𝐱j]k)superscriptsubscript𝑖1𝑁superscriptsubscript𝑗1𝑁1#subscriptdist𝑖𝑗𝑖𝜌11dist𝑗𝑖subscript𝑓𝑘subscriptdelimited-[]subscript𝐱𝑗𝑘\displaystyle\sum_{i=1}^{N}\sum_{j=1}^{N}\frac{1}{\#\text{dist}_{i}(j,i)}\cdot% \rho\left(\frac{1}{1+\text{dist}\left(j,i\right)}\right)\cdot f_{k}\left(\left% [\mathbf{x}_{j}\right]_{k}\right)∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG # dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) end_ARG ⋅ italic_ρ ( divide start_ARG 1 end_ARG start_ARG 1 + dist ( italic_j , italic_i ) end_ARG ) ⋅ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( [ bold_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
=j=1Nfk([𝐱j]k)i=1N1#disti(j,i)ρ(11+dist(j,i)).absentsuperscriptsubscript𝑗1𝑁subscript𝑓𝑘subscriptdelimited-[]subscript𝐱𝑗𝑘superscriptsubscript𝑖1𝑁1#subscriptdist𝑖𝑗𝑖𝜌11dist𝑗𝑖\displaystyle=\sum_{j=1}^{N}f_{k}\left(\left[\mathbf{x}_{j}\right]_{k}\right)% \sum_{i=1}^{N}\frac{1}{\#\text{dist}_{i}(j,i)}\cdot\rho\left(\frac{1}{1+\text{% dist}\left(j,i\right)}\right)~{}~{}~{}.= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( [ bold_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG # dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) end_ARG ⋅ italic_ρ ( divide start_ARG 1 end_ARG start_ARG 1 + dist ( italic_j , italic_i ) end_ARG ) .

This mathematical formulation reveals how each node and feature contribute to the overall graph representation. The elements in the first equation convey the effect of nodes, while the elements of the second equation convey the influence of individual features across all nodes. Therefore, the model facilitates a detailed understanding of local behavior from multiple perspectives.

The functions ρ𝜌\rhoitalic_ρ and {fk}k=1dsuperscriptsubscriptsubscript𝑓𝑘𝑘1𝑑\{f_{k}\}_{k=1}^{d}{ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT may be implemented using a variety of neural network architectures. In our experiments, we employed multi-layer perceptrons (MLPs) with ReLU activations to implement these functions. Nonetheless, other alternatives are viable, such as employing learning splines for activations to achieve smoother shape functions [46]. Additionally, it is feasible to develop a separate distance network for each feature to enhance the model’s capacity. Specifically, rather than utilizing a single function ρ𝜌\rhoitalic_ρ, one can train a distinct function ρksubscript𝜌𝑘\rho_{k}italic_ρ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for each feature k𝑘kitalic_k, which weights the contribution of each feature based on its node’s distance. For graph-level tasks, additional feature networks may be integrated prior to aggregating the graph’s representation vector, akin to a readout layer in GNNs. These extensions, along with a discussion on a tensor representation of the computation that facilitates efficient GPU utilization, are further elaborated in the Appendix.

For multiclass classification involving C𝐶Citalic_C classes, we configure the final layers of the feature shape functions fk(x;θk):C×1:subscript𝑓𝑘𝑥subscript𝜃𝑘superscript𝐶1f_{k}(x;\theta_{k}):\mathbb{R}\rightarrow\mathbb{R}^{C\times 1}italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) : blackboard_R → blackboard_R start_POSTSUPERSCRIPT italic_C × 1 end_POSTSUPERSCRIPT and the distance function ρ(x;θ):C×1:𝜌𝑥𝜃superscript𝐶1\rho(x;\theta):\mathbb{R}\rightarrow\mathbb{R}^{C\times 1}italic_ρ ( italic_x ; italic_θ ) : blackboard_R → blackboard_R start_POSTSUPERSCRIPT italic_C × 1 end_POSTSUPERSCRIPT to accommodate the required dimensionality. The transformed feature vectors and the distance metrics are combined using an element-wise multiplication denoted by direct-product\odot, as follows:

[𝐡i]k=j=1N1#disti(j,i)ρ(11+dist(j,i))fk([𝐱j]k).subscriptdelimited-[]subscript𝐡𝑖𝑘superscriptsubscript𝑗1𝑁direct-product1#subscriptdist𝑖𝑗𝑖𝜌11dist𝑗𝑖subscript𝑓𝑘subscriptdelimited-[]subscript𝐱𝑗𝑘[\mathbf{h}_{i}]_{k}=\sum_{j=1}^{N}\frac{1}{\#\text{dist}_{i}(j,i)}\cdot\rho% \left(\frac{1}{1+\text{dist}\left(j,i\right)}\right)\odot f_{k}([\mathbf{x}_{j% }]_{k})~{}~{}~{}.[ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG # dist start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_j , italic_i ) end_ARG ⋅ italic_ρ ( divide start_ARG 1 end_ARG start_ARG 1 + dist ( italic_j , italic_i ) end_ARG ) ⊙ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( [ bold_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) .

For prediction purposes, the sum operator is applied independently across the dimensions corresponding to each class, and a softmax is employed as the activation function.

4 Inteligibility

In this section, we demonstrate the intelligibility of GNAN through visualizations. Each GNAN model is characterized by the univariate learned shape functions ρ𝜌\rhoitalic_ρ and {fk}k=1dsuperscriptsubscriptsubscript𝑓𝑘𝑘1𝑑\{f_{k}\}_{k=1}^{d}{ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and can thus be depicted as a set of illustrative figures. Below, we present examples of such figures and explain their utility in generating insights. Our focus in this section is on global interpretability, as local interpretability can utilize analogous ways. We showcase GNAN’s application on two datasets, with additional examples detailed in the Appendix.

Refer to caption
Refer to caption
Figure 1: Visualization of the distance and feature functions, learned on Mutagenicity. As the features are binary, the feature functions are evaluated only on the value 1111. These plots provide an exact description of the functions’ signal processing and a global explanation of how the model uses the distances and features.

Our initial examples focus on the task of detecting mutation-causing molecules using the Mutagenicity dataset [47]. In this task, molecules are modeled as graphs where nodes correspond to atoms and edges to connections between these atoms. Each atom type is represented by a 14-dimensional one-hot encoding. A GNAN model trained on this dataset is illustrated in Figure 1. On the left, the function ρ𝜌\rhoitalic_ρ is presented, demonstrating how distance impacts prediction, with a clear diminishing influence of more distant atoms. On the right, the shape functions for the features are displayed. Given that the features are binary, each shape function manifests only two values: one when the feature is 00 (indicating that the atom is not of the specified type), and another when the feature is 1111 (indicating that the atom is of the specified type). Defining b=kfk(0)𝑏subscript𝑘subscript𝑓𝑘0b=\sum_{k}f_{k}(0)italic_b = ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( 0 ) as the bias term allows us to set fk(0)=0subscript𝑓𝑘00f_{k}(0)=0italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( 0 ) = 0 for each k𝑘kitalic_k, thereby enabling the plotting of only fk(1)subscript𝑓𝑘1f_{k}(1)italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( 1 ). The graphical representation reveals that atoms such as Ca, Na, and Li are predicted to correlate with an increased mutagenicity effect, whereas N and P atoms are predicted to be associated with a slight protective effect.

It is essential to emphasize that Figure 1 displays the entire model comprehensively. This means that combined with the value of the bias term, which is 5.66725.6672-5.6672- 5.6672 in this case, every crucial detail needed to understand and utilize this model for predictions is contained within this single figure. This stands in stark contrast to methods like feature importance, which offer a limited perspective on models. While the figure provides complete information about the model, presenting additional views can sometimes be helpful.

Figure 2 showcases the cross product of the shape functions and the distance function as a heatmap. Each cell (k,l)𝑘𝑙(k,l)( italic_k , italic_l ) in the heatmap represents the value ρ(1/(1+l))fk(1)𝜌11𝑙subscript𝑓𝑘1\rho\left(\nicefrac{{1}}{{(1+l)}}\right)\cdot f_{k}(1)italic_ρ ( / start_ARG 1 end_ARG start_ARG ( 1 + italic_l ) end_ARG ) ⋅ italic_f start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( 1 ). This figure illustrates the interplay the model has learned between the graph’s structure and the attributes of its nodes. As the task involves binary classification, positive values in the heatmap contribute to classifying a molecule as mutagenic, whereas negative values indicate non-mutagenic properties.

This heatmap illustrates how atoms at specific distances influence the final outcome. For instance, it shows that the model has learned that the presence of a Ca atom (cell (Ca, 0)) or its proximity (cell (Ca, 1)) contributes to mutagenicity. The visualizations can also be used for debugging purposes. This can be crucial, for example, to ensure that the model is free from biases or to identify any discrepancies with existing scientific knowledge. If it is already known that Ca atoms actually have a negative effect on mutagenicity, users could identify and correct this misalignment in the model’s learning. Additionally, this detailed understanding allows users to select models that not only achieve high accuracy on the given samples but also align with prior knowledge, optimizing both performance and reliability.

Refer to caption
Figure 2: Visualization of products of the outputs of the distance function and the feature functions, trained on Mutagenicity. Each cell shows the exact contribution, positive or negative, of features at a certain distance to the prediction. Positive values (green) contribute to classifying a molecule as mutagenic, and negative values (red) contribute to classifying a molecule as non-mutagenic.

Interpreting multiclass prediction tasks poses significant challenges, as noted by Zhang et al. [6]. In this context, we showcase the interpretability of GNAN using the PubMed dataset [48]. This dataset comprises 19,717 scientific publications related to diabetes archived on PubMed and categorized into three distinct classes (type-1 diabetes, type-2 diabetes, and gestational diabetes). The dataset’s citation network includes 44,338 links. Each publication, represented as a node, is characterized by a TF/IDF weighted word vector derived from a dictionary containing 500 unique words.

As there are three classes, we trained the GNAN model such that the output of the distance and feature functions are of dimension three. In this setting it is interesting to compare the three functions, corresponding to the three classes and therefore we draw them on the same figure [6]. Figure 3 shows that the model uses only the local neighborhood of each node, and as nodes become more distanced, the information between them is less used. We also observe a difference between the classes; while for type 2 diabetes, the longer the distance, the less their information is used (converges to 0), for type 1 and gestational diabetes, nodes of long-distance have a negative effect.

In Figure 4, we display the feature shape functions for nine selected features, demonstrating GNAN’s capability to learn complex, non-monotone functions such as those seen in the ’diet’ and ’hepat’ features. Observing these shape functions across the three classes simultaneously allows for an understanding of how different feature values are utilized by the model to distinguish among the classes. For instance, the shape function for the ’insulin’ feature reveals that the absence of this word in a document (i.e., feature values close to zero) does not significantly indicate the document’s class. However, as the frequency of ’insulin’ increases within the document, its impact on the prediction becomes more pronounced, though this effect varies distinctly between type 1 & 2 diabetes and gestational diabetes.

To visualize the contribution of a feature value at a specific distance, we employ a heatmap for each class, evaluating the products between the outputs of the feature function over the input range ([0,1]01[0,1][ 0 , 1 ]) and the output of the corresponding distance function. Figure 5 exemplifies this visualization technique with the ’children’ feature. It is insightful to observe that the presence of the word ’children’ influences the predictions differently across the diabetes types. The model has learned that papers concerning type 1 diabetes seldom mention ’children’, nor do related papers. In contrast, the term frequently appears in the context of gestational diabetes. Further examples of GNAN visualizations are presented in the Appendix.

Refer to caption
Figure 3: Visualization of the distance shape function learned on the PubMed dataset. As the output of the function is of dimension three, we plot it as three shape functions, one for each class. We plot them on the same figure to compare them. The shape functions show that the model uses only the local neighborhood of each node. It also shows a difference between the classes; while for type 2 diabetes, the longer the distance, the less their information is used (converges to 0), for type 1 and gestational diabetes, nodes of long-distance have a negative effect.
Refer to caption
Figure 4: Visualization of nine features’ shape functions, learned over the PubMed dataset.
Refer to caption
Figure 5: Visualization of the products between the outputs of the ’children’ feature function over the input range [0,1]01[0,1][ 0 , 1 ] and the outputs of the distance function, learned over the PubMed dataset.

5 Empirical evaluation

In this section, we evaluate GNAN on real-world graph and node labeling tasks, including large-scale, long-range, and heterophily datasets.222The implementation can be found at https://github.com/mayabechlerspeicher/Graph-Neural-Additive-Networks---GNAN. We compare GNAN to multiple commonly used black-box GNNs including GraphConv [49], GraphSAGE [30], Graph Isomorphism Network (GIN)  [33], the expressive version of the Graph Attention Network (GATv2) [29, 50], the Graph Transformer (GTransformer) [51]. We also evaluate the FSGNN model, which disentangles the node features from the graph structure [35]. The information on the hyper-parameters tuned for each baseline can be found in the Appendix. We used the following common benchmarks:

Node labeling tasks Cora, Citeseer, PubMed, ogb-arxiv [52, 53] are paper citation networks where the goal is to classify papers into one of several topics. The ogb-arxiv dataset is a large-scale network.
Cornell [54] & Tolokers [55] are heterophilious datasets. Cornell is a web-link network with the task of classifying nodes into one of five categories. Tolokers dataset is based on data from the Toloka crowdsourcing platform. The nodes represent tolokers (workers) who have participated in at least one of 13 selected projects. An edge connects two tolokers if they have worked on the same task. The goal is to predict which tolokers have been banned in one of the projects. Node features are based on the worker’s profile information and task performance statistics.
Graph labeling tasks NCI1, Proteins, Mutagen & PTC  [56] are datasets of chemical compounds. In each dataset, the goal is to classify compounds according to some property of interest.
Thr μ𝜇\muitalic_μ ,α𝛼\alphaitalic_α ,αHOMOsubscript𝛼𝐻𝑂𝑀𝑂\alpha_{HOMO}italic_α start_POSTSUBSCRIPT italic_H italic_O italic_M italic_O end_POSTSUBSCRIPT [57] datasets are long-range molecular property prediction regression tasks, over the large-scale QM9 molecular dataset.
Additional data information, including the data statistics, can be found in the Appendix.

Protocol

For all tasks, we used existing splits, protocols, and metrics, as commonly used in the literature for each dataset. The complete protocols for each dataset are given in detail in the Appendix. The metrics we report are: For Cornell, Cora, Citeseer, PubMed, ogb-arxiv, Mutagenicity, PTC, NCI, and Proteins, we report accuracy. For μ𝜇\muitalic_μ, α𝛼\alphaitalic_α, αHOMOsubscript𝛼𝐻𝑂𝑀𝑂\alpha_{HOMO}italic_α start_POSTSUBSCRIPT italic_H italic_O italic_M italic_O end_POSTSUBSCRIPT we report MAE. For Tolokers we report ROC-AUC. For the node labeling tasks, we used the pre-defined splits in the data and followed the common protocols for each dataset. The results are an average of the test set using 5555 or 10101010 random seeds. For the Proteins and NCI1 tasks, we followed the splits and the nested-cross-validation protocol from [58]. The final reported result on these datasets is an average of 30 runs (10-folds and 3 random seeds). For NCI1 and PTC we followed the splits and protocol from [39] and report the average accuracy and std of a 10-fold nested cross-validation.

Results

The results are presented in Table 1. GNAN performed as the best or second-best model in 9 out of the 13 tasks we evaluated. In GNAN, each node gathers information from all others, ensuring complete information flow, while the ρ𝜌\rhoitalic_ρ function modulates influence based on distance. Consequently, GNAN avoids the computational bottlenecks encountered by some message-passing GNNs [25]. Particularly in the long-range tasks μ𝜇\muitalic_μ, α𝛼\alphaitalic_α, and αHOMOsubscript𝛼𝐻𝑂𝑀𝑂\alpha_{HOMO}italic_α start_POSTSUBSCRIPT italic_H italic_O italic_M italic_O end_POSTSUBSCRIPT, GNAN outperformed all other evaluated baselines, aligning with findings by Alon and Yahav [25] that emphasize the benefits of capturing long-range information. While intelligibility sometimes comes at the cost of accuracy, our findings suggest that enhancing intelligibility does not necessarily result in significant accuracy loss. It may appear surprising that GNAN, despite its limited capacity, matches the accuracy of more expressive GNNs. However, prior research indicates that even limited-capacity GNNs, such as linear GNNs, can achieve high accuracy on various real-world datasets [59, 58, 60], suggesting that some real-world graph problems are simpler than anticipated. Our results corroborate these observations.

Table 1: Evaluation of GNAN on node (top) and graph (bottom) tasks. The best and second-best models are marked in cyan and violet colors, respectively. We report accuracy and std for all tasks, except for the Tolokers dataset where we report ROC-AUC and std, and the μ𝜇\muitalic_μ, α𝛼\alphaitalic_α, αHOMOsubscript𝛼𝐻𝑂𝑀𝑂\alpha_{HOMO}italic_α start_POSTSUBSCRIPT italic_H italic_O italic_M italic_O end_POSTSUBSCRIPT datasets where we report MAE and std.
Model Cornell Tolokers Cora Citeseer PubMed ogb-arxiv
GraphConv 65.9 ± 0.5 83.5 ± 0.7 81.3 ± 1.1 75.9 ± 2.0 85.9 ± 0.5 72.4 ± 0.1
GraphSAGE 75.9 ± 5.0 82.4 ± 0.4 81.4 ± 0.7 76.4 ± 0.8 88.4 ± 0.4 71.7 ± 0.2
GIN 69.0 ± 1.3 81.0 ± 0.4 80.0 ± 1.2 77.1 ± 1.9 85.3 ± 0.9 73.8 ± 1.4
GATv2 72.5 ± 0.7 83.8 ± 1.1 83.1 ± 0.9 73.9 ± 1.5 84.4 ± 0.5 74.0 ± 2.1
GTransformer 70.5 ± 1.7 83.3 ± 0.9 80.7 ± 0.5 76.0 ± 0.9 85.3 ± 1.6 73.1 ± 0.2
FSGNN 86.0 ± 4.1 83.1 ± 0.6 83.0 ± 1.3 76.2 ± 1.3 85.0 ± 1.3 72.9 ± 1.7
GNAN 85.7 ± 4.8 84.5 ± 0.9 81.1 ± 1.5 75.8 ± 0.6 86.9 ± 1.2 74.1 ± 1.5
Model μ𝜇\muitalic_μ α𝛼\alphaitalic_α αHOMOsubscript𝛼𝐻𝑂𝑀𝑂\alpha_{HOMO}italic_α start_POSTSUBSCRIPT italic_H italic_O italic_M italic_O end_POSTSUBSCRIPT Proteins Mutagen PTC NCI1
GraphConv 2.91 ± 0.1 4.37 ± 0.5 1.46 ± 0.1 73.1 ± 1.6 64.3 ± 1.7 63.9 ± 5.0 76.5 ± 1.2
GraphSAGE 3.55 ± 0.2 4.51 ± 0.7 1.44 ± 0.2 73.0 ± 4.5 64.1 ± 0.3 67.1±12.6 76.0 ± 1.8
GIN 2.60 ± 0.1 4.67 ± 0.5 1.42 ± 0.1 73.3 ± 4.0 69.4 ± 1.2 55.6±11.1 80.0 ± 1.4
GATv2 2.72 ± 0.1 4.39 ± 0.6 1.41 ± 0.1 73.5 ± 2.8 72.0 ± 0.9 59.5 ± 2.1 80.4 ± 1.6
GTransformer 2.90 ± 0.3 4.30 ± 0.5 1.41 ± 0.2 73.9 ± 1.5 73.1 ± 0.9 55.9 ± 3.5 80.5 ± 1.1
FSGNN 3.57 ± 0.3 4.50 ± 0.4 1.44 ± 0.3 72.9 ± 2.1 66.9 ± 1.5 60.3 ± 7.2 79.7 ± 1.1
GNAN 2.55 ± 0.1 4.28 ± 0.9 1.40 ± 0.1 73.2 ± 3.1 72.2 ± 1.0 64.9 ± 7.1 76.9 ± 1.2

6 Conclusion

In this work, we introduced the Graph Neural Additive Network (GNAN), a novel extension of the interpretable class of Generalized Additive Models, to accommodate graph data. GNAN is inherently interpretable, and provides both global and local explanations directly from its architecture, eliminating the need for post-hoc interpretations and thus enhancing transparency. Furthermore, GNAN demonstrates competitive performance with popular GNNs, showing that intelligibility does not necessarily entail a significant degradation in accuracy.

It is possible to enhance GNAN in several ways. To generate smooth shape functions, one could integrate techniques from the recently proposed Kolmogorov–Arnold Networks [46]. Increasing the capacity of GNAN is feasible by learning individual distance functions for each feature. Exploring reduced capacity is also intriguing, particularly in scenarios with many features, where it may be beneficial to employ regularization to limit the number of shape functions used. Additionally, applying these techniques to biological network datasets, such as protein interactions, could be a valuable tool to support scientific discoveries. These and other directions are left for future studies.

Acknowledgements

This work was supported by the Tel Aviv University Center for AI and Data Science (TAD) and the Israeli Science Foundation grants 1186/18 and 1437/22.

References

  • Zhou et al. [2020] Jie Zhou, Ganqu Cui, Shengding Hu, Zhengyan Zhang, Cheng Yang, Zhiyuan Liu, Lifeng Wang, Changcheng Li, and Maosong Sun. Graph neural networks: A review of methods and applications. AI open, 1:57–81, 2020.
  • Rudin [2019] Cynthia Rudin. Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead, 2019.
  • Goodman and Flaxman [2017] Bryce Goodman and Seth Flaxman. European union regulations on algorithmic decision-making and a “right to explanation”. AI magazine, 38(3):50–57, 2017.
  • Selbst and Powles [2018] Andrew Selbst and Julia Powles. “meaningful information” and the right to explanation. In conference on fairness, accountability and transparency, pages 48–48. PMLR, 2018.
  • European Parliament [2024] European Parliament. European parliament legislative resolution of 13 march 2024 on the proposal for a regulation of the european parliament and of the council on laying down harmonised rules on artificial intelligence (artificial intelligence act). OJ, 2024. URL https://www.europarl.europa.eu/doceo/document/TA-9-2024-0138_EN.pdf.
  • Zhang et al. [2019] Xuezhou Zhang, Sarah Tan, Paul Koch, Yin Lou, Urszula Chajewska, and Rich Caruana. Axiomatic interpretability for multiclass additive models. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 226–234, 2019.
  • Bilodeau et al. [2024] Blair Bilodeau, Natasha Jaques, Pang Wei Koh, and Been Kim. Impossibility theorems for feature attribution. Proceedings of the National Academy of Sciences, 121(2):e2304406120, 2024.
  • Vale et al. [2022] Daniel Vale, Ali El-Sharif, and Muhammed Ali. Explainable artificial intelligence (xai) post-hoc explainability methods: Risks and limitations in non-discrimination law. AI and Ethics, 2(4):815–826, 2022.
  • Wexler [2017] Rebecca Wexler. When a computer program keeps you in jail. The New York Times, 13:1, 2017.
  • McGough [2018] Michael McGough. How bad is sacramento’s air, exactly? google results appear at odds with reality, some say. Sacramento Bee, 7, 2018.
  • Ghassemi et al. [2021] Marzyeh Ghassemi, Luke Oakden-Rayner, and Andrew L Beam. The false hope of current approaches to explainable artificial intelligence in health care. The Lancet Digital Health, 3(11):e745–e750, 2021.
  • Lundberg and Lee [2017] Scott M Lundberg and Su-In Lee. A unified approach to interpreting model predictions. Advances in neural information processing systems, 30, 2017.
  • Ribeiro et al. [2016] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. " why should i trust you?" explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, pages 1135–1144, 2016.
  • Hooker et al. [2018] Sara Hooker, Dumitru Erhan, Pieter-Jan Kindermans, and Been Kim. Evaluating feature importance estimates. arXiv preprint arXiv:1806.10758, 2, 2018.
  • Friedman [2001] Jerome H Friedman. Greedy function approximation: a gradient boosting machine. Annals of statistics, pages 1189–1232, 2001.
  • Saeed and Omlin [2023] Waddah Saeed and Christian Omlin. Explainable ai (xai): A systematic meta-survey of current challenges and future opportunities. Knowledge-Based Systems, 263:110273, 2023.
  • Laberge et al. [2024] Gabriel Laberge, Yann Batiste Pequignot, Mario Marchand, and Foutse Khomh. Tackling the xai disagreement problem with regional explanations. In International Conference on Artificial Intelligence and Statistics, pages 2017–2025. PMLR, 2024.
  • Franzoni [2023] Valentina Franzoni. From black box to glass box: advancing transparency in artificial intelligence systems for ethical and trustworthy ai. In International Conference on Computational Science and Its Applications, pages 118–130. Springer, 2023.
  • Hastie and Tibshirani [1986] Trevor Hastie and Robert Tibshirani. Generalized Additive Models. Statistical Science, 1(3):297 – 310, 1986. doi: 10.1214/ss/1177013604. URL https://doi.org/10.1214/ss/1177013604.
  • Hastie and Tibshirani [1995] Trevor Hastie and Robert Tibshirani. Generalized additive models for medical research. Statistical methods in medical research, 4(3):187–196, 1995.
  • Hastie and Tibshirani [1987] Trevor Hastie and Robert Tibshirani. Generalized additive models: Some applications. Journal of the American Statistical Association, 82(398):371–386, 1987. doi: 10.1080/01621459.1987.10478440.
  • Yee and Mitchell [1991] Thomas W Yee and Neil D Mitchell. Generalized additive models in plant ecology. Journal of vegetation science, 2(5):587–602, 1991.
  • Rigby and Stasinopoulos [2005] Robert A Rigby and D Mikis Stasinopoulos. Generalized additive models for location, scale and shape. Journal of the Royal Statistical Society Series C: Applied Statistics, 54(3):507–554, 2005.
  • Caruana et al. [2015] Rich Caruana, Yin Lou, Johannes Gehrke, Paul Koch, Marc Sturm, and Noemie Elhadad. Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD international conference on knowledge discovery and data mining, pages 1721–1730, 2015.
  • Alon and Yahav [2021] Uri Alon and Eran Yahav. On the bottleneck of graph neural networks and its practical implications, 2021.
  • Agarwal et al. [2021] Rishabh Agarwal, Levi Melnick, Nicholas Frosst, Xuezhou Zhang, Ben Lengerich, Rich Caruana, and Geoffrey Hinton. Neural additive models: Interpretable machine learning with neural nets, 2021.
  • Kipf and Welling [2017a] Thomas N. Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017a. URL https://openreview.net/forum?id=SJU4ayYgl.
  • Gilmer et al. [2017] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. Neural message passing for quantum chemistry, 2017.
  • Veličković et al. [2018] Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, and Yoshua Bengio. Graph attention networks. In International Conference on Learning Representations, 2018.
  • Hamilton et al. [2018] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs, 2018.
  • Ying et al. [2021] Chengxuan Ying, Tianle Cai, Shengjie Luo, Shuxin Zheng, Guolin Ke, Di He, Yanming Shen, and Tie-Yan Liu. Do transformers really perform bad for graph representation?, 2021.
  • Wang et al. [2024] Chloe Wang, Oleksii Tsepa, Jun Ma, and Bo Wang. Graph-mamba: Towards long-range graph sequence modeling with selective state spaces, 2024.
  • Xu et al. [2019] Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=ryGs6iA5Km.
  • Baranwal et al. [2023] Aseem Baranwal, Kimon Fountoulakis, and Aukosh Jagannath. Optimality of message-passing architectures for sparse graphs. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=d1knqWjmNt.
  • Maurya et al. [2021] Sunil Kumar Maurya, Xin Liu, and Tsuyoshi Murata. Improving graph neural networks with simple architecture design, 2021.
  • Frasca et al. [2020] Fabrizio Frasca, Emanuele Rossi, Davide Eynard, Ben Chamberlain, Michael Bronstein, and Federico Monti. Sign: Scalable inception graph neural networks, 2020.
  • Bechler-Speicher et al. [2024a] Maya Bechler-Speicher, Ido Amos, Ran Gilad-Bachrach, and Amir Globerson. Graph neural networks use graphs when they shouldn’t, 2024a. URL https://arxiv.longhoe.net/abs/2309.04332.
  • Chami et al. [2022] Ines Chami, Sami Abu-El-Haija, Bryan Perozzi, Christopher Ré, and Kevin Murphy. Machine learning on graphs: A model and comprehensive taxonomy. Journal of Machine Learning Research, 23(89):1–64, 2022.
  • Bechler-Speicher et al. [2024b] Maya Bechler-Speicher, Amir Globerson, and Ran Gilad-Bachrach. Tree-g: Decision trees contesting graph neural networks. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 38, pages 11032–11042, 2024b.
  • Luo et al. [2020] Dongsheng Luo, Wei Cheng, Dongkuan Xu, Wenchao Yu, Bo Zong, Haifeng Chen, and Xiang Zhang. Parameterized explainer for graph neural network, 2020.
  • Ying et al. [2019] Rex Ying, Dylan Bourgeois, Jiaxuan You, Marinka Zitnik, and Jure Leskovec. Gnnexplainer: Generating explanations for graph neural networks, 2019.
  • Amara et al. [2022] Kenza Amara, Rex Ying, Zitao Zhang, Zhihao Han, Yinan Shan, Ulrik Brandes, Sebastian Schemm, and Ce Zhang. Graphframex: Towards systematic evaluation of explainability methods for graph neural networks, 2022.
  • Yin et al. [2023] Jun Yin, Chaozhuo Li, Hao Yan, Jianxun Lian, and Senzhang Wang. Train once and explain everywhere: Pre-training interpretable graph neural networks. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview.net/forum?id=enfx8HM4Rp.
  • Yu et al. [2020] Junchi Yu, Tingyang Xu, Yu Rong, Yatao Bian, Junzhou Huang, and Ran He. Graph information bottleneck for subgraph recognition, 2020.
  • Wu et al. [2022] Ying-Xin Wu, Xiang Wang, An Zhang, Xiangnan He, and Tat-Seng Chua. Discovering invariant rationales for graph neural networks, 2022.
  • Liu et al. [2024] Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljačić, Thomas Y. Hou, and Max Tegmark. Kan: Kolmogorov-arnold networks, 2024.
  • Kazius et al. [2005] Jeroen Kazius, Ross McGuire, and Roberta Bursi. Derivation and validation of toxicophores for mutagenicity prediction. Journal of medicinal chemistry, 48(1):312–320, 2005.
  • Sen et al. [2008] Prithviraj Sen, Galileo Namata, Mustafa Bilgic, Lise Getoor, Brian Galligher, and Tina Eliassi-Rad. Collective classification in network data. AI magazine, 29(3):93–93, 2008.
  • Morris et al. [2021] Christopher Morris, Martin Ritzert, Matthias Fey, William L. Hamilton, Jan Eric Lenssen, Gaurav Rattan, and Martin Grohe. Weisfeiler and leman go neural: Higher-order graph neural networks, 2021.
  • Brody et al. [2022] Shaked Brody, Uri Alon, and Eran Yahav. How attentive are graph attention networks?, 2022.
  • Shi et al. [2021] Yunsheng Shi, Zhengjie Huang, Shikun Feng, Hui Zhong, Wen** Wang, and Yu Sun. Masked label prediction: Unified message passing model for semi-supervised classification, 2021.
  • Yang et al. [2016] Zhilin Yang, William W. Cohen, and Ruslan Salakhutdinov. Revisiting semi-supervised learning with graph embeddings, 2016. URL https://arxiv.longhoe.net/abs/1603.08861.
  • Hu et al. [2020] Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and Jure Leskovec. Open graph benchmark: Datasets for machine learning on graphs, 2020. URL https://arxiv.longhoe.net/abs/2005.00687.
  • Pei et al. [2020] Hongbin Pei, Bingzhe Wei, Kevin Chen-Chuan Chang, Yu Lei, and Bo Yang. Geom-gcn: Geometric graph convolutional networks. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=S1e2agrFvS.
  • Likhobaba et al. [2023] Daniil Likhobaba, Nikita Pavlichenko, and Dmitry Ustalov. Toloker Graph: Interaction of Crowd Annotators, 2023. URL https://github.com/Toloka/TolokerGraph.
  • Morris et al. [2020] Christopher Morris, Nils M. Kriege, Franka Bause, Kristian Kersting, Petra Mutzel, and Marion Neumann. Tudataset: A collection of benchmark datasets for learning with graphs. In ICML 2020 Workshop on Graph Representation Learning and Beyond (GRL+ 2020), 2020. URL www.graphlearning.io.
  • Ramakrishnan et al. [2014] Raghunathan Ramakrishnan, Pavlo Dral, Matthias Rupp, and Anatole von Lilienfeld. Quantum chemistry structures and properties of 134 kilo molecules. Scientific Data, 1, 08 2014. doi: 10.1038/sdata.2014.22.
  • Errica et al. [2022] Federico Errica, Marco Podda, Davide Bacciu, and Alessio Micheli. A fair comparison of graph neural networks for graph classification, 2022.
  • Wu et al. [2019] Felix Wu, Amauri Souza, Tianyi Zhang, Christopher Fifty, Tao Yu, and Kilian Weinberger. Simplifying graph convolutional networks. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 6861–6871. PMLR, 09–15 Jun 2019. URL https://proceedings.mlr.press/v97/wu19e.html.
  • Yang et al. [2023] Chenxiao Yang, Qitian Wu, Jiahua Wang, and Junchi Yan. Graph neural networks are inherently good generalizers: Insights by bridging gnns and mlps, 2023.
  • Paszke et al. [2019] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library, 2019.
  • Fey and Lenssen [2019] Matthias Fey and Jan Eric Lenssen. Fast graph representation learning with pytorch geometric, 2019.
  • Platonov et al. [2023] Oleg Platonov, Denis Kuznedelev, Michael Diskin, Artem Babenko, and Liudmila Prokhorenkova. A critical look at the evaluation of GNNs under heterophily: Are we really making progress? In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=tJbbQfw-5wv.
  • Kipf and Welling [2017b] Thomas N. Kipf and Max Welling. Semi-supervised classification with graph convolutional networks, 2017b.

Appendix A Efficient GNAN implementation with tensor products

We formulate GNAM case of classification with C𝐶Citalic_C classes. For regression, the exact formulation holds with C=1𝐶1C=1italic_C = 1. For the sake of notation, we assume that every tensor that has its last dimension, C𝐶Citalic_C, is permuted to have the last dimension as its first dimension, without stating it explicitly. This is necessary to achieve a valid tensor multiplication.

GNAM linearly combines the outputs of at least d+1𝑑1d+1italic_d + 1 neural networks m𝑚mitalic_m and {fl}l=1dsuperscriptsubscriptsubscript𝑓𝑙𝑙1𝑑\{f_{l}\}_{l=1}^{d}{ italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

fl:1×C:subscript𝑓𝑙superscript1𝐶f_{l}:\mathbb{R}\rightarrow\mathbb{R}^{1\times C}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : blackboard_R → blackboard_R start_POSTSUPERSCRIPT 1 × italic_C end_POSTSUPERSCRIPT is a feature-transformation function, and it acts on the lthsuperscript𝑙𝑡l^{\prime}thitalic_l start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_t italic_h feature of some given node i𝑖iitalic_i, i.e., fl(xli)subscript𝑓𝑙subscriptsuperscript𝑥𝑖𝑙f_{l}(x^{i}_{l})italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )

m:1×C:𝑚superscript1𝐶m:\mathbb{R}\rightarrow\mathbb{R}^{1\times C}italic_m : blackboard_R → blackboard_R start_POSTSUPERSCRIPT 1 × italic_C end_POSTSUPERSCRIPT is a distance-transformation function, and it acts on distances between nodes, i.e., m(dist(j,i))𝑚𝑑𝑖𝑠𝑡𝑗𝑖m(dist(j,i))italic_m ( italic_d italic_i italic_s italic_t ( italic_j , italic_i ) ).

Given A𝐴Aitalic_A , we build the distance matrix D𝐷Ditalic_D , where Di,j=dist(j,i)subscript𝐷𝑖𝑗𝑑𝑖𝑠𝑡𝑗𝑖D_{i,j}=dist(j,i)italic_D start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = italic_d italic_i italic_s italic_t ( italic_j , italic_i )

We denote with M𝑀Mitalic_M the matrix of the transformed distances that is outputted by applying m𝑚mitalic_m to each entry of D𝐷Ditalic_D, i.e., Mi,j=m(dist(j,i))subscript𝑀𝑖𝑗𝑚𝑑𝑖𝑠𝑡𝑗𝑖M_{i,j}=m(dist(j,i))italic_M start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = italic_m ( italic_d italic_i italic_s italic_t ( italic_j , italic_i ) ), MC×N×N𝑀superscript𝐶𝑁𝑁M\in\mathbb{R}^{C\times N\times N}italic_M ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_N × italic_N end_POSTSUPERSCRIPT.

We denote with F𝐹Fitalic_F the matrix of the transformed feature is outputted by applying the corresponding flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT for feature l𝑙litalic_l of each node in the graph, i.e., Fil=fl(xli)subscript𝐹𝑖𝑙subscript𝑓𝑙superscriptsubscript𝑥𝑙𝑖F_{il}=f_{l}(x_{l}^{i})italic_F start_POSTSUBSCRIPT italic_i italic_l end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ), FC×N×d𝐹superscript𝐶𝑁𝑑F\in\mathbb{R}^{C\times N\times d}italic_F ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_N × italic_d end_POSTSUPERSCRIPT

Any GNAM, both for node and graph tasks, first computes the matrix MFC×N×d𝑀𝐹superscript𝐶𝑁𝑑M\cdot F\in\mathbb{R}^{C\times N\times d}italic_M ⋅ italic_F ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_N × italic_d end_POSTSUPERSCRIPT

Then, depending on the task.

A.1 Node Tasks

We aggregate the transformed features weighted by the transformed distances. This is done by summing over the rows of MFtensor-product𝑀𝐹M\otimes Fitalic_M ⊗ italic_F :

ϕ(i)=[MF𝟙C×d×1]i=l=1dj=1Nm(dist(j,i))fl(xlj)italic-ϕ𝑖subscriptdelimited-[]tensor-product𝑀𝐹subscript1𝐶𝑑1𝑖superscriptsubscript𝑙1𝑑superscriptsubscript𝑗1𝑁tensor-product𝑚𝑑𝑖𝑠𝑡𝑗𝑖subscript𝑓𝑙superscriptsubscript𝑥𝑙𝑗\phi(i)=[M\otimes F\otimes\mathbbm{1}_{C\times d\times 1}]_{i}=\sum_{l=1}^{d}% \sum_{j=1}^{N}m(dist(j,i))\otimes f_{l}(x_{l}^{j})italic_ϕ ( italic_i ) = [ italic_M ⊗ italic_F ⊗ blackboard_1 start_POSTSUBSCRIPT italic_C × italic_d × 1 end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_m ( italic_d italic_i italic_s italic_t ( italic_j , italic_i ) ) ⊗ italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT )

A.2 Graph Tasks

For graph classification, we set m𝑚mitalic_m and f𝑓fitalic_f to output a scalar and aggregate the transformed features over the nodes, i.e., the row of MF𝑀𝐹M\cdot Fitalic_M ⋅ italic_F, to form a fixed-size vector of size d𝑑ditalic_d.

Φ¯(G)=𝟙C×1×NMFC×1×d¯Φ𝐺tensor-productsubscript1𝐶1𝑁𝑀𝐹superscript𝐶1𝑑\bar{\Phi}(G)=\mathbbm{1}_{C\times 1\times N}\otimes M\otimes F\in\mathbb{R}^{% C\times 1\times d}over¯ start_ARG roman_Φ end_ARG ( italic_G ) = blackboard_1 start_POSTSUBSCRIPT italic_C × 1 × italic_N end_POSTSUBSCRIPT ⊗ italic_M ⊗ italic_F ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × 1 × italic_d end_POSTSUPERSCRIPT

Then, we can apply another readout NAM:

Φ(G)=F¯(Φ¯(G))Φ𝐺¯𝐹¯Φ𝐺\Phi(G)=\bar{F}(\bar{\Phi}(G))roman_Φ ( italic_G ) = over¯ start_ARG italic_F end_ARG ( over¯ start_ARG roman_Φ end_ARG ( italic_G ) )

Such that F¯¯𝐹\bar{F}over¯ start_ARG italic_F end_ARG is the transformed features using the function {f¯l}l=1dsuperscriptsubscriptsubscript¯𝑓𝑙𝑙1𝑑\{\bar{f}_{l}\}_{l=1}^{d}{ over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT such that f¯l:1×C:subscript¯𝑓𝑙superscript1𝐶\bar{f}_{l}:\mathbb{R}\rightarrow\mathbb{R}^{1\times C}over¯ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : blackboard_R → blackboard_R start_POSTSUPERSCRIPT 1 × italic_C end_POSTSUPERSCRIPT

We can also simply sum over the outputs. In that case, we will set f𝑓fitalic_f and m𝑚mitalic_m to output a vector of dimension C𝐶Citalic_C:

Φ(G)=i=1Nϕ(i)=𝟙C×1×NMF𝟙C×d×1Φ𝐺superscriptsubscript𝑖1𝑁italic-ϕ𝑖tensor-productsubscript1𝐶1𝑁𝑀𝐹subscript1𝐶𝑑1\Phi(G)=\sum_{i=1}^{N}\phi(i)=\mathbbm{1}_{C\times 1\times N}\otimes M\otimes F% \otimes\mathbbm{1}_{C\times d\times 1}roman_Φ ( italic_G ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_ϕ ( italic_i ) = blackboard_1 start_POSTSUBSCRIPT italic_C × 1 × italic_N end_POSTSUBSCRIPT ⊗ italic_M ⊗ italic_F ⊗ blackboard_1 start_POSTSUBSCRIPT italic_C × italic_d × 1 end_POSTSUBSCRIPT

Appendix B Extensions and ablations

In Section 3 we mentioned several possible extensions for GNAN. Here, we discuss them in detail.

Readout layer for graph tasks

In graph tasks, after aggregating the node representations, it is possible to apply another transformation before aggregating over the entries of the graph representations. There may be many ways to do so, and we did not explore all of them. We did explore an application of another set of feature functions to each feature or the graph representation vector. This approach increases the capacity of the model in the cost of interpretability. This is because the set of addition feature functions should be plotted separately, and the product between the feature function and the distance functions does not affect the final output directly but rather through another feature function. Empirically, we observed this approach did not improve performance with respect to the performance reported in Section 5.

Splines

It is possible to learn splines for the activations in each feature network to achieve smoother shape functions [46]. We note that the Tolokers example presented in Section C shows that the learned feature shape function is smooth, although we use ReLU activations. Nonetheless, in other cases, such as in the PubMed example presented in the main paper, many of the learned feature functions are step functions. Therefore, it is likely that the model could benefit from spline activations, to smooth its shape functions.

Normalization

In GNAN we normalize the weight of nodes of distance l𝑙litalic_l with the number of nodes of distance l𝑙litalic_l, so that the cumulative weight of nodes of distance l𝑙litalic_l will be ρ(1/(1+l))𝜌11𝑙\rho(\nicefrac{{1}}{{(1+l)}})italic_ρ ( / start_ARG 1 end_ARG start_ARG ( 1 + italic_l ) end_ARG ). We examined the effect of removing this normalization. We observed that without normalization, the loss scale is drastically larger. Therefore, more epochs are required to fit the data. As a result, for the fixed number of epochs we used in our experiments (1000), without normalization, the accuracy decreases.

Appendix C Additional explanations examples

In the main paper, we presented two examples of explanations over two datasets with different properties. In this section, we present additional explanations and examples we could not fit into the main text due to space limitations.

C.1 Additional PubMed heatmaps

In the main text, we presented the heatmaps for the ’children’ feature. Here we provide additional heatmaps for additional features: the ’fat’ feature and the ’young’ feature, as presented in Figures 6 and 7.

Refer to caption
Figure 6: Visualization of the products between the outputs of the ’fat’ feature function over the input range [0, 1] and the outputs of the distance function, learned over the PubMed dataset
Refer to caption
Figure 7: Visualization of the products between the outputs of the ’young’ feature function over the input range [0, 1] and the outputs of the distance function, learned over the PubMed dataset

C.2 Tolokers - Binary classification with binary and continuous features

Tolokers dataset is based on data from the Toloka crowdsourcing platform. The nodes represent tolokers (workers) who have participated in at least one of 13 selected projects. An edge connects two tolokers if they have worked on the same task. The goal is to predict which tolokers have been banned in one of the projects. Node features are based on the worker’s profile information and task performance statistics. Each node in the graph is associated with 9 features. There are 4 continuous features in the range [0,1]01[0,1][ 0 , 1 ] and 5 are binary features. Figure 8 shows the shape functions of the features learned by GNAN. Figure 9 shows the distance shape function learned by GNAN. In Figures 11 and 10 we present the heatmaps of the cross product of the shape functions and the distance function.

Refer to caption
Refer to caption
Figure 8: Visualization of the feature functions learned over the Tolokers dataset.
Refer to caption
Figure 9: isualization of the distance shape function learned on the Tolokers dataset.
Refer to caption
Figure 10: Visualization of the products between the outputs of the continuous feature functions over the input range [0,1]01[0,1][ 0 , 1 ] and the outputs of the distance function, learned over the Tolokers dataset.
Refer to caption
Figure 11: Visualization of products of the outputs of the distance function and the feature functions, trained on Tolokers. Each cell shows the exact contribution, positive or negative, of features at a certain distance to the prediction. Positive values (green) contribute to classifying a toloker as ’banned’, and negative values (red) contribute to classifying a toloker as ’not banned’.

Appendix D Additional experimental details

All our baselines are implemented using PyTorch [61] and PyTorch-Geometric [62].

D.1 Dataset information

Here we provide additional information about the datasets used in Section 5. The data statistics are given in Table 2.

Proteins [56] is a dataset of chemical compounds consisting of 1113111311131113 graphs, respectively. The goal in the first two datasets is to predict whether a compound is an enzyme or not, and the goal in the last datasets is to classify the type of an enzyme among 6666 classes.

NCI1 [56] is a datasets consisting of 4110411041104110 graphs, representing chemical compounds. Vertices and edges represent atoms and the chemical bonds between them. The graphs are divided into two classes according to their ability to suppress or inhibit tumor growth.

Mutagenicity [56] is a dataset consisting of 4337433743374337 chemical compounds of drugs divided into two classes: mutagen and non-mutagen. A mutagen is a compound that changes genetic material such as DNA, and increases mutation frequency.

PTC [56] is a dataset consisting of 344344344344 chemical compounds divided into two classes according to their carcinogenicity for rats.

Cornell [54] is a heterophilic webpage dataset collected from the computer science department at Cornell University. Nodes represent web pages, and edges are hyperlinks between them. The task is to classify the nodes into one of five categories.

Table 2: Statistics of the real-world datasets used in our evaluation.
Dataset # Graphs Avg # Nodes Avg # Edges # Node Features # Classes
Proteins [56] 1,113 39.06 72.82 3 2
NCI1 [56] 4,110 29.87 32.3 37 2
Mutagenicity [56] 4,337 30.32 30.37 7 2
PTC [56] 344 14 14 19 2
QM9 [57] 130,831 18 37.3 11 -
Cora [52] 1 2,708 10,556 1,433 7
Citeseer [52] 1 3,327 9,104 3,703 6
PubMed [52] 1 19,717 88,648 500 3
ogb-arxiv [53] 1 169,343 1,166,243 128 40
Cornell [54] 1 183 295 1,703 5
Tolokers [55] 1 11758 519000 10 2

D.2 Protocols

ogb-arxive

The ogb-arxive datasets are large-scale datasets provided in the Open Graph Benchmark (OGB) paper [53] with pre-defined train and test splits and different metrics and protocols for each dataset. As common in the literature when evaluating OGB datasets, we followed its pre-defined metric and protocol. The metric used is accuracy. We ran GNAN 10101010 times and reported the mean accuracy and std over the runs.

Cornell

For the Cornell dataset we used the splits and protocol from [54] and report the test accuracy averaged over 10101010 runs, using the best hyper-paremeters found on the validation set.

Tolokers

For the Tolokers dataset, we followed the protocol and pre-defined splits from [55, 63]. The reported result is an average of a 10101010-fold nested cross-validation.

Core, Citeseer and PubMed

Following [64, 30, 29], for the Core, Citeseer and Pubmed datasets we tuned the parameters on the Cora dataset using the pre-defined splits from [64]. For all these datasets we report the test accuracies averaged over 5555 runs, using the parameters obtained from the best accuracy on the validation set of Cora.

Proteins, NCI

We used 10101010-fold nested cross validation with the splits and protocol of Errica et al. [58]. The final reported result on these datasets is an average of 30303030 runs (10101010-folds and 3333 random seeds).

Mutagenicity, PTC

We use the splits and protocols from [39], and use a 10101010-fold nested cross-validation. The final reported test accuracies are averages over the 10101010 test sets of the outer 10101010 folds.

D.3 Hyperparameters

All GNNs (excluded GNAN) use ReLU activations with {3,5}35\{3,5\}{ 3 , 5 } layers and 64646464 hidden channels. They were trained with Adam optimizer over 1000100010001000 epochs and early on the validation loss with a patient of 100100100100 steps, eight Decay of 1e41𝑒41e-41 italic_e - 4, learning rate in {1e3,1e4\{1e-3,1e-4{ 1 italic_e - 3 , 1 italic_e - 4}, dropout rate in {0,0.5}00.5\{0,0.5\}{ 0 , 0.5 }, and a train batch size of 32323232.

In GNAN, all the feature and distance networks use ReLU activations with {3,5}35\{3,5\}{ 3 , 5 } layers and {64,32}6432\{64,32\}{ 64 , 32 } hidden channels. They were trained with Adam optimizer over 1000100010001000 epochs weight decay of 0,5e405𝑒40,5e-40 , 5 italic_e - 4, learning rate in {1e2,1e3\{1e-2,1e-3{ 1 italic_e - 2 , 1 italic_e - 3}, dropout rate in {0,0.6}00.6\{0,0.6\}{ 0 , 0.6 }.

D.4 Compute resources

All experiments ran on an NVIDIA GeForce RTX 3090 GPU.