Fishnets: Information-Optimal, Scalable Aggregation for Sets and Graphs
Abstract
Set-based learning is an essential component of modern deep learning and network science. Graph Neural Networks (GNNs) and their edge-free counterparts Deepsets have proven remarkably useful on ragged and topologically challenging datasets. The key to learning informative embeddings for set members is a specified aggregation function, usually a sum, max, or mean. We propose Fishnets, an aggregation strategy for learning information-optimal embeddings for sets of data for both Bayesian inference and graph aggregation. We demonstrate that i) Fishnets neural summaries can be scaled optimally to an arbitrary number of data objects, ii) Fishnets aggregations are robust to changes in data distribution, unlike standard deepsets, iii) Fishnets saturate Bayesian information content and extend to regimes where Markov Chain Monte Carlo (MCMC) techniques fail and iv) Fishnets can be used as a drop-in aggregation scheme within GNNs. We show that by adopting a Fishnets aggregation scheme for message passing, GNNs can achieve state-of-the-art performance versus architecture size on benchmark datasets over existing architectures with a fraction of learnable parameters and faster training time.
1 Introduction
Aggregating information from independent data in an optimal way is a fundamental problem in statistics and machine learning. On one hand, frequentist analyses need optimal estimators for data compression, while on the other Bayesian analyses need small informative summaries for simulation-based inference (SBI) schemes (Cranmer et al., 2020). In a deep learning context graph neural networks (GNNs) rely on aggregation schemes to pool information over large data structures, where each feature might be weakly informative, but at a graph level might contribute a lot of information for predictive or regression tasks (Zhou et al., 2020).
Up until now, graph aggregation schemes have relied on simple, fixed operations such as mean, max, sum, (Corso et al., 2020b; Kipf & Welling, 2017; Hamilton et al., 2017; Xu et al., 2019), variance, or trainable variants of these aggregators (Battaglia et al., 2018; Li et al., 2020), which are susceptible to generalisation issues in heterogeneous data aggregation, and may contribute to GNN “bottlenecking” over large aggregation neighborhoods (Alon & Yahav, 2021; Giovanni et al., 2024). We introduce a new optimal aggregation scheme grounded in information-theoretic principles. By leveraging the additive structure of the log-likelihood for independent data and underlying Fisher curvature, we can construct a learned summary space that asymptotically contains maximal information (Vaart, 1998; Coulton & Wandelt, 2023). We show that this formalism captures relevant information in both a Bayesian inference context as well as for edge aggregation in graphs.
Contributions. In this work we establish Fishnets, a new, information-optimal aggregation scheme for graph and set-based data. By explicitly learning inverse-Fisher weights in addition to neural score embeddings, we are able to achieve i) asymptotic optimality, ii) scalability, and iii) robustness to changes in the generative distribution for the data and its noise characteristics. We are able to construct optimal summary statistics for independent data for SBI applications, and using the same formalism are able to beat key benchmark GNN learning tasks with far smaller architectures in faster training time than leading networks.
Summary of Results. In Fig. LABEL:fig:foo we show how incorporating our aggregation scheme improves model convergence for benchmark (LABEL:fig:gnn_benchmark) and realistic noisy ogb-proteins (LABEL:fig:gnn_noise). We provide other GNN benchmark and model performance specifications in Table 1 and demonstrate that incorporating Fishnets aggregation as a drop-in replacement in an existing GCN framework enables faster convergence and better performance with much smaller model architectures. We carefully explore the information capture of our aggregation in Section 4. We explain this improvement by demonstrating information saturation, robustness, and scalability in a Bayesian context for increasingly difficult problems, and highlight where existing aggregators fall short. In Section 5 we detail the GNN benchmarks and improved aggregator performance.
ogb dataset performance | ogb-proteins comparison | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
2 Method: Optimal Aggregation of independent (heterogeneous) data
2.1 Fisher Information and Optimality Definitions
We first define the notion of information optimality using the likelihood principle. Data d is related to some parameters (or quantities of interest) via a log-likelihood . We would like to obtain a compression map** from data to numbers t which preserves as much information about the parameters as possible. We define the information inequality to quantify how informative this map** is (Lehmann & Casella, 1998):
(1) |
where and the Fisher Information matrix is
(2) |
where the last equality holds under mild regulatory conditions (Alsing & Wandelt, 2018). The Fisher matrix is the curvature of the log-likelihood, and is in general a function of the parameters. In the case where the compressed numbers t are unbiased estimators of the parameters, , and the information inequality (1) reduces to the Cramér-Rao bound (Cramér, 1946):
(3) |
Maximising the Fisher information over parameter space decreases the variance on the estimates of the quantities of interest . (Alsing & Wandelt, 2018) show that the score function, saturates the lower bound of (1) around a fiducial point, . We reproduce this proof in the general case over a space of and relate information saturation to parameter estimators in Appendix A.
2.2 Set-like Data Likelihoods
Many inference problems consist of a set of data vectors, , which obey a global model controlled by parameters , and a possibly arbitrarily deep hierarchy of latent values, . The full data likelihood for interesting parameters is given by the integral over latents,
(4) |
When the data are independently distributed, their log-likelihood takes the form
(5) |
A maximum likelihood estimator can then be formed (iteratively) by the Fisher scoring method (Alsing & Wandelt, 2018):
(6) |
which requires knowledge of a fiducial point , the score, , and Fisher matrix, . For problems like linear regression where the analytic form of F and t are known, Eq. (6) gives the exact MLE for the parameters in a single iteration in the Gaussian approximation, given the dataset. In the case of independent data, both of the score and Fisher information are additive.
Taking the gradient of the log-likelihood with respect to the parameters, the score for the full dataset is the sum of the scores of the individual data points:
(7) |
Taking the gradient again yields the Hessian, or Fisher information matrix (Amari, 2021; Vaart, 1998) for the dataset,
(8) |
which is also comprised of a sum of Fisher matrices of individual data. Once the score and Fisher matrix for a dataset are known, the two can be combined to form a pseudo-maximum likelihood estimate (MLE) for the target parameters following (6). Therefore, constructing optimal embeddings of independent data with respect to specific quantities of interest just requires aggregating the scores and Fishers, and combining them as in (6). However, in general explicit forms for the likelihood (per data vector) may not be known. In this general case, as we will show in the following section, we can parameterize and learn the score and Fisher using neural networks.
2.3 Twin Fisher-Score Networks
For many problems, however, the analytic form of the Fisher and score are not known. Here we propose learning these functions with neural networks. Due to the additive structure of (8) and (7), we can parameterize the per-datapoint score and Fisher with twin neural networks:
(9) | ||||
(10) |
where the score and Fisher network are parameterized by weights and , respectively. The twin networks output a score and Fisher for each datapoint (see Appendix B for formalism), which are then each summed to obtain a global score and Fisher for the dataset. We can then compute parameter estimates using these aggregated embeddings following (6):
(11) |
where the fiducial point can be set to an arbitrary constant. Provided the embeddings and are descriptive enough, the summation formalism can be used to obtain Fisher and score estimates under an implicit likelihood for datasets with heterogeneous structure and arbitrary size. These summaries can be regarded as sufficient statistics, since the score as a function of parameters could in principle be used to reconstruct the likelihood surface up to a constant (Alsing & Wandelt, 2018; Hoffmann & Onnela, 2022).
Loss Function. In a regression scenario, we draw data-parameter pairs from the joint distribution and compute from (11). The twin networks can then be trained jointly using a negative-log Gaussian loss:
(12) |
where Minimizing this loss with respect to the neural network weights ensures that information is saturated via maximising the aggregated Fisher, and forces the distance between embedding MLE and parameters to be minimized with respect to the Cramér-Rao bound ((3)) as a function of the data. This loss can also be interpreted as a maximum likelihood (MLE) loss for the quantities of interest , as opposed to typical mean-square error (MSE) regression losses (see Appendix E for deepsets details).
3 Related Work
Deepsets Mean Aggregation. A comparable method for learning over sets of data is regression using the Deepsets (DS) formalism (Zaheer et al., 2018). Here an embedding is learned for each datum, and then aggregated with a fixed permutation-invariant scheme and fed to a global function ; . The networks are optimised minimising a squared loss against the true parameters, . When the aggregation is chosen to be the mean, the deepsets formalism is scalable to arbitrary data and becomes equivalent to the Fishnets aggregation formalism with flat weights across the aggregated data (see Appendix E for in-depth treatment).
Learned Softmax Aggregation. Li et al. present a learnable softmax counterpart to the DS aggregation scheme in the context of edge aggregation in GNNs. Using the above notation, their aggregation scheme reads:
(13) |
where is a learned scalar temperature parameter and is some embedding layer. They show that adopting this aggregation scheme allows more graph convolution (GCN) layers to be stacked efficiently to deepen GNN models. Many other aggregation frameworks have been studied, including Graph Attention (Veličković et al., 2018), LSTM units (Hamilton et al., 2017), Recurrent aggregations (Soelch et al., 2019), and scaled multiple aggregators (Corso et al., 2020a).
4 Experiments: Bayesian Information Saturation
Bayesian Simulation Based Inference (SBI) provides a framework in which to perform inference with intractable likelihood. There have been massive developments in SBI, such as neural ratio estimation (Miller et al., 2021) and density estimation (Alsing et al., 2019; Papamakarios et al., 2019). Key to all of these methods is compressing a large number of data down to small summaries–typically one informative summary per parameter of interest to preserve information (Alsing & Wandelt, 2018; Charnock et al., 2018; Makinen et al., 2021). ML methods like regression (Jeffrey & Wandelt, 2020) and information-maximising neural networks (Charnock et al., 2018; Makinen et al., 2022, 2021) are very good at learning embeddings for highly structured data like images, and can do so losslessly (Makinen et al., 2021). For unstructured datasets comprised of many independent data, the task of constructing optimal summaries amounts to an aggregation task (Zaheer et al., 2018; Hoffmann & Onnela, 2022; Wagstaff et al., 2019). The Fishnets formalism is an optimal version of this aggregation. What deepsets and “learned” aggregation functions are missing is explicitly constructing the inverse-Fisher weights per datapoint, as well as being able to construct the total Fisher information, which is required to turn summaries into unbiased estimators (Alsing & Wandelt, 2018). Explicitly learning the weights in addition to the score allows us to achieve 1) asymptotic optimality 2) scalability, and 3) robustness to changes in information content among the data.
In this section we demonstrate the 1) information saturation, 2) robustness and 3) scalability of the Fishnets aggregation through two examples in the context of SBI, and highlight the shortcomings of existing aggregators. We first investigate a linear regression scaling problem and introduce a robustness test in which Fishnets outperforms deepset and learned softmax aggregation on test data. We then extend Fishnets to an inference problem with nuisance (latent) parameters and censorship to demonstrate the applicability of network scaling to a regime where MCMC becomes intractable.
4.1 Validation Case: Linear Regression
We use a toy linear regression model to validate our method and demonstrate network scalability. We consider the form , where , where the parameters of interest are the slope and intercept . This likelihood has an analytically-calculable score and Fisher matrix (see Appendix C.1), which can be used to calculate exact MLE estimates for the parameters via (6). We choose wide Gaussian priors for , and uniform distributions for and . For network training, we simulate datasets of size datapoints. For testing, we generate an additional datasets of size datapoints to demonstrate scalability. See Appendix C.2 for neural architecture details.
Results. We display a comparison of test set performance to the true MLE solution in Figure LABEL:fig:info_saturation, and slices of the true and predicted score vectors as a function of input data. The networks are able to recover the exact score and Fisher information matrices (see Figure LABEL:fig:net_outputs), even when scaled up 20-fold. This test demonstrates that Fishnets can (1) saturate information on small training sets to enable scalable predictions on far larger aggregations of data (2).
4.2 Robustness to changes in the underlying data distributions
In real-world applications, actual data processed by a network might follow a different distribution than that seen in training. Here we compare three different network formalisms on changing shapes of target data distributions.
We train three networks on the same datasets as before: a sum-aggregated Fishnets network, a mean-aggregated deepset, and a learned softmax-aggregated deepset with mean-square error loss. To demonstrate the improvement in aggregation Fishnets offers, we adopt smaller networks for the regression task (see Table 2 and Appendix C.2 for architecture details).
We apply our trained networks to test data with noise variances and values drawn from different distributions to the training data: centred at , truncated at , and . The noise and covariate distributions have the same support as the training data, but have different expectation values and distributions, which can pose a problem for the mean-aggregation used in the standard deepsets formalism. We display results in Figure LABEL:fig:stresstest. The heterogeneous Fishnets aggregation allows the network to correctly embed the noisy data drawn from the different distributions, while a significant loss in information can be seen for flat mean aggregation. The learned softmax aggregation improves the width of the residual distribution, but is still significantly wider than the Fishnets solution. We quote numeric results in Table 2.
These robustness tests show that Fishnets successfully learns per-object embeddings (score) and weights (Fisher) within sets, while being robust to changing shapes of the training distributions of these quantities (3). This test also shows that even in a very simple prediction scenario, common and learned aggregators can suffer from robustness issues.
network | # params | |||
---|---|---|---|---|
robustness test | fishnets | |||
deepset | ||||
softmax |
4.3 Scalable Inference With Censorship and Nuisance Parameters
As a non-trivial information saturation example we consider a censorship inference problem with latent parameters inspired by epidemiological studies. Consider a serum which, when injected into a patient, decays over time, and the (heterogeneous) decay rate among people is not well known. A population of patients are injected with the serum and then asked to come back to the lab within days for a measurement of the remaining serum-levels in their blood, . We can cast this problem as a Bayesian hierarchical model visualised in Figure LABEL:fig:plate_diagram (see Appendix C.3 for details) where the goal is to infer the mean and scale of the decay rate Gamma distribution from the data, . In the censored case, measurements are rejected if , and collected until valid samples are collected. As a ground-truth comparison for the uncensored version of this problem, we sample the above hierarchical model using Hamiltonian Monte-Carlo (HMC). For comparison, we utilize the same Fishnets architecture and small-data training setup as before to predict from data inputs . Once trained, we generated a new suite of simulations and pass the data through the network to learn a neural posterior from pairs. We then evaluated both HMC and neural posteriors at the same target data. Finally, using the same network we perform the same procedure, this time with simulations of size , where the HMC becomes computationally prohibitive.
Results. We display inference results in Figure LABEL:fig:uncensored-corner. The summaries obtained from Fishnet compression of the small data (green) result in posteriors that hug the “true” MCMC contours (black), indicating information saturation. Extending the same network on the larger data results in intuitively smaller contours (blue). It should be emphasized that is a regime where the MCMC inference is no longer tractable on standard devices. Fishnets here allows for 1) much faster posterior calculation and 2) allows for optimal inference on larger data set sizes without any retraining.
As a final demonstration we solve the same problem, this time subject to censorship. In the censored case, the target joint posterior defined by the hierarchical model requires computing an integral for the selection probability as a function of the model hyper-parameters; in general, these selection terms make Bayesian problems with censorship computationally challenging, or intractable in some cases (Qi et al., 2022; Dickey et al., 1987).
We train Fishnets on the small data size, subject to censorship below . We obtain posteriors of the same shape of the censored case, but for a consistency check perform a probability-integral transform (PIT) test for the neural posterior. For each parameter we want the marginal PIT test to yield a uniform distribution to show that the learned posterior behaves as a continuous distribution. We display these results in Figure 6. We obtain a Kolmogorov-Smirnov test (Massey Jr., 1951) p-value of 0.628 and 0.233 for parameters and , respectively, indicating that our posterior is well-parameterized and robust.
5 Graph Neural Network Aggregation
Graphs can be thought of as tuples of sets within connected neighborhoods. Graph neural networks (GNNs) operate by message-passing along edges between nodes. For predicting node- and graph-level properties, an aggregation of these sets of edges or nodes is required to reduce features to fixed-size feature vectors. Whereas in the SBI setting, we are interested in finding optimal estimators for specific parameters of interest, in the GNN aggregation setting we are implicitly trying to find a compact latent (embedding) representation of the aggregated neighborhood data, and optimally estimate and propagate those latent features through the GNN architecture.
Here we compare the Fishnets aggregation scheme as a drop-in replacement for learned softmax aggregators within the graph convolutional network (GCN) scheme presented by Li et al.. We can rewrite our aggregations to occur within neighborhoods of nodes:
(14) | ||||
(15) |
where the aggregation occurs in a neighborhood of a node . The Fishnets aggregation requires a bottleneck hyperparameter, , which controls the size of the score embedding and Fisher Cholesky factors . We use a single linear layer before aggregation to obtain score and Fisher components from hidden layer embeddings.
5.1 Drop-in replacement for Graph Benchmark Datasets
Here we replace the learned softmax aggregation with Fishnets aggregation in Li et al.’s publicly-available best-performing models. We change four hyperparameters in testing our new architectures: number of layers, , dropout, and learning rate. We study several graph datasets from the Open Graph Benchmark (OGB) (Hu et al., 2020, 2021), which require substantial aggregation steps to predict either node or graph-level properties. The object of this study is to investigate how well fishnets aggregation can perform within an existing architecture, with fewer layers and minimal hyperoptimisation.
Results. We display benchmark results in Table 1, and refer the reader to Appendix D for architecture and dataset details. This small drop-in study shows that incorporating the more information-efficient Fishnets Aggregation, we can achieve better than or similar results to SOTA GCNs with a fraction of the trainable parameters and training epochs.
5.2 Focus Study on ogbn-proteins Benchmark
In this section we study the proteins dataset in detail highlight a scenario where the heterogeneous Fishnets aggregation drastically improves performance. Here we expect different node neighborhoods to have a heterogeneous edge weighting “association score” structure across protein categories, making the Fishnets aggregation ideal for applicability beyond the training set, as in the linear regression case. The association scores can be stochastically modelled with added measurement noise, increasing the difficulty of the classification problem. We adopt a stripped-down version of the training routine presented in (Li et al., 2020) (no subgraph and edge preprocessing) to make modifications to the raw data by adding noise. We first benchmark our training routine with smaller GCN and Fishnets aggregation on the noise-free data, and then proceed to adding noise to the edges.
Noisefree Results. We display representative test ROC-AUC curves over training in Figure LABEL:fig:gnn_benchmark, and in Table 3. Fishnets-16 and Fishnets-20 clearly saturate information within 250 epochs to and accuracy respectively.
5.2.1 Modelling Uncertain Protein Associations.
Here we incorporate uncertainties on the protein interaction strengths (edges), in order to demonstrate the robustness of the Fishnets approach to changes in the underlying data (noise) distribution on the graph features. We model noisy “measurements” of the protein graph edge associations using a simple Binomial model: taking the dataset edges as the“true” association strengths, we can simulate a noisy measurement of those quantities as weighted coin tosses per edge, where varies between measurements:
(16) | |||
(17) | |||
(18) |
Note that in the last step the new graph edge now contains the (noisy) measured associations, as well as (which provides a measure of uncertainty on those estimated interaction strengths). The GNN task is now to learn to re-weight predictions conditioned on the provided coin toss information, much like feeding in in the linear regression case. We train a 28-layer GCN and 20-layer Fishnets. For the test dataset, we alter the distribution for to be such that we sample the extremes of the training distribution support.
test | network | # params | test ROC-AUC |
---|---|---|---|
noisefree | fishnets-20 | ||
GCN-28 | |||
noisy- | fishnets-20 | ||
edges | GCN-28 |
Noisy Results. We display test ROC-AUC curves for both networks in Figure LABEL:fig:gnn_noise, subject to a patience setting of 250 epochs on the validation set. The GCN framework exhibits an early plateau at accuracy, while Fishnets saturates to accuracy. This stark difference in behaviour can be explained by the difference in formalism: The Fishnets aggregation explicitly learns a weighting scheme as a function of measured edge probabilities and the conditional information , much like the linear regression case where was passed as an input. This scheme helps to learn how to deal with edge-case noise artefacts like the noisy edge test case. Explicitly specifying the inverse-Fisher weighting formalism as an inductive bias (Battaglia et al., 2018) during aggregation can help explain the fast information saturation exhibited in both graph test settings.
6 Discussion & Future Work
In this paper we built up an information-theoretic approach to optimal aggregation in the form of Fishnets. Through progressively non-trivial examples, we demonstrated that explicitly parameterizing the score and inverse-Fisher weights of set members results in an aggregation scheme that saturates Bayesian information in non-trivial problems, and also serves as an optimal aggregator for sets and graph neural networks in heterogeneous data scenarios.
The stark improvement in information saturation on the proteins test dataset relative to architecture size and training efficiency indicates that the Fishnets aggregation acts as an information-level inductive bias for GNN aggregation. Follow-up study is warranted on optimizing hyperparameter choices for graph neural network architectures using Fishnets. We chose to demonstrate improved information capture by using an ablation study of smaller models, but careful (and potentially bigger) network design would almost certainly improve results here and potentially achieve SOTA accuracy on common benchmarks.
References
- Alon & Yahav (2021) Alon, U. and Yahav, E. On the bottleneck of graph neural networks and its practical implications, 2021.
- Alsing & Wandelt (2018) Alsing, J. and Wandelt, B. Generalized massive optimal data compression. Monthly Notices of the Royal Astronomical Society: Letters, 476(1):L60–L64, feb 2018. doi: 10.1093/mnrasl/sly029. URL https://doi.org/10.1093%2Fmnrasl%2Fsly029.
- Alsing et al. (2019) Alsing, J., Charnock, T., Feeney, S., and Wandelt, B. Fast likelihood-free cosmology with neural density estimators and active learning. Monthly Notices of the Royal Astronomical Society, Jul 2019. ISSN 1365-2966. doi: 10.1093/mnras/stz1960. URL http://dx.doi.org/10.1093/mnras/stz1960.
- Amari (2021) Amari, S.-i. Information geometry. Japanese Journal of Mathematics, 16(1):1–48, January 2021. ISSN 1861-3624. doi: 10.1007/s11537-020-1920-5. URL https://doi.org/10.1007/s11537-020-1920-5.
- Battaglia et al. (2018) Battaglia, P. W., Hamrick, J. B., Bapst, V., Sanchez-Gonzalez, A., Zambaldi, V., Malinowski, M., Tacchetti, A., Raposo, D., Santoro, A., Faulkner, R., Gulcehre, C., Song, F., Ballard, A., Gilmer, J., Dahl, G., Vaswani, A., Allen, K., Nash, C., Langston, V., Dyer, C., Heess, N., Wierstra, D., Kohli, P., Botvinick, M., Vinyals, O., Li, Y., and Pascanu, R. Relational inductive biases, deep learning, and graph networks, 2018.
- Charnock et al. (2018) Charnock, T., Lavaux, G., and Wandelt, B. D. Automatic physical inference with information maximizing neural networks. Physical Review D, 97(8), apr 2018. doi: 10.1103/physrevd.97.083004. URL https://doi.org/10.1103%2Fphysrevd.97.083004.
- Clevert et al. (2015) Clevert, D.-A., Unterthiner, T., and Hochreiter, S. Fast and accurate deep network learning by exponential linear units (elus), 2015. URL https://arxiv.longhoe.net/abs/1511.07289.
- Corso et al. (2020a) Corso, G., Cavalleri, L., Beaini, D., Liò, P., and Velickovic, P. Principal neighbourhood aggregation for graph nets. CoRR, abs/2004.05718, 2020a. URL https://arxiv.longhoe.net/abs/2004.05718.
- Corso et al. (2020b) Corso, G., Cavalleri, L., Beaini, D., Liò, P., and Veličković, P. Principal neighbourhood aggregation for graph nets, 2020b.
- Coulton & Wandelt (2023) Coulton, W. R. and Wandelt, B. D. How to estimate fisher information matrices from simulations, 2023.
- Cramér (1946) Cramér, H. Mathematical methods of statistics, by Harald Cramer, .. The University Press, 1946.
- Cranmer et al. (2020) Cranmer, K., Brehmer, J., and Louppe, G. The frontier of simulation-based inference. Proceedings of the National Academy of Sciences, 117(48):30055–30062, may 2020. doi: 10.1073/pnas.1912789117. URL https://doi.org/10.1073%2Fpnas.1912789117.
- Dickey et al. (1987) Dickey, J. M., Jiang, J.-M., and Kadane, J. B. Bayesian methods for censored categorical data. Journal of the American Statistical Association, 82(399):773–781, 1987. ISSN 01621459. URL http://www.jstor.org/stable/2288786.
- Fey & Lenssen (2019) Fey, M. and Lenssen, J. E. Fast graph representation learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.
- Giovanni et al. (2024) Giovanni, F. D., Rusch, T. K., Bronstein, M. M., Deac, A., Lackenby, M., Mishra, S., and Veličković, P. How does over-squashing affect the power of gnns?, 2024.
- Hamilton et al. (2017) Hamilton, W., Ying, Z., and Leskovec, J. Inductive representation learning on large graphs. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings.neurips.cc/paper_files/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf.
- Hoffmann & Onnela (2022) Hoffmann, T. and Onnela, J.-P. Minimizing the expected posterior entropy yields optimal summary statistics, 2022. URL https://arxiv.longhoe.net/abs/2206.02340.
- Hu et al. (2020) Hu, W., Fey, M., Zitnik, M., Dong, Y., Ren, H., Liu, B., Catasta, M., and Leskovec, J. Open graph benchmark: Datasets for machine learning on graphs. arXiv preprint arXiv:2005.00687, 2020.
- Hu et al. (2021) Hu, W., Fey, M., Ren, H., Nakata, M., Dong, Y., and Leskovec, J. Ogb-lsc: A large-scale challenge for machine learning on graphs. arXiv preprint arXiv:2103.09430, 2021.
- Jeffrey & Wandelt (2020) Jeffrey, N. and Wandelt, B. D. Solving high-dimensional parameter inference: marginal posterior densities & moment networks, 2020. URL https://arxiv.longhoe.net/abs/2011.05991.
- Kipf & Welling (2017) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks, 2017.
- Lehmann & Casella (1998) Lehmann, E. L. and Casella, G. Theory of Point Estimation. Springer-Verlag, New York, NY, USA, second edition, 1998.
- Li et al. (2020) Li, G., Xiong, C., Thabet, A., and Ghanem, B. Deepergcn: All you need to train deeper gcns, 2020.
- Makinen et al. (2021) Makinen, T. L., Charnock, T., Alsing, J., and Wandelt, B. D. Lossless, scalable implicit likelihood inference for cosmological fields. Journal of Cosmology and Astroparticle Physics, 2021(11):049, nov 2021. doi: 10.1088/1475-7516/2021/11/049. URL https://doi.org/10.1088%2F1475-7516%2F2021%2F11%2F049.
- Makinen et al. (2022) Makinen, T. L., Charnock, T., Lemos, P., Porqueres, N., Heavens, A. F., and Wandelt, B. D. The cosmic graph: Optimal information extraction from large-scale structure using catalogues. The Open Journal of Astrophysics, 5(1), dec 2022. doi: 10.21105/astro.2207.05202. URL https://doi.org/10.21105%2Fastro.2207.05202.
- Massey Jr. (1951) Massey Jr., F. J. The kolmogorov-smirnov test for goodness of fit. Journal of the American Statistical Association, 46(253):68–78, 1951. doi: 10.1080/01621459.1951.10500769. URL https://www.tandfonline.com/doi/abs/10.1080/01621459.1951.10500769.
- Miller et al. (2021) Miller, B. K., Cole, A., Forré, P., Louppe, G., and Weniger, C. Truncated marginal neural ratio estimation. In Ranzato, M., Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J. W. (eds.), Advances in Neural Information Processing Systems, volume 34, pp. 129–143. Curran Associates, Inc., 2021. URL https://proceedings.neurips.cc/paper_files/paper/2021/file/01632f7b7a127233fa1188bd6c2e42e1-Paper.pdf.
- Papamakarios et al. (2019) Papamakarios, G., Sterratt, D., and Murray, I. Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows. In Chaudhuri, K. and Sugiyama, M. (eds.), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics, volume 89 of Proceedings of Machine Learning Research, pp. 837–848. PMLR, 16–18 Apr 2019. URL https://proceedings.mlr.press/v89/papamakarios19a.html.
- Phan et al. (2019) Phan, D., Pradhan, N., and Jankowiak, M. Composable effects for flexible and accelerated probabilistic programming in numpyro. arXiv preprint arXiv:1912.11554, 2019.
- Qi et al. (2022) Qi, X., Zhou, S., and Plummer, M. On Bayesian modeling of censored data in JAGS. BMC Bioinformatics, 23(1):102, March 2022. ISSN 1471-2105. doi: 10.1186/s12859-021-04496-8. URL https://doi.org/10.1186/s12859-021-04496-8.
- Ramachandran et al. (2017) Ramachandran, P., Zoph, B., and Le, Q. V. Searching for activation functions, 2017.
- Soelch et al. (2019) Soelch, M., Akhundov, A., van der Smagt, P., and Bayer, J. On Deep Set Learning and the Choice of Aggregations, pp. 444–457. Springer International Publishing, 2019. ISBN 9783030304874. doi: 10.1007/978-3-030-30487-4˙35. URL http://dx.doi.org/10.1007/978-3-030-30487-4_35.
- Vaart (1998) Vaart, A. W. v. d. Asymptotic Statistics. Cambridge Series in Statistical and Probabilistic Mathematics. Cambridge University Press, 1998. doi: 10.1017/CBO9780511802256.
- Veličković et al. (2018) Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., and Bengio, Y. Graph attention networks, 2018.
- Wagstaff et al. (2019) Wagstaff, E., Fuchs, F. B., Engelcke, M., Posner, I., and Osborne, M. On the limitations of representing functions on sets, 2019.
- Xu et al. (2019) Xu, K., Hu, W., Leskovec, J., and Jegelka, S. How powerful are graph neural networks?, 2019.
- Zaheer et al. (2018) Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R., and Smola, A. Deep sets, 2018.
- Zhou et al. (2020) Zhou, J., Cui, G., Hu, S., Zhang, Z., Yang, C., Liu, Z., Wang, L., Li, C., and Sun, M. Graph neural networks: A review of methods and applications. AI Open, 1:57–81, 2020. ISSN 2666-6510. doi: https://doi.org/10.1016/j.aiopen.2021.01.001. URL https://www.sciencedirect.com/science/article/pii/S2666651021000012.
Appendix A Saturating the Information Inequality over Parameter space
Here we show that knowing the score function saturates the information inequality and provides a natural data compression function. We first consider the Taylor expansion of the log-likelihood around a fixed fiducial point in parameter space, , (where ):
(19) |
where is the observed information matrix. To linear order in , the data d couples to the parameters through the score function . We can show that t saturates the information inequality via
(20) |
where we have used the fact that . From this we observe that the covariance of the score function is the Fisher matrix. Using the fact that
(21) |
the right-hand side of the information inequality becomes , which shows that the score statistics t saturate the information inequality. Within this formalism, no statistics can provide more (Fisher) information about the parameters .
We can relate this information saturation to an optima, quasi maximum-likelihood estimator whose covariance is equal to the inverse Fisher information from the above derivation. Maximising the Taylor expansion (19) with respect to the parameters yields
(22) |
where both the score and the observed information depend on the observed data. In practice, we can exchange J with its expectation value, the Fisher information: , which yields
(23) |
Making this replacement means the MLE estimator only depends on the data through the score function statistics . The covariance of the MLE estimator (at the expansion point ) is then:
(24) |
where . Hence the covariance of the MLE is equal to the Fisher information matrix at and the Cramér-Rao bound is saturated.
The above proof of information saturation calculated the information saturation around a fixed fiducial point. In general, however, by parameterising the score and Fisher functions with neural networks under the loss (12) we are learning an embedding and weighting neighborhood as a function of data (and implicitly parameters).
Appendix B Calculating the Fisher Matrix from Network Outputs
To ensure that our Fisher matrix is positive-definite, our Fisher-score networks output numbers as lower triangular entries in a Cholesky decomposition of the Fisher matrix, L. To ensure that the lower triangular entries remain positive-definite, we add a softplus activation to the diagonal entries of L:
(25) |
We then compute the Fisher via:
(26) |
The negative-log likelihood loss in Equation 12 allows for explicit interrogation of the resulting Fisher matrix at the level of the predicted quantities (parameters), and ensures that the summary space in is convex. In the GNN regression formalism, Fishnets does not explicitly maximise the Fisher information as a part of the loss, rather the Fisher matrix weights are optimized as an inductive bias as a hidden layer in the GNN scheme. E
Appendix C Bayesian Information Experiment Details
C.1 Scalable Linear Regression
We use a toy linear regression model to validate our method and demonstrate network scalability. We consider the form , where , where the parameters of interest are the slope and intercept . This likelihood has an analytically-calculable score and Fisher matrix,
t | (27) | |||
F | (28) |
where , with is the mean of the prior on the score, and is added to the Fisher matrix as a prior on the inverse-covariance of the spread of the summaries. With these two expressions we can calculate exact MLE estimates for the parameters via Eq. (6). We choose wide Gaussian priors for , and uniform priors for and . For network training, we simulate datasets of size datapoints. For testing, we generate an additional datasets of size datapoints to demonstrate scalability. We use fully-connected MLPs of size [256, 256, 256] with ELU activations (Clevert et al., 2015) for both score and Fisher networks. Both networks receive the input data . We train networks for 2500 epochs with an adam optimizer using a step learning rate decay schedule. We train an ensemble of 10 networks in parallel on the same training data with different initializations.
C.2 Robustness Network Architecture Comparison
We train three networks on the same datasets as before: a sum-aggregated Fishnets network, a mean-aggregated deepset, and a learned softmax-aggregated deepset (no Fisher output and standard MSE loss against true parameters ). Here we initialise Fishnets with [50,50,50] hidden units for score and Fisher networks, and two embeddings of [128,128,128] hidden units for both deepset networks, all with swish (Ramachandran et al., 2017) nonlinearities for the data embedding (see Table 2). All networks are initialised with the same seed.
C.3 Gamma Population Model
Consider a serum which increases patients’ red blood cell counts, whose decay rate, , is not known. A population of patients are injected with the serum and then asked to come back to the lab within days for a measurement of their blood cell count, . We can cast this problem using the following hierarchical model
where the goal is to infer the mean and scale of the decay rate Gamma distribution from the data, . In the censored case, measurements are rejected if , and collected until samples are accepted. The model is visualised in a plate diagram in Figure LABEL:fig:plate_diagram. In the uncensored case, the posterior estimation for this problem is readily solved using a high-dimensional Hamiltonian Monte-Carlo (HMC) sampler. We implement this model in Numpyro (Phan et al., 2019) as a baseline MCMC comparison for our algorithm. For the Fishnets implementation, we generate simulations of size over a uniform prior for and . We then train the same Fishnets architecture used for the linear regression case with data inputs . Once the networks were trained, we pass a suite of simulations through the network to generate neural summaries with which to train a density estimation network. Following (Alsing et al., 2019), we use Mixture Density Networks to learn an amortized posterior for with three hidden layers of size . We then evaluate this posterior at the same target data used for the HMC, shown in green in Figure LABEL:fig:uncensored-corner. The Fishnets compression results in slightly inflated contours, indicating a small leakage of information. To demonstrate scaling, we addidionally generate another simulation at using the same random seed. We train another amortised posterior using 5000 simulations at and pass the data through the same trained Fishnet architecture. The resulting posterior is shown in blue for comparison.
Appendix D Graph Prediction Benchmark Experiment Details
All GNN models are implemented in PyTorch Geometric (Fey & Lenssen, 2019), and all experiments are performed on a single NVIDIA V100 32GB with the same random seed for initialisation and training. We first describe graph- and node-level prediction tasks, followed by the experimental details on the three benchmark datasets.
Node Property Prediction. This task consists of aggregating edge and node information within neighborhoods to predict properties at the node level. The ogbn-arxiv dataset is a directed citation graph of papers summarised as 128-dimensional vectors (nodes) and citations (edges), where the direction indicates the citation direction. The task is to predict which of 40 classes each paper belongs to. The ogbn-proteins dataset consists of proteins encoded as 8-dimensional one-hot features indicating protein species (nodes) and undirected weighted edges indicating association scores between proteins. The task a 112-class classification from aggregated subgraph edges and nodes using an ROC-AUC metric.
Graph Property Prediction. This task requires the aggregation of edges and nodes to global features of a graph. We consider the ogbn-molhiv dataset, which is comprised of molecules with atoms arranged as nodes and bonds as edges. The prediction task is binary classification.
ogb-molhiv. This dataset does not provide a node feature for each protein. We initialize the node features via a sum aggregation, e.g. , where denotes the initialized node features and denotes the input edge features. We train a 7-layer DyResGEN model with softmax aggregator with learnable parameter. A batch normalization is used for each layer. We set the hidden channel size as 256. A dropout with a rate of 0.5 is used for each layer. An Adam optimizer with a learning rate of 0.0001 are used to train the model for 150 epochs. For the Fishnets comparison we train a 3-layer network with Fishnets Aggregation with bottleneck in place of the softmax aggregation, and a learning rate of 0.00002.
ogb-arxiv. We train (Li et al., 2020)’s 28-layer ResGEN model with softmax aggregation where is fixed as 0.1. Full batch training and test are applied. A batch normalization is used for each layer. The hidden channel size is 128. We apply a dropout with a rate of 0.5 for each layer. An Adam optimizer with a learning rate of 0.01 is used to train the model for 500 epochs. For the Fishnets comparison we train a 3-layer version of the same ResGEN network with Fishnets Aggregation with bottleneck in place of the softmax aggregation.
ogb-proteins. This dataset does not provide a node feature for each protein. We initialize the node features via a sum aggregation, e.g. , where denotes the initialized node features and denotes the input edge features. We train (Li et al., 2020)’s 112-layer DyResGEN with softmax aggregator. A hidden channel size of 64 is used. A layer normalization and a dropout with a rate of 0.1 are used for each layer. We train the model for 1000 epochs with an Adam optimizer with a learning rate of 0.01. For the Fishnets comparison we train an 8-layer and 16-layer version of the same DyResGEN network with Fishnets Aggregation with bottleneck in place of the softmax aggregation. Here we temper the learning rate to 0.005 and decrease the dropout rate to 0.25.
D.1 Noisy Proteins Focus Study
Here we again initialize the node features via a sum aggregation.
We test five model architectures using the vanilla ogbn-proteins dataset (no subgraph and edge preprocessing as performed by (Li et al., 2020)). This change allowed us to flexibly incorporate the added edge feature in the noisy edge setting. To benchmark our training routine we adopt a 28-layer DyResGEN network with learned softmax aggregations and hidden size of 64, and a smaller version of this model with hidden size 14 and 28 layers. We construct two, shallower Fishnets GNNs, with 16 and 20 layers, each with 64 hidden units, and one small model with 14 hidden units and 14 layers. For each graph convolution aggregation, we adopt a “score” bottleneck of for the large Fishnets models and for the small model. We train all networks with a cross-entropy loss over the same dataset and fixed random seed using an Adam optimizer with fixed learning rate . We incorporate an early stop** criterion conditioned on the validation data, which dictates an end to training (saturation) when the validation ROC-AUC metric stops increasing for epochs.
In the noisy proteins setting we again control for stochasticity in training set loading and added edge noise by fixing the initial random seed before each training run.
test | network | # params | test ROC-AUC |
noisefree | fishnets-20 | ||
fishnets-16 | |||
fishnets small | |||
GCN-28 | |||
GCN small | |||
noisy edges | fishnets-20 | ||
GCN-28 |
Appendix E Deepsets Formalism
Summary. The deepsets method presented by Zaheer et al. shows in Theorem 9 that any function over a countable set can be decomposed in the form . They then extend this to the universality of deepsets since and can be parameterized as neural networks, which can be universal function approximators. The deepsets formalism allows point-estimates for regression parameters to be obtained following an aggregation of features in a potentially variably-sized set of data. Incorporating our formalism, each set member is first passed to a neural network , and subsequently aggregated using some permutation-invariant scheme, .
(29) |
where is the embedding network and is the “global” function that maps aggregated features to predicted parameters. When the aggregation is chosen to be the mean, the deepsets formalism is scalable to arbitrary data and becomes equivalent to the Fishnets aggregation formalism with flat weights across the aggregated data. The loss takes the form of a convex squared loss, e.g. the mean square error
(30) |
where is a batch of full simulations, each of size .
Training and Generalization. In practice, Deepsets requires a fixed aggregation scheme from which to learn its global function. Most often this is a summary of embedding layers . For networks to scale to arbitrary dataset cardinality, aggregations like max, mean, and variance need to be used. In a scenario where the training data distribution follows a different distribution from the training data, these aggregations might pose an issue. Concretely, consider , with the target quantity . Next consider a deepset with the identity embedding layer and mean-aggregation:
(31) |
If test data were drawn from the same distribution as the test data, would act on the mean value of the set of data, in this case , and would converge to a learned function of the joint prior-data distribution . However, if a test set of data were drawn from a different distribution, e.g. , then the expectation would take on a different value, and would return an incorrect result for the deterministic aggregation. Here it is important to emphasize that and overlap along the same support, meaning the network will have seen examples of data drawn from this prior in the limit of an infinite training set. However, the fixed aggregation makes use of a training-data distribution-dependent quantity for its map**, which can be skewed under covariate shift or different noise settings.