Fishnets: Information-Optimal, Scalable Aggregation for Sets and Graphs

T. Lucas Makinen    Justin Alsing    Benjamin D. Wandelt
Abstract

Set-based learning is an essential component of modern deep learning and network science. Graph Neural Networks (GNNs) and their edge-free counterparts Deepsets have proven remarkably useful on ragged and topologically challenging datasets. The key to learning informative embeddings for set members is a specified aggregation function, usually a sum, max, or mean. We propose Fishnets, an aggregation strategy for learning information-optimal embeddings for sets of data for both Bayesian inference and graph aggregation. We demonstrate that i) Fishnets neural summaries can be scaled optimally to an arbitrary number of data objects, ii) Fishnets aggregations are robust to changes in data distribution, unlike standard deepsets, iii) Fishnets saturate Bayesian information content and extend to regimes where Markov Chain Monte Carlo (MCMC) techniques fail and iv) Fishnets can be used as a drop-in aggregation scheme within GNNs. We show that by adopting a Fishnets aggregation scheme for message passing, GNNs can achieve state-of-the-art performance versus architecture size on benchmark datasets over existing architectures with a fraction of learnable parameters and faster training time.

Machine Learning, ICML

1 Introduction

Aggregating information from independent data in an optimal way is a fundamental problem in statistics and machine learning. On one hand, frequentist analyses need optimal estimators for data compression, while on the other Bayesian analyses need small informative summaries for simulation-based inference (SBI) schemes (Cranmer et al., 2020). In a deep learning context graph neural networks (GNNs) rely on aggregation schemes to pool information over large data structures, where each feature might be weakly informative, but at a graph level might contribute a lot of information for predictive or regression tasks (Zhou et al., 2020).

Up until now, graph aggregation schemes have relied on simple, fixed operations such as mean, max, sum, (Corso et al., 2020b; Kipf & Welling, 2017; Hamilton et al., 2017; Xu et al., 2019), variance, or trainable variants of these aggregators (Battaglia et al., 2018; Li et al., 2020), which are susceptible to generalisation issues in heterogeneous data aggregation, and may contribute to GNN “bottlenecking” over large aggregation neighborhoods (Alon & Yahav, 2021; Giovanni et al., 2024). We introduce a new optimal aggregation scheme grounded in information-theoretic principles. By leveraging the additive structure of the log-likelihood for independent data and underlying Fisher curvature, we can construct a learned summary space that asymptotically contains maximal information (Vaart, 1998; Coulton & Wandelt, 2023). We show that this formalism captures relevant information in both a Bayesian inference context as well as for edge aggregation in graphs.

Contributions. In this work we establish Fishnets, a new, information-optimal aggregation scheme for graph and set-based data. By explicitly learning inverse-Fisher weights in addition to neural score embeddings, we are able to achieve i) asymptotic optimality, ii) scalability, and iii) robustness to changes in the generative distribution for the data and its noise characteristics. We are able to construct optimal summary statistics for independent data for SBI applications, and using the same formalism are able to beat key benchmark GNN learning tasks with far smaller architectures in faster training time than leading networks.

Summary of Results. In Fig. LABEL:fig:foo we show how incorporating our aggregation scheme improves model convergence for benchmark (LABEL:fig:gnn_benchmark) and realistic noisy ogb-proteins (LABEL:fig:gnn_noise). We provide other GNN benchmark and model performance specifications in Table 1 and demonstrate that incorporating Fishnets aggregation as a drop-in replacement in an existing GCN framework enables faster convergence and better performance with much smaller model architectures. We carefully explore the information capture of our aggregation in Section 4. We explain this improvement by demonstrating information saturation, robustness, and scalability in a Bayesian context for increasingly difficult problems, and highlight where existing aggregators fall short. In Section 5 we detail the GNN benchmarks and improved aggregator performance.

ogb dataset performance ogb-proteins comparison
dataset GCN fishnets
ogb-arxiv 0.71000.71000.71000.7100 0.70620.70620.70620.7062
ogb-molhiv 0.76000.76000.76000.7600 0.80000.8000\bm{0.8000}bold_0.8000
ogb-proteins 0.84250.84250.84250.8425 0.84440.8444\bm{0.8444}bold_0.8444
model # params test ROC-AUC
GCN-112 1,887,144 0.8425±0.0018plus-or-minus0.84250.00180.8425\pm 0.00180.8425 ± 0.0018
fishnets-8 146,596 0.8410±0.0013plus-or-minus0.84100.0013{0.8410\pm 0.0013}0.8410 ± 0.0013
fishnets-16 280,740 0.8444±0.0018plus-or-minus0.84440.0018\bm{0.8444\pm 0.0018}bold_0.8444 bold_± bold_0.0018
Table 1: (left) Summary of benchmark improvement within GCN framework with Fishnets aggregation. (right) Model size comparison for ogb-proteins benchmark. Fishnets aggregation improves performance with 15%similar-toabsentpercent15\sim 15\%∼ 15 % of the learnable parameters.

2 Method: Optimal Aggregation of independent (heterogeneous) data

2.1 Fisher Information and Optimality Definitions

We first define the notion of information optimality using the likelihood principle. Data d is related to some parameters (or quantities of interest) via a log-likelihood =lnp(d|𝜽)𝑝conditionald𝜽\mathcal{L}=\ln p(\textbf{d}|\bm{\theta})caligraphic_L = roman_ln italic_p ( d | bold_italic_θ ). We would like to obtain a compression map** f:dt:𝑓maps-todtf:\textbf{d}\mapsto\textbf{t}italic_f : d ↦ t from N𝑁Nitalic_N data to npsubscript𝑛𝑝n_{p}italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT numbers t which preserves as much information about the parameters 𝜽𝜽\bm{\theta}bold_italic_θ as possible. We define the information inequality to quantify how informative this map** is (Lehmann & Casella, 1998):

Var𝜽[tα](ATF1A),subscriptVar𝜽delimited-[]subscript𝑡𝛼superscriptA𝑇superscriptF1A\mathrm{Var}_{\bm{\theta}}\left[t_{\alpha}\right]\geq\left(\textbf{A}^{T}% \textbf{F}^{-1}\textbf{A}\right),roman_Var start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ italic_t start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ] ≥ ( A start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT A ) , (1)

where A=𝔼𝜽[tT]Asubscript𝔼𝜽delimited-[]superscriptt𝑇\textbf{A}=\nabla\mathbb{E}_{\bm{\theta}}\left[\textbf{t}^{T}\right]A = ∇ blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ t start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] and the Fisher Information matrix is

F=𝔼𝜽[T]=𝔼𝜽[T],Fsubscript𝔼𝜽delimited-[]superscript𝑇subscript𝔼𝜽delimited-[]superscript𝑇\textbf{F}=-\mathbb{E}_{\bm{\theta}}\left[\nabla\nabla^{T}\mathcal{L}\right]=% \mathbb{E}_{\bm{\theta}}\left[\nabla\mathcal{L}\nabla^{T}\mathcal{L}\right],F = - blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L ] = blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ caligraphic_L ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L ] , (2)

where the last equality holds under mild regulatory conditions (Alsing & Wandelt, 2018). The Fisher matrix is the curvature of the log-likelihood, and is in general a function of the parameters. In the case where the compressed numbers t are unbiased estimators of the parameters, 𝔼𝜽[t]=𝜽subscript𝔼𝜽delimited-[]t𝜽\mathbb{E}_{\bm{\theta}}\left[\textbf{t}\right]=\bm{\theta}blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ t ] = bold_italic_θ, A=𝕀A𝕀\textbf{A}=\mathbb{I}A = blackboard_I and the information inequality (1) reduces to the Cramér-Rao bound (Cramér, 1946):

Var𝜽[tα]Fαα1.subscriptVar𝜽delimited-[]subscript𝑡𝛼superscriptsubscriptF𝛼𝛼1\mathrm{Var}_{\bm{\theta}}\left[t_{\alpha}\right]\geq\textbf{F}_{\alpha\alpha}% ^{-1}.roman_Var start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ italic_t start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ] ≥ F start_POSTSUBSCRIPT italic_α italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT . (3)

Maximising the Fisher information over parameter space decreases the variance on the estimates of the quantities of interest 𝜽𝜽\bm{\theta}bold_italic_θ. (Alsing & Wandelt, 2018) show that the score function, t=t\textbf{t}=\nabla\mathcal{L}t = ∇ caligraphic_L saturates the lower bound of (1) around a fiducial point, 𝜽subscript𝜽\bm{\theta}_{*}bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT. We reproduce this proof in the general case over a space of 𝜽𝜽\bm{\theta}bold_italic_θ and relate information saturation to parameter estimators in Appendix A.

Maximum likelihood estimators (MLEs) are the asymptotically-optimal estimators for predictive tasks. When they are available, they provide an optimally-informative embedding of the data with respect to the parameters of interest, 𝜽𝜽\bm{\theta}bold_italic_θ (see (Alsing & Wandelt, 2018) and Appendix A).

2.2 Set-like Data Likelihoods

Many inference problems consist of a set of ndatasubscript𝑛datan_{\rm data}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT data vectors, {di}i=1ndatasuperscriptsubscriptsubscriptd𝑖𝑖1subscript𝑛data\{\textbf{d}_{i}\}_{i=1}^{n_{\rm data}}{ d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, diNsubscriptd𝑖superscript𝑁\textbf{d}_{i}\in\mathbb{R}^{N}d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT which obey a global model controlled by parameters 𝜽np𝜽superscriptsubscript𝑛𝑝\bm{\theta}\in\mathbb{R}^{n_{p}}bold_italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and a possibly arbitrarily deep hierarchy of latent values, η𝜂\etaitalic_η. The full data likelihood for interesting parameters 𝜽𝜽\bm{\theta}bold_italic_θ is given by the integral over latents,

p({di}|𝜽)=p({di}|𝜽,η)p(η|𝜽)𝑑η.𝑝conditionalsubscriptd𝑖𝜽𝑝conditionalsubscriptd𝑖𝜽𝜂𝑝conditional𝜂𝜽differential-d𝜂p(\{\textbf{d}_{i}\}|\bm{\theta})=\int p(\{\textbf{d}_{i}\}|\bm{\theta},\eta)p% (\eta|\bm{\theta})d\eta.italic_p ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } | bold_italic_θ ) = ∫ italic_p ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } | bold_italic_θ , italic_η ) italic_p ( italic_η | bold_italic_θ ) italic_d italic_η . (4)

When the data are independently distributed, their log-likelihood takes the form

lnp({di}|𝜽)=i=1ndatalnp(di|𝜽).𝑝conditionalsubscriptd𝑖𝜽superscriptsubscript𝑖1subscript𝑛data𝑝conditionalsubscriptd𝑖𝜽\ln p(\{\textbf{d}_{i}\}|\bm{\theta})=\sum_{i=1}^{n_{\rm data}}\ln p(\textbf{d% }_{i}|\bm{\theta}).roman_ln italic_p ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } | bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_ln italic_p ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_italic_θ ) . (5)

A maximum likelihood estimator can then be formed (iteratively) by the Fisher scoring method (Alsing & Wandelt, 2018):

𝜽^MLE=𝜽+F1t,superscript^𝜽MLEsubscript𝜽superscriptF1t\hat{\bm{\theta}}^{\rm MLE}=\bm{\theta}_{*}+\textbf{F}^{-1}\textbf{t},over^ start_ARG bold_italic_θ end_ARG start_POSTSUPERSCRIPT roman_MLE end_POSTSUPERSCRIPT = bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT + F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT t , (6)

which requires knowledge of a fiducial point 𝜽subscript𝜽\bm{\theta}_{*}bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT, the score, tnptsuperscriptsubscript𝑛𝑝\textbf{t}\in\mathbb{R}^{n_{p}}t ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and Fisher matrix, Fnp×npFsuperscriptsubscript𝑛𝑝subscript𝑛𝑝\textbf{F}\in\mathbb{R}^{n_{p}\times n_{p}}F ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. For problems like linear regression where the analytic form of F and t are known, Eq. (6) gives the exact MLE for the parameters in a single iteration in the Gaussian approximation, given the dataset. In the case of independent data, both of the score and Fisher information are additive.

Taking the gradient of the log-likelihood with respect to the parameters, the score t=𝜽lnp({di}|𝜽)tsubscriptbold-∇𝜽𝑝conditionalsubscriptd𝑖𝜽\textbf{t}=\bm{\nabla}_{\bm{\theta}}\ln p(\{\textbf{d}_{i}\}|\bm{\theta})t = bold_∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_ln italic_p ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } | bold_italic_θ ) for the full dataset is the sum of the scores of the individual data points:

t=i=1ndata𝜽lnp(di|𝜽)=i=1ndatati(di)tsuperscriptsubscript𝑖1subscript𝑛datasubscriptbold-∇𝜽𝑝conditionalsubscriptd𝑖𝜽superscriptsubscript𝑖1subscript𝑛datasubscriptt𝑖subscriptd𝑖\textbf{t}=\sum_{i=1}^{n_{\rm data}}\bm{\nabla}_{\bm{\theta}}\ln p(\textbf{d}_% {i}|\bm{\theta})=\sum_{i=1}^{n_{\rm data}}\textbf{t}_{i}(\textbf{d}_{i})t = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_ln italic_p ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (7)

Taking the gradient again yields the Hessian, or Fisher information matrix (Amari, 2021; Vaart, 1998) for the dataset,

F=i=1ndata𝜽𝜽Tlnp(di|𝜽)=i=1ndataFi(di),Fsuperscriptsubscript𝑖1subscript𝑛datasubscriptbold-∇𝜽superscriptsubscriptbold-∇𝜽𝑇𝑝conditionalsubscriptd𝑖𝜽superscriptsubscript𝑖1subscript𝑛datasubscriptF𝑖subscriptd𝑖\textbf{F}=\sum_{i=1}^{n_{\rm data}}\bm{\nabla}_{\bm{\theta}}\bm{\nabla}_{\bm{% \theta}}^{T}\ln p(\textbf{d}_{i}|\bm{\theta})=\sum_{i=1}^{n_{\rm data}}\textbf% {F}_{i}(\textbf{d}_{i}),F = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_ln italic_p ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT F start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (8)

which is also comprised of a sum of Fisher matrices of individual data. Once the score and Fisher matrix for a dataset are known, the two can be combined to form a pseudo-maximum likelihood estimate (MLE) for the target parameters following (6). Therefore, constructing optimal embeddings of independent data with respect to specific quantities of interest just requires aggregating the scores and Fishers, and combining them as in (6). However, in general explicit forms for the likelihood (per data vector) may not be known. In this general case, as we will show in the following section, we can parameterize and learn the score and Fisher using neural networks.

2.3 Twin Fisher-Score Networks

For many problems, however, the analytic form of the Fisher and score are not known. Here we propose learning these functions with neural networks. Due to the additive structure of (8) and (7), we can parameterize the per-datapoint score and Fisher with twin neural networks:

t^isubscript^t𝑖\displaystyle\hat{\textbf{t}}_{i}over^ start_ARG t end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =t(di;wt)np;tNN=indatat^iformulae-sequenceabsenttsubscriptd𝑖subscript𝑤𝑡superscriptsubscript𝑛𝑝subscripttNNsuperscriptsubscript𝑖subscript𝑛datasubscript^t𝑖\displaystyle=\textbf{t}(\textbf{d}_{i};w_{t})\in\mathbb{R}^{n_{p}};\ \ \ \ % \textbf{t}_{\rm NN}=\sum_{i}^{n_{\rm data}}\hat{\textbf{t}}_{i}= t ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ; t start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT over^ start_ARG t end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (9)
F^isubscript^F𝑖\displaystyle\hat{\textbf{F}}_{i}over^ start_ARG F end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =F(di;wF)np×np;FNN=indataF^iformulae-sequenceabsentFsubscriptd𝑖subscript𝑤𝐹superscriptsubscript𝑛𝑝subscript𝑛𝑝subscriptFNNsuperscriptsubscript𝑖subscript𝑛datasubscript^F𝑖\displaystyle=\textbf{F}(\textbf{d}_{i};w_{F})\in\mathbb{R}^{n_{p}\times n_{p}% };\ \ \ \ {\textbf{F}}_{\rm NN}=\sum_{i}^{n_{\rm data}}\hat{\textbf{F}}_{i}= F ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ; F start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT over^ start_ARG F end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (10)

where the score and Fisher network are parameterized by weights wtsubscript𝑤𝑡w_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and wFsubscript𝑤𝐹w_{F}italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT, respectively. The twin networks output a score and Fisher for each datapoint (see Appendix B for formalism), which are then each summed to obtain a global score and Fisher for the dataset. We can then compute parameter estimates using these aggregated embeddings following (6):

𝜽^NN({di};wt,wF)=FNN1tNN+c,subscript^𝜽NNsubscriptd𝑖subscript𝑤𝑡subscript𝑤𝐹subscriptsuperscriptF1NNsubscripttNN𝑐\hat{\bm{\theta}}_{\rm NN}\left(\{\textbf{d}_{i}\};w_{t},w_{F}\right)=\textbf{% F}^{-1}_{\rm NN}\textbf{t}_{\rm NN}+c,over^ start_ARG bold_italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } ; italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) = F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT t start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT + italic_c , (11)

where the fiducial point 𝜽=c=0subscript𝜽𝑐0\bm{\theta}_{*}=c=0bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT = italic_c = 0 can be set to an arbitrary constant. Provided the embeddings t^isubscript^t𝑖\hat{\textbf{t}}_{i}over^ start_ARG t end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and F^isubscript^F𝑖\hat{\textbf{F}}_{i}over^ start_ARG F end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are descriptive enough, the summation formalism can be used to obtain Fisher and score estimates under an implicit likelihood for datasets with heterogeneous structure and arbitrary size. These summaries can be regarded as sufficient statistics, since the score as a function of parameters could in principle be used to reconstruct the likelihood surface up to a constant (Alsing & Wandelt, 2018; Hoffmann & Onnela, 2022).

Loss Function. In a regression scenario, we draw data-parameter pairs from the joint distribution 𝜽,{di}p({di},𝜽)𝜽subscriptd𝑖𝑝subscriptd𝑖𝜽\bm{\theta},\{\textbf{d}_{i}\}\curvearrowleft p(\{\textbf{d}_{i}\},\bm{\theta})bold_italic_θ , { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } ↶ italic_p ( { d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , bold_italic_θ ) and compute 𝜽NNsubscript𝜽NN\bm{\theta}_{\rm NN}bold_italic_θ start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT from (11). The twin networks can then be trained jointly using a negative-log Gaussian loss:

(𝜽,𝜽^NN;wt,wF)=12(𝜽𝜽^NN)TFNN(𝜽𝜽^NN)12lndetFNN.𝜽subscript^𝜽NNsubscript𝑤𝑡subscript𝑤𝐹12superscript𝜽subscript^𝜽NN𝑇subscriptFNN𝜽subscript^𝜽NN12subscriptFNN\mathcal{L}(\bm{\theta},\hat{\bm{\theta}}_{\rm NN};\ w_{t},w_{F})=\\ \frac{1}{2}(\bm{\theta}-\hat{\bm{\theta}}_{\rm NN})^{T}\textbf{F}_{\rm NN}(\bm% {\theta}-\hat{\bm{\theta}}_{\rm NN})-\frac{1}{2}\ln\det\textbf{F}_{\rm NN}.start_ROW start_CELL caligraphic_L ( bold_italic_θ , over^ start_ARG bold_italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) = end_CELL end_ROW start_ROW start_CELL divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_θ - over^ start_ARG bold_italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT F start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT ( bold_italic_θ - over^ start_ARG bold_italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_ln roman_det F start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT . end_CELL end_ROW (12)

where 𝒘=(wt,wF)𝒘subscript𝑤𝑡subscript𝑤𝐹\bm{w}=(w_{t},w_{F})bold_italic_w = ( italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) Minimizing this loss with respect to the neural network weights ensures that information is saturated via maximising the aggregated Fisher, and forces the distance between embedding MLE and parameters to be minimized with respect to the Cramér-Rao bound ((3)) as a function of the data. This loss can also be interpreted as a maximum likelihood (MLE) loss for the quantities of interest 𝜽𝜽\bm{\theta}bold_italic_θ, as opposed to typical mean-square error (MSE) regression losses (see Appendix E for deepsets details).

3 Related Work

Deepsets Mean Aggregation. A comparable method for learning over sets of data is regression using the Deepsets (DS) formalism (Zaheer et al., 2018). Here an embedding f(di;w1)𝑓subscriptd𝑖subscript𝑤1f(\textbf{d}_{i};w_{1})italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is learned for each datum, and then aggregated with a fixed permutation-invariant scheme and fed to a global function g𝑔gitalic_g; 𝜽^=g(i=1ndataf(di;w1);w2)^𝜽𝑔superscriptsubscriptdirect-sum𝑖1subscript𝑛data𝑓subscriptd𝑖subscript𝑤1subscript𝑤2\hat{\bm{\theta}}=g\left(\bigoplus_{i=1}^{n_{\rm data}}f(\textbf{d}_{i};w_{1})% ;\ \ w_{2}\right)over^ start_ARG bold_italic_θ end_ARG = italic_g ( ⨁ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ; italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). The networks are optimised minimising a squared loss against the true parameters, MSE(𝜽^,𝜽)MSE^𝜽𝜽\rm{MSE}(\hat{\bm{\theta}},\bm{\theta})roman_MSE ( over^ start_ARG bold_italic_θ end_ARG , bold_italic_θ ). When the aggregation is chosen to be the mean, the deepsets formalism is scalable to arbitrary data and becomes equivalent to the Fishnets aggregation formalism with flat weights across the aggregated data (see Appendix E for in-depth treatment).

Learned Softmax Aggregation. Li et al. present a learnable softmax counterpart to the DS aggregation scheme in the context of edge aggregation in GNNs. Using the above notation, their aggregation scheme reads:

SoftmaxAgg()=i=1ndataexp(βf(di;w1))lexp(βf(dl;w1))f(di;w1)SoftmaxAggsuperscriptsubscript𝑖1subscript𝑛data𝛽𝑓subscriptd𝑖subscript𝑤1subscript𝑙𝛽𝑓subscriptd𝑙subscript𝑤1𝑓subscriptd𝑖subscript𝑤1\text{SoftmaxAgg}(\cdot)=\sum_{i=1}^{n_{\rm data}}\frac{\exp{\left(\beta f(% \textbf{d}_{i};w_{1})\right)}}{\sum_{l}\exp{\left(\beta f(\textbf{d}_{l};w_{1}% )\right)}}\cdot f(\textbf{d}_{i};w_{1})SoftmaxAgg ( ⋅ ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG roman_exp ( italic_β italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp ( italic_β italic_f ( d start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) end_ARG ⋅ italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (13)

where β𝛽\betaitalic_β is a learned scalar temperature parameter and f(;w1)𝑓subscript𝑤1f(\cdot;w_{1})italic_f ( ⋅ ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) is some embedding layer. They show that adopting this aggregation scheme allows more graph convolution (GCN) layers to be stacked efficiently to deepen GNN models. Many other aggregation frameworks have been studied, including Graph Attention (Veličković et al., 2018), LSTM units (Hamilton et al., 2017), Recurrent aggregations (Soelch et al., 2019), and scaled multiple aggregators (Corso et al., 2020a).

4 Experiments: Bayesian Information Saturation

Bayesian Simulation Based Inference (SBI) provides a framework in which to perform inference with intractable likelihood. There have been massive developments in SBI, such as neural ratio estimation (Miller et al., 2021) and density estimation (Alsing et al., 2019; Papamakarios et al., 2019). Key to all of these methods is compressing a large number of data down to small summaries–typically one informative summary per parameter of interest to preserve information (Alsing & Wandelt, 2018; Charnock et al., 2018; Makinen et al., 2021). ML methods like regression (Jeffrey & Wandelt, 2020) and information-maximising neural networks (Charnock et al., 2018; Makinen et al., 2022, 2021) are very good at learning embeddings for highly structured data like images, and can do so losslessly (Makinen et al., 2021). For unstructured datasets comprised of many independent data, the task of constructing optimal summaries amounts to an aggregation task (Zaheer et al., 2018; Hoffmann & Onnela, 2022; Wagstaff et al., 2019). The Fishnets formalism is an optimal version of this aggregation. What deepsets and “learned” aggregation functions are missing is explicitly constructing the inverse-Fisher weights per datapoint, as well as being able to construct the total Fisher information, which is required to turn summaries into unbiased estimators (Alsing & Wandelt, 2018). Explicitly learning the F1superscriptF1\textbf{F}^{-1}F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT weights in addition to the score allows us to achieve 1) asymptotic optimality 2) scalability, and 3) robustness to changes in information content among the data.

In this section we demonstrate the 1) information saturation, 2) robustness and 3) scalability of the Fishnets aggregation through two examples in the context of SBI, and highlight the shortcomings of existing aggregators. We first investigate a linear regression scaling problem and introduce a robustness test in which Fishnets outperforms deepset and learned softmax aggregation on test data. We then extend Fishnets to an inference problem with nuisance (latent) parameters and censorship to demonstrate the applicability of network scaling to a regime where MCMC becomes intractable.

4.1 Validation Case: Linear Regression

We use a toy linear regression model to validate our method and demonstrate network scalability. We consider the form y=mx+b+ϵ𝑦𝑚𝑥𝑏italic-ϵy=mx+b+\epsilonitalic_y = italic_m italic_x + italic_b + italic_ϵ, where ϵ𝒩(0,σ)similar-toitalic-ϵ𝒩0𝜎\epsilon\sim\mathcal{N}(0,\sigma)italic_ϵ ∼ caligraphic_N ( 0 , italic_σ ), where the parameters of interest are the slope and intercept 𝜽=(m,b)𝜽𝑚𝑏\bm{\theta}=(m,b)bold_italic_θ = ( italic_m , italic_b ). This likelihood has an analytically-calculable score and Fisher matrix (see Appendix C.1), which can be used to calculate exact MLE estimates for the parameters 𝜽=(m,b)𝜽𝑚𝑏\bm{\theta}=(m,b)bold_italic_θ = ( italic_m , italic_b ) via (6). We choose wide Gaussian priors for θ𝜃\thetaitalic_θ, and uniform distributions for x[0,10]𝑥010x\in[0,10]italic_x ∈ [ 0 , 10 ] and σ[1,10]𝜎110\sigma\in[1,10]italic_σ ∈ [ 1 , 10 ]. For network training, we simulate 104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datasets of size ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 datapoints. For testing, we generate an additional 104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datasets of size ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datapoints to demonstrate scalability. See Appendix C.2 for neural architecture details.

Results. We display a comparison of test set performance to the true MLE solution in Figure LABEL:fig:info_saturation, and slices of the true and predicted score vectors as a function of input data. The networks are able to recover the exact score and Fisher information matrices (see Figure LABEL:fig:net_outputs), even when scaled up 20-fold. This test demonstrates that Fishnets can (1) saturate information on small training sets to enable scalable predictions on far larger aggregations of data (2).

4.2 Robustness to changes in the underlying data distributions

In real-world applications, actual data processed by a network might follow a different distribution than that seen in training. Here we compare three different network formalisms on changing shapes of target data distributions.

We train three networks on the same ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 datasets as before: a sum-aggregated Fishnets network, a mean-aggregated deepset, and a learned softmax-aggregated deepset with mean-square error loss. To demonstrate the improvement in aggregation Fishnets offers, we adopt smaller networks for the regression task (see Table 2 and Appendix C.2 for architecture details).

We apply our trained networks to test data ndata=850subscript𝑛data850n_{\rm data}=850italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 850 with noise variances and x𝑥xitalic_x values drawn from different distributions to the training data: σExp(λ=1.0)𝜎Exp𝜆1.0\sigma\curvearrowleft\textrm{Exp}(\lambda=1.0)italic_σ ↶ Exp ( italic_λ = 1.0 ) centred at σ=3.5𝜎3.5\sigma=3.5italic_σ = 3.5, truncated at σ=10.0𝜎10.0\sigma=10.0italic_σ = 10.0, and x𝒰(0,3)𝑥𝒰03x\curvearrowleft\mathcal{U}(0,3)italic_x ↶ caligraphic_U ( 0 , 3 ). The noise and covariate distributions have the same support as the training data, but have different expectation values and distributions, which can pose a problem for the mean-aggregation used in the standard deepsets formalism. We display results in Figure LABEL:fig:stresstest. The heterogeneous Fishnets aggregation allows the network to correctly embed the noisy data drawn from the different distributions, while a significant loss in information can be seen for flat mean aggregation. The learned softmax aggregation improves the width of the residual distribution, but is still significantly wider than the Fishnets solution. We quote numeric results in Table 2.

These robustness tests show that Fishnets successfully learns per-object embeddings (score) and weights (Fisher) within sets, while being robust to changing shapes of the training distributions of these quantities (3). This test also shows that even in a very simple prediction scenario, common and learned aggregators can suffer from robustness issues.

network # params MSE(𝒎^,𝒎𝐭𝐫𝐮𝐞)MSEbold-^𝒎subscript𝒎𝐭𝐫𝐮𝐞\textrm{{MSE}}\bm{(\hat{m},m_{\rm true})}MSE bold_( overbold_^ start_ARG bold_italic_m end_ARG bold_, bold_italic_m start_POSTSUBSCRIPT bold_true end_POSTSUBSCRIPT bold_) MSE(𝒄^,𝒄𝐭𝐫𝐮𝐞)MSEbold-^𝒄subscript𝒄𝐭𝐫𝐮𝐞\textrm{{MSE}}\bm{(\hat{c},c_{\rm true})}MSE bold_( overbold_^ start_ARG bold_italic_c end_ARG bold_, bold_italic_c start_POSTSUBSCRIPT bold_true end_POSTSUBSCRIPT bold_)
robustness test fishnets 10,8551085510,85510 , 855 0.007±0.017plus-or-minus0.0070.017\bm{0.007\pm 0.017}bold_0.007 bold_± bold_0.017 0.046±0.078plus-or-minus0.0460.078\bm{0.046\pm 0.078}bold_0.046 bold_± bold_0.078
deepset 87,8108781087,81087 , 810 0.120±0.178plus-or-minus0.1200.1780.120\pm 0.1780.120 ± 0.178 0.285±0.406plus-or-minus0.2850.4060.285\pm 0.4060.285 ± 0.406
softmax 87,8118781187,81187 , 811 0.042±0.069plus-or-minus0.0420.0690.042\pm 0.0690.042 ± 0.069 0.482±0.347plus-or-minus0.4820.3470.482\pm 0.3470.482 ± 0.347
Table 2: Summary of robustness testing for different set-based networks. Fishnets’ Fisher aggregation has an advantage over mean- and learned softmax deepsets aggregation when test data follows a different distribution than the training suite, and does so with an eigth of the number of learnable parameters.

4.3 Scalable Inference With Censorship and Nuisance Parameters

As a non-trivial information saturation example we consider a censorship inference problem with latent parameters inspired by epidemiological studies. Consider a serum which, when injected into a patient, decays over time, and the (heterogeneous) decay rate among people is not well known. A population of patients are injected with the serum and then asked to come back to the lab within tmax=10subscript𝑡max10t_{\rm max}=10italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = 10 days for a measurement of the remaining serum-levels in their blood, s𝑠sitalic_s. We can cast this problem as a Bayesian hierarchical model visualised in Figure LABEL:fig:plate_diagram (see Appendix C.3 for details) where the goal is to infer the mean μ𝜇\muitalic_μ and scale ΘΘ\Thetaroman_Θ of the decay rate Gamma distribution from the data, {τi,si}subscript𝜏𝑖subscript𝑠𝑖\{\tau_{i},s_{i}\}{ italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }. In the censored case, measurements are rejected if si<sminsubscript𝑠𝑖subscript𝑠mins_{i}<s_{\rm min}italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT, and collected until ndatasubscript𝑛datan_{\rm data}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT valid samples are collected. As a ground-truth comparison for the uncensored version of this problem, we sample the above hierarchical model using Hamiltonian Monte-Carlo (HMC). For comparison, we utilize the same Fishnets architecture and small-data training setup as before to predict (μ,Θ)𝜇Θ(\mu,\Theta)( italic_μ , roman_Θ ) from data inputs [τi,si]Tsuperscriptsubscript𝜏𝑖subscript𝑠𝑖𝑇[\tau_{i},s_{i}]^{T}[ italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. Once trained, we generated a new suite of ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 simulations and pass the data through the network to learn a neural posterior from (θ^NN,𝜽)subscript^𝜃NN𝜽(\hat{\theta}_{\rm NN},\bm{\theta})( over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT , bold_italic_θ ) pairs. We then evaluated both HMC and neural posteriors at the same target data. Finally, using the same network we perform the same procedure, this time with simulations of size ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, where the HMC becomes computationally prohibitive.

Results. We display inference results in Figure LABEL:fig:uncensored-corner. The summaries obtained from Fishnet compression of the small data (green) result in posteriors that hug the “true” MCMC contours (black), indicating information saturation. Extending the same network on the larger data results in intuitively smaller contours (blue). It should be emphasized that ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT is a regime where the MCMC inference is no longer tractable on standard devices. Fishnets here allows for 1) much faster posterior calculation and 2) allows for optimal inference on larger data set sizes without any retraining.

As a final demonstration we solve the same problem, this time subject to censorship. In the censored case, the target joint posterior defined by the hierarchical model requires computing an integral for the selection probability as a function of the model hyper-parameters; in general, these selection terms make Bayesian problems with censorship computationally challenging, or intractable in some cases (Qi et al., 2022; Dickey et al., 1987).

We train Fishnets on the small data size, subject to censorship below sminsubscript𝑠mins_{\rm min}italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT. We obtain posteriors of the same shape of the censored case, but for a consistency check perform a probability-integral transform (PIT) test for the neural posterior. For each parameter we want the marginal PIT test to yield a uniform distribution to show that the learned posterior behaves as a continuous distribution. We display these results in Figure 6. We obtain a Kolmogorov-Smirnov test (Massey Jr., 1951) p-value of 0.628 and 0.233 for parameters μ𝜇\muitalic_μ and ΘΘ\Thetaroman_Θ, respectively, indicating that our posterior is well-parameterized and robust.

5 Graph Neural Network Aggregation

Graphs can be thought of as tuples of sets within connected neighborhoods. Graph neural networks (GNNs) operate by message-passing along edges between nodes. For predicting node- and graph-level properties, an aggregation of these sets of edges {eij}subscripte𝑖𝑗\{\textbf{e}_{ij}\}{ e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT } or nodes {vi}subscriptv𝑖\{\textbf{v}_{i}\}{ v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } is required to reduce features to fixed-size feature vectors. Whereas in the SBI setting, we are interested in finding optimal estimators for specific parameters of interest, in the GNN aggregation setting we are implicitly trying to find a compact latent (embedding) representation of the aggregated neighborhood data, and optimally estimate and propagate those latent features through the GNN architecture.

Here we compare the Fishnets aggregation scheme as a drop-in replacement for learned softmax aggregators within the graph convolutional network (GCN) scheme presented by Li et al.. We can rewrite our aggregations to occur within neighborhoods of nodes:

SoftmaxAgg()SoftmaxAgg\displaystyle\text{SoftmaxAgg}(\cdot)SoftmaxAgg ( ⋅ ) =i𝒩(v)exp(βeiv)l𝒩exp(βeli)eiv,absentsubscript𝑖𝒩𝑣𝛽subscripte𝑖𝑣subscript𝑙𝒩𝛽subscripte𝑙𝑖subscripte𝑖𝑣\displaystyle=\sum_{i\in\mathcal{N}(v)}\frac{\exp{\left(\beta\textbf{e}_{iv}% \right)}}{\sum_{l\in\mathcal{N}}\exp{\left(\beta\textbf{e}_{li}\right)}}\cdot% \textbf{e}_{iv},= ∑ start_POSTSUBSCRIPT italic_i ∈ caligraphic_N ( italic_v ) end_POSTSUBSCRIPT divide start_ARG roman_exp ( italic_β e start_POSTSUBSCRIPT italic_i italic_v end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_l ∈ caligraphic_N end_POSTSUBSCRIPT roman_exp ( italic_β e start_POSTSUBSCRIPT italic_l italic_i end_POSTSUBSCRIPT ) end_ARG ⋅ e start_POSTSUBSCRIPT italic_i italic_v end_POSTSUBSCRIPT , (14)
FishnetsAgg()FishnetsAgg\displaystyle\text{FishnetsAgg}(\cdot)FishnetsAgg ( ⋅ ) =(i𝒩(v)F(eiv))1(i𝒩(v)t(eiv)),absentsuperscriptsubscript𝑖𝒩𝑣Fsubscripte𝑖𝑣1subscript𝑖𝒩𝑣tsubscripte𝑖𝑣\displaystyle=\left(\sum_{i\in\mathcal{N}(v)}\textbf{F}(\textbf{e}_{iv})\right% )^{-1}\left(\sum_{i\in\mathcal{N}(v)}\textbf{t}(\textbf{e}_{iv})\right),= ( ∑ start_POSTSUBSCRIPT italic_i ∈ caligraphic_N ( italic_v ) end_POSTSUBSCRIPT F ( e start_POSTSUBSCRIPT italic_i italic_v end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i ∈ caligraphic_N ( italic_v ) end_POSTSUBSCRIPT t ( e start_POSTSUBSCRIPT italic_i italic_v end_POSTSUBSCRIPT ) ) , (15)

where the aggregation occurs in a neighborhood 𝒩𝒩\mathcal{N}caligraphic_N of a node v𝑣vitalic_v. The Fishnets aggregation requires a bottleneck hyperparameter, npsubscript𝑛𝑝n_{p}italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, which controls the size of the score embedding t(eiv)nptsubscripte𝑖𝑣superscriptsubscript𝑛𝑝\textbf{t}(\textbf{e}_{iv})\in\mathbb{R}^{n_{p}}t ( e start_POSTSUBSCRIPT italic_i italic_v end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and Fisher Cholesky factors Fcholnp(np+1)/2subscriptFcholsuperscriptsubscript𝑛𝑝subscript𝑛𝑝12\textbf{F}_{\rm chol}\in\mathbb{R}^{n_{p}(n_{p}+1)/2}F start_POSTSUBSCRIPT roman_chol end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT + 1 ) / 2 end_POSTSUPERSCRIPT. We use a single linear layer before aggregation to obtain score and Fisher components from hidden layer embeddings.

5.1 Drop-in replacement for Graph Benchmark Datasets

Here we replace the learned softmax aggregation with Fishnets aggregation in Li et al.’s publicly-available best-performing models. We change four hyperparameters in testing our new architectures: number of layers, npsubscript𝑛𝑝n_{p}italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, dropout, and learning rate. We study several graph datasets from the Open Graph Benchmark (OGB) (Hu et al., 2020, 2021), which require substantial aggregation steps to predict either node or graph-level properties. The object of this study is to investigate how well fishnets aggregation can perform within an existing architecture, with fewer layers and minimal hyperoptimisation.

Results. We display benchmark results in Table 1, and refer the reader to Appendix D for architecture and dataset details. This small drop-in study shows that incorporating the more information-efficient Fishnets Aggregation, we can achieve better than or similar results to SOTA GCNs with a fraction of the trainable parameters and training epochs.

5.2 Focus Study on ogbn-proteins Benchmark

In this section we study the proteins dataset in detail highlight a scenario where the heterogeneous Fishnets aggregation drastically improves performance. Here we expect different node neighborhoods to have a heterogeneous edge weighting “association score” structure across protein categories, making the Fishnets aggregation ideal for applicability beyond the training set, as in the linear regression case. The association scores can be stochastically modelled with added measurement noise, increasing the difficulty of the classification problem. We adopt a stripped-down version of the training routine presented in (Li et al., 2020) (no subgraph and edge preprocessing) to make modifications to the raw data by adding noise. We first benchmark our training routine with smaller GCN and Fishnets aggregation on the noise-free data, and then proceed to adding noise to the edges.

Noisefree Results. We display representative test ROC-AUC curves over training in Figure LABEL:fig:gnn_benchmark, and in Table 3. Fishnets-16 and Fishnets-20 clearly saturate information within 250 epochs to 79.63%percent79.6379.63\%79.63 % and 81.10%percent81.1081.10\%81.10 % accuracy respectively.

5.2.1 Modelling Uncertain Protein Associations.

Here we incorporate uncertainties on the protein interaction strengths (edges), in order to demonstrate the robustness of the Fishnets approach to changes in the underlying data (noise) distribution on the graph features. We model noisy “measurements” of the protein graph edge associations using a simple Binomial model: taking the dataset edges pij=eij[0,1)subscriptp𝑖𝑗subscripte𝑖𝑗01\textbf{p}_{ij}=\textbf{e}_{ij}\in[\textbf{0},\textbf{1})p start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ∈ [ 0 , 1 ) as the“true” association strengths, we can simulate a noisy measurement of those quantities as N𝑁Nitalic_N weighted coin tosses per edge, where N𝑁Nitalic_N varies between measurements:

N𝒰(20,200)𝑁𝒰20200\displaystyle N\curvearrowleft\mathcal{U}(20,200)italic_N ↶ caligraphic_U ( 20 , 200 ) (16)
nsuccessBinomial(n=N,p=pij)subscriptnsuccessBinomialformulae-sequence𝑛𝑁𝑝subscriptp𝑖𝑗\displaystyle\textbf{n}_{\rm success}\curvearrowleft\text{Binomial}\left(n=N,p% =\textbf{p}_{ij}\right)n start_POSTSUBSCRIPT roman_success end_POSTSUBSCRIPT ↶ Binomial ( italic_n = italic_N , italic_p = p start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) (17)
eij[p^ij=nsuccess/N,N].subscripte𝑖𝑗delimited-[]subscript^p𝑖𝑗subscriptnsuccess𝑁𝑁\displaystyle\textbf{e}_{ij}\leftarrow\left[\hat{\textbf{p}}_{ij}=\textbf{n}_{% \rm success}/N,N\right].e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ← [ over^ start_ARG p end_ARG start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = n start_POSTSUBSCRIPT roman_success end_POSTSUBSCRIPT / italic_N , italic_N ] . (18)

Note that in the last step the new graph edge now contains the (noisy) measured associations, as well as N𝑁Nitalic_N (which provides a measure of uncertainty on those estimated interaction strengths). The GNN task is now to learn to re-weight predictions conditioned on the provided N𝑁Nitalic_N coin toss information, much like feeding in σ𝜎\sigmaitalic_σ in the linear regression case. We train a 28-layer GCN and 20-layer Fishnets. For the test dataset, we alter the distribution for N𝑁Nitalic_N to be 𝒰(20,50)+𝒰(170,200)𝒰2050𝒰170200\mathcal{U}(20,50)+\mathcal{U}(170,200)caligraphic_U ( 20 , 50 ) + caligraphic_U ( 170 , 200 ) such that we sample the extremes of the training distribution support.

test network # params test ROC-AUC
noisefree fishnets-20 442,372442372442,372442 , 372 0.8110±0.0021plus-or-minus0.81100.0021\bm{0.8110\pm 0.0021}bold_0.8110 bold_± bold_0.0021
GCN-28 477,964477964477,964477 , 964 0.7951±0.0059plus-or-minus0.79510.00590.7951\pm 0.00590.7951 ± 0.0059
noisy- fishnets-20 442,500442500{442,500}442 , 500 0.7198±0.0109plus-or-minus0.71980.0109\bm{0.7198\pm 0.0109}bold_0.7198 bold_± bold_0.0109
edges GCN-28 478,092478092478,092478 , 092 0.6471±0.0090plus-or-minus0.64710.00900.6471\pm 0.00900.6471 ± 0.0090
Table 3: Summary of performance on benchmark and noisy edge variants of the proteins dataset. Errorbars denote standard deviation of test ROC-AUC in the last ten epochs of training.

Noisy Results. We display test ROC-AUC curves for both networks in Figure LABEL:fig:gnn_noise, subject to a patience setting of 250 epochs on the validation set. The GCN framework exhibits an early plateau at 64.71%percent64.7164.71\%64.71 % accuracy, while Fishnets saturates to 71.98%percent71.9871.98\%71.98 % accuracy. This stark difference in behaviour can be explained by the difference in formalism: The Fishnets aggregation explicitly learns a weighting scheme as a function of measured edge probabilities and the conditional information N𝑁Nitalic_N, much like the linear regression case where σ𝜎\sigmaitalic_σ was passed as an input. This scheme helps to learn how to deal with edge-case noise artefacts like the noisy edge test case. Explicitly specifying the inverse-Fisher weighting formalism as an inductive bias (Battaglia et al., 2018) during aggregation can help explain the fast information saturation exhibited in both graph test settings.

6 Discussion & Future Work

In this paper we built up an information-theoretic approach to optimal aggregation in the form of Fishnets. Through progressively non-trivial examples, we demonstrated that explicitly parameterizing the score and inverse-Fisher weights of set members results in an aggregation scheme that saturates Bayesian information in non-trivial problems, and also serves as an optimal aggregator for sets and graph neural networks in heterogeneous data scenarios.

The stark improvement in information saturation on the proteins test dataset relative to architecture size and training efficiency indicates that the Fishnets aggregation acts as an information-level inductive bias for GNN aggregation. Follow-up study is warranted on optimizing hyperparameter choices for graph neural network architectures using Fishnets. We chose to demonstrate improved information capture by using an ablation study of smaller models, but careful (and potentially bigger) network design would almost certainly improve results here and potentially achieve SOTA accuracy on common benchmarks.

References

  • Alon & Yahav (2021) Alon, U. and Yahav, E. On the bottleneck of graph neural networks and its practical implications, 2021.
  • Alsing & Wandelt (2018) Alsing, J. and Wandelt, B. Generalized massive optimal data compression. Monthly Notices of the Royal Astronomical Society: Letters, 476(1):L60–L64, feb 2018. doi: 10.1093/mnrasl/sly029. URL https://doi.org/10.1093%2Fmnrasl%2Fsly029.
  • Alsing et al. (2019) Alsing, J., Charnock, T., Feeney, S., and Wandelt, B. Fast likelihood-free cosmology with neural density estimators and active learning. Monthly Notices of the Royal Astronomical Society, Jul 2019. ISSN 1365-2966. doi: 10.1093/mnras/stz1960. URL http://dx.doi.org/10.1093/mnras/stz1960.
  • Amari (2021) Amari, S.-i. Information geometry. Japanese Journal of Mathematics, 16(1):1–48, January 2021. ISSN 1861-3624. doi: 10.1007/s11537-020-1920-5. URL https://doi.org/10.1007/s11537-020-1920-5.
  • Battaglia et al. (2018) Battaglia, P. W., Hamrick, J. B., Bapst, V., Sanchez-Gonzalez, A., Zambaldi, V., Malinowski, M., Tacchetti, A., Raposo, D., Santoro, A., Faulkner, R., Gulcehre, C., Song, F., Ballard, A., Gilmer, J., Dahl, G., Vaswani, A., Allen, K., Nash, C., Langston, V., Dyer, C., Heess, N., Wierstra, D., Kohli, P., Botvinick, M., Vinyals, O., Li, Y., and Pascanu, R. Relational inductive biases, deep learning, and graph networks, 2018.
  • Charnock et al. (2018) Charnock, T., Lavaux, G., and Wandelt, B. D. Automatic physical inference with information maximizing neural networks. Physical Review D, 97(8), apr 2018. doi: 10.1103/physrevd.97.083004. URL https://doi.org/10.1103%2Fphysrevd.97.083004.
  • Clevert et al. (2015) Clevert, D.-A., Unterthiner, T., and Hochreiter, S. Fast and accurate deep network learning by exponential linear units (elus), 2015. URL https://arxiv.longhoe.net/abs/1511.07289.
  • Corso et al. (2020a) Corso, G., Cavalleri, L., Beaini, D., Liò, P., and Velickovic, P. Principal neighbourhood aggregation for graph nets. CoRR, abs/2004.05718, 2020a. URL https://arxiv.longhoe.net/abs/2004.05718.
  • Corso et al. (2020b) Corso, G., Cavalleri, L., Beaini, D., Liò, P., and Veličković, P. Principal neighbourhood aggregation for graph nets, 2020b.
  • Coulton & Wandelt (2023) Coulton, W. R. and Wandelt, B. D. How to estimate fisher information matrices from simulations, 2023.
  • Cramér (1946) Cramér, H. Mathematical methods of statistics, by Harald Cramer, .. The University Press, 1946.
  • Cranmer et al. (2020) Cranmer, K., Brehmer, J., and Louppe, G. The frontier of simulation-based inference. Proceedings of the National Academy of Sciences, 117(48):30055–30062, may 2020. doi: 10.1073/pnas.1912789117. URL https://doi.org/10.1073%2Fpnas.1912789117.
  • Dickey et al. (1987) Dickey, J. M., Jiang, J.-M., and Kadane, J. B. Bayesian methods for censored categorical data. Journal of the American Statistical Association, 82(399):773–781, 1987. ISSN 01621459. URL http://www.jstor.org/stable/2288786.
  • Fey & Lenssen (2019) Fey, M. and Lenssen, J. E. Fast graph representation learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.
  • Giovanni et al. (2024) Giovanni, F. D., Rusch, T. K., Bronstein, M. M., Deac, A., Lackenby, M., Mishra, S., and Veličković, P. How does over-squashing affect the power of gnns?, 2024.
  • Hamilton et al. (2017) Hamilton, W., Ying, Z., and Leskovec, J. Inductive representation learning on large graphs. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings.neurips.cc/paper_files/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf.
  • Hoffmann & Onnela (2022) Hoffmann, T. and Onnela, J.-P. Minimizing the expected posterior entropy yields optimal summary statistics, 2022. URL https://arxiv.longhoe.net/abs/2206.02340.
  • Hu et al. (2020) Hu, W., Fey, M., Zitnik, M., Dong, Y., Ren, H., Liu, B., Catasta, M., and Leskovec, J. Open graph benchmark: Datasets for machine learning on graphs. arXiv preprint arXiv:2005.00687, 2020.
  • Hu et al. (2021) Hu, W., Fey, M., Ren, H., Nakata, M., Dong, Y., and Leskovec, J. Ogb-lsc: A large-scale challenge for machine learning on graphs. arXiv preprint arXiv:2103.09430, 2021.
  • Jeffrey & Wandelt (2020) Jeffrey, N. and Wandelt, B. D. Solving high-dimensional parameter inference: marginal posterior densities &amp; moment networks, 2020. URL https://arxiv.longhoe.net/abs/2011.05991.
  • Kipf & Welling (2017) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks, 2017.
  • Lehmann & Casella (1998) Lehmann, E. L. and Casella, G. Theory of Point Estimation. Springer-Verlag, New York, NY, USA, second edition, 1998.
  • Li et al. (2020) Li, G., Xiong, C., Thabet, A., and Ghanem, B. Deepergcn: All you need to train deeper gcns, 2020.
  • Makinen et al. (2021) Makinen, T. L., Charnock, T., Alsing, J., and Wandelt, B. D. Lossless, scalable implicit likelihood inference for cosmological fields. Journal of Cosmology and Astroparticle Physics, 2021(11):049, nov 2021. doi: 10.1088/1475-7516/2021/11/049. URL https://doi.org/10.1088%2F1475-7516%2F2021%2F11%2F049.
  • Makinen et al. (2022) Makinen, T. L., Charnock, T., Lemos, P., Porqueres, N., Heavens, A. F., and Wandelt, B. D. The cosmic graph: Optimal information extraction from large-scale structure using catalogues. The Open Journal of Astrophysics, 5(1), dec 2022. doi: 10.21105/astro.2207.05202. URL https://doi.org/10.21105%2Fastro.2207.05202.
  • Massey Jr. (1951) Massey Jr., F. J. The kolmogorov-smirnov test for goodness of fit. Journal of the American Statistical Association, 46(253):68–78, 1951. doi: 10.1080/01621459.1951.10500769. URL https://www.tandfonline.com/doi/abs/10.1080/01621459.1951.10500769.
  • Miller et al. (2021) Miller, B. K., Cole, A., Forré, P., Louppe, G., and Weniger, C. Truncated marginal neural ratio estimation. In Ranzato, M., Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J. W. (eds.), Advances in Neural Information Processing Systems, volume 34, pp.  129–143. Curran Associates, Inc., 2021. URL https://proceedings.neurips.cc/paper_files/paper/2021/file/01632f7b7a127233fa1188bd6c2e42e1-Paper.pdf.
  • Papamakarios et al. (2019) Papamakarios, G., Sterratt, D., and Murray, I. Sequential neural likelihood: Fast likelihood-free inference with autoregressive flows. In Chaudhuri, K. and Sugiyama, M. (eds.), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics, volume 89 of Proceedings of Machine Learning Research, pp.  837–848. PMLR, 16–18 Apr 2019. URL https://proceedings.mlr.press/v89/papamakarios19a.html.
  • Phan et al. (2019) Phan, D., Pradhan, N., and Jankowiak, M. Composable effects for flexible and accelerated probabilistic programming in numpyro. arXiv preprint arXiv:1912.11554, 2019.
  • Qi et al. (2022) Qi, X., Zhou, S., and Plummer, M. On Bayesian modeling of censored data in JAGS. BMC Bioinformatics, 23(1):102, March 2022. ISSN 1471-2105. doi: 10.1186/s12859-021-04496-8. URL https://doi.org/10.1186/s12859-021-04496-8.
  • Ramachandran et al. (2017) Ramachandran, P., Zoph, B., and Le, Q. V. Searching for activation functions, 2017.
  • Soelch et al. (2019) Soelch, M., Akhundov, A., van der Smagt, P., and Bayer, J. On Deep Set Learning and the Choice of Aggregations, pp.  444–457. Springer International Publishing, 2019. ISBN 9783030304874. doi: 10.1007/978-3-030-30487-4˙35. URL http://dx.doi.org/10.1007/978-3-030-30487-4_35.
  • Vaart (1998) Vaart, A. W. v. d. Asymptotic Statistics. Cambridge Series in Statistical and Probabilistic Mathematics. Cambridge University Press, 1998. doi: 10.1017/CBO9780511802256.
  • Veličković et al. (2018) Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., and Bengio, Y. Graph attention networks, 2018.
  • Wagstaff et al. (2019) Wagstaff, E., Fuchs, F. B., Engelcke, M., Posner, I., and Osborne, M. On the limitations of representing functions on sets, 2019.
  • Xu et al. (2019) Xu, K., Hu, W., Leskovec, J., and Jegelka, S. How powerful are graph neural networks?, 2019.
  • Zaheer et al. (2018) Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R., and Smola, A. Deep sets, 2018.
  • Zhou et al. (2020) Zhou, J., Cui, G., Hu, S., Zhang, Z., Yang, C., Liu, Z., Wang, L., Li, C., and Sun, M. Graph neural networks: A review of methods and applications. AI Open, 1:57–81, 2020. ISSN 2666-6510. doi: https://doi.org/10.1016/j.aiopen.2021.01.001. URL https://www.sciencedirect.com/science/article/pii/S2666651021000012.

Appendix A Saturating the Information Inequality over Parameter space

Here we show that knowing the score function t=t\textbf{t}=\nabla\mathcal{L}t = ∇ caligraphic_L saturates the information inequality and provides a natural data compression function. We first consider the Taylor expansion of the log-likelihood around a fixed fiducial point in parameter space, 𝜽subscript𝜽\bm{\theta}_{*}bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT, (where gg(𝜽=𝜽)subscript𝑔𝑔𝜽subscript𝜽g_{*}\equiv g(\bm{\theta}=\bm{\theta}_{*})italic_g start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ≡ italic_g ( bold_italic_θ = bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT )):

=+δ𝜽T12δ𝜽TJδ𝜽subscript𝛿superscript𝜽𝑇subscript12𝛿superscript𝜽𝑇subscriptJ𝛿𝜽\mathcal{L}=\mathcal{L}_{*}+\delta\bm{\theta}^{T}\nabla\mathcal{L}_{*}-\frac{1% }{2}\delta\bm{\theta}^{T}\textbf{J}_{*}\delta\bm{\theta}caligraphic_L = caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT + italic_δ bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_δ bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT J start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT italic_δ bold_italic_θ (19)

where J=TJsuperscript𝑇\textbf{J}=-\nabla\nabla^{T}\mathcal{L}J = - ∇ ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L is the observed information matrix. To linear order in 𝜽𝜽\bm{\theta}bold_italic_θ, the data d couples to the parameters through the score function tnptsuperscriptsubscript𝑛𝑝\textbf{t}\in\mathbb{R}^{n_{p}}t ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We can show that t saturates the information inequality via

Cov𝜽[t, t]=𝔼𝜽[T]=F,subscriptCov𝜽delimited-[]t, tsubscript𝔼𝜽delimited-[]subscriptsubscriptsuperscript𝑇subscriptF\mathrm{Cov}_{\bm{\theta}}\left[\textbf{t, {t}}\right]=\mathbb{E}_{\bm{\theta}% }\left[\nabla\mathcal{L}_{*}\nabla^{T}_{*}\right]=\textbf{F}_{*},roman_Cov start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ t, t ] = blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ] = F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT , (20)

where we have used the fact that 𝔼𝜽[]=0subscript𝔼𝜽delimited-[]subscript0\mathbb{E}_{\bm{\theta}}\left[\nabla\mathcal{L}_{*}\right]=0blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ] = 0. From this we observe that the covariance of the score function is the Fisher matrix. Using the fact that

A=𝔼𝜽[T]=𝔼𝜽[T]=F,Asubscript𝔼𝜽delimited-[]superscript𝑇subscript𝔼𝜽delimited-[]superscript𝑇subscriptF\textbf{A}=\nabla\mathbb{E}_{\bm{\theta}}\left[\nabla^{T}\mathcal{L}\right]=% \mathbb{E}_{\bm{\theta}}\left[\nabla\nabla^{T}\mathcal{L}\right]=-\textbf{F}_{% *},A = ∇ blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L ] = blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ ∇ ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L ] = - F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT , (21)

the right-hand side of the information inequality becomes ATF1A=FsuperscriptsubscriptA𝑇superscriptsubscriptF1subscriptAsubscriptF\textbf{A}_{*}^{T}\textbf{F}_{*}^{-1}\textbf{A}_{*}=\textbf{F}_{*}A start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT A start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT = F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT, which shows that the score statistics t saturate the information inequality. Within this formalism, no statistics can provide more (Fisher) information about the parameters 𝜽𝜽\bm{\theta}bold_italic_θ.

We can relate this information saturation to an optima, quasi maximum-likelihood estimator whose covariance is equal to the inverse Fisher information from the above derivation. Maximising the Taylor expansion (19) with respect to the parameters yields

𝜽^=𝜽+J1^𝜽subscript𝜽superscriptsubscriptJ1subscript\hat{\bm{\theta}}=\bm{\theta}_{*}+\textbf{J}_{*}^{-1}\nabla\mathcal{L}_{*}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT + J start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT (22)

where both the score t=subscripttsubscript\textbf{t}_{*}=\nabla\mathcal{L}_{*}t start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT = ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT and the observed information J1subscriptsuperscriptJ1\textbf{J}^{-1}_{*}J start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT depend on the observed data. In practice, we can exchange J with its expectation value, the Fisher information: F𝔼𝜽[J]subscriptFsubscript𝔼𝜽delimited-[]subscriptJ\textbf{F}_{*}\equiv\mathbb{E}_{\bm{\theta}}\left[\textbf{J}_{*}\right]F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ≡ blackboard_E start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT [ J start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ], which yields

𝜽^=𝜽+F1.^𝜽subscript𝜽superscriptsubscriptF1subscript\hat{\bm{\theta}}=\bm{\theta}_{*}+\textbf{F}_{*}^{-1}\nabla\mathcal{L}_{*}.over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT + F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT . (23)

Making this replacement means the MLE estimator only depends on the data through the score function statistics t=tsubscript\textbf{t}=\nabla\mathcal{L}_{*}t = ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT. The covariance of the MLE estimator (at the expansion point 𝜽subscript𝜽\bm{\theta}_{*}bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT) is then:

Cov𝜽[𝜽^,𝜽^]=F1𝔼𝜽[T]F1=F1,subscriptCovsubscript𝜽^𝜽^𝜽superscriptsubscriptF1subscript𝔼subscript𝜽delimited-[]subscriptsuperscript𝑇subscriptsuperscriptsubscriptF1superscriptsubscriptF1\mathrm{Cov}_{\bm{\theta}_{*}}\left[\hat{\bm{\theta}},\hat{\bm{\theta}}\right]% =\textbf{F}_{*}^{-1}\mathbb{E}_{\bm{\theta}_{*}}\left[\nabla\mathcal{L}_{*}% \nabla^{T}\mathcal{L}_{*}\right]\textbf{F}_{*}^{-1}=\textbf{F}_{*}^{-1},roman_Cov start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over^ start_ARG bold_italic_θ end_ARG , over^ start_ARG bold_italic_θ end_ARG ] = F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ] F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , (24)

where 𝔼𝜽[T]Fsubscript𝔼subscript𝜽delimited-[]subscriptsuperscript𝑇subscriptsubscriptF\mathbb{E}_{\bm{\theta}_{*}}\left[\nabla\mathcal{L}_{*}\nabla^{T}\mathcal{L}_{% *}\right]\equiv\textbf{F}_{*}blackboard_E start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ∇ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ] ≡ F start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT. Hence the covariance of the MLE is equal to the Fisher information matrix at 𝜽subscript𝜽\bm{\theta}_{*}bold_italic_θ start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT and the Cramér-Rao bound is saturated.

The above proof of information saturation calculated the information saturation around a fixed fiducial point. In general, however, by parameterising the score and Fisher functions with neural networks under the loss (12) we are learning an embedding t(di;wt)tsubscriptd𝑖subscript𝑤𝑡\textbf{t}(\textbf{d}_{i};w_{t})t ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and weighting neighborhood F(di;wF)Fsubscriptd𝑖subscript𝑤𝐹\textbf{F}(\textbf{d}_{i};w_{F})F ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ) as a function of data (and implicitly parameters).

Appendix B Calculating the Fisher Matrix from Network Outputs

To ensure that our Fisher matrix is positive-definite, our Fisher-score networks output nparams+nparams(nparams+1)2subscript𝑛paramssubscript𝑛paramssubscript𝑛params12n_{\rm params}+n_{\rm params}\frac{(n_{\rm params}+1)}{2}italic_n start_POSTSUBSCRIPT roman_params end_POSTSUBSCRIPT + italic_n start_POSTSUBSCRIPT roman_params end_POSTSUBSCRIPT divide start_ARG ( italic_n start_POSTSUBSCRIPT roman_params end_POSTSUBSCRIPT + 1 ) end_ARG start_ARG 2 end_ARG numbers as lower triangular entries in a Cholesky decomposition of the Fisher matrix, L. To ensure that the lower triangular entries remain positive-definite, we add a softplus activation to the diagonal entries of L:

diag(L)softplus(diag(L))diagLsoftplusdiagL{\rm diag}(\textbf{L})\leftarrow{\rm softplus(diag(\textbf{L}))}roman_diag ( L ) ← roman_softplus ( roman_diag ( L ) ) (25)

We then compute the Fisher via:

F=LLTFsuperscriptLL𝑇\textbf{F}=\textbf{L}\textbf{L}^{T}F = bold_L bold_L start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT (26)

The negative-log likelihood loss in Equation 12 allows for explicit interrogation of the resulting Fisher matrix at the level of the predicted quantities (parameters), and ensures that the summary space in 𝜽^^𝜽\hat{\bm{\theta}}over^ start_ARG bold_italic_θ end_ARG is convex. In the GNN regression formalism, Fishnets does not explicitly maximise the Fisher information as a part of the loss, rather the Fisher matrix weights are optimized as an inductive bias as a hidden layer in the GNN scheme. E

Appendix C Bayesian Information Experiment Details

C.1 Scalable Linear Regression

We use a toy linear regression model to validate our method and demonstrate network scalability. We consider the form y=mx+b+ϵ𝑦𝑚𝑥𝑏italic-ϵy=mx+b+\epsilonitalic_y = italic_m italic_x + italic_b + italic_ϵ, where ϵ𝒩(0,σ)similar-toitalic-ϵ𝒩0𝜎\epsilon\sim\mathcal{N}(0,\sigma)italic_ϵ ∼ caligraphic_N ( 0 , italic_σ ), where the parameters of interest are the slope and intercept 𝜽=(m,b)𝜽𝑚𝑏\bm{\theta}=(m,b)bold_italic_θ = ( italic_m , italic_b ). This likelihood has an analytically-calculable score and Fisher matrix,

t =i=1ndata1σi2[xi(yi(mfidxi+bfid))yi(mfidxi+bfid))]]+t0,\displaystyle=\sum_{i=1}^{n_{\rm data}}\frac{1}{\sigma_{i}^{2}}\begin{bmatrix}% &x_{i}(y_{i}-(m_{\rm fid}x_{i}+b_{\rm fid}))\\ &y_{i}-(m_{\rm fid}x_{i}+b_{\rm fid}))]\end{bmatrix}+\textbf{t}_{0},= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG [ start_ARG start_ROW start_CELL end_CELL start_CELL italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ( italic_m start_POSTSUBSCRIPT roman_fid end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT roman_fid end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ( italic_m start_POSTSUBSCRIPT roman_fid end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT roman_fid end_POSTSUBSCRIPT ) ) ] end_CELL end_ROW end_ARG ] + t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , (27)
F =i=1ndata1σi2[xi2xixi1]+𝐂p1,absentsuperscriptsubscript𝑖1subscript𝑛data1superscriptsubscript𝜎𝑖2matrixsuperscriptsubscript𝑥𝑖2subscript𝑥𝑖subscript𝑥𝑖1subscriptsuperscript𝐂1p\displaystyle=\sum_{i=1}^{n_{\rm data}}\frac{1}{\sigma_{i}^{2}}\begin{bmatrix}% x_{i}^{2}&x_{i}\\ x_{i}&1\end{bmatrix}+\mathbf{C}^{-1}_{\rm p},= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG [ start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL start_CELL italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] + bold_C start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_p end_POSTSUBSCRIPT , (28)

where t0=𝐂p1(θfidμp)subscriptt0subscriptsuperscript𝐂1psubscript𝜃fidsubscript𝜇p\textbf{t}_{0}=\mathbf{C}^{-1}_{\rm p}(\theta_{\rm fid}-\mu_{\rm p})t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = bold_C start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_p end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_fid end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT roman_p end_POSTSUBSCRIPT ), with μp=0subscript𝜇p0\mu_{\rm p}=\textbf{0}italic_μ start_POSTSUBSCRIPT roman_p end_POSTSUBSCRIPT = 0 is the mean of the prior on the score, and 𝐂p1=Isubscriptsuperscript𝐂1pI\mathbf{C}^{-1}_{\rm p}=\textbf{I}bold_C start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_p end_POSTSUBSCRIPT = I is added to the Fisher matrix as a prior on the inverse-covariance of the spread of the summaries. With these two expressions we can calculate exact MLE estimates for the parameters 𝜽=(m,b)𝜽𝑚𝑏\bm{\theta}=(m,b)bold_italic_θ = ( italic_m , italic_b ) via Eq. (6). We choose wide Gaussian priors for θ𝜃\thetaitalic_θ, and uniform priors for x[0,10]𝑥010x\in[0,10]italic_x ∈ [ 0 , 10 ] and σ[1,10]𝜎110\sigma\in[1,10]italic_σ ∈ [ 1 , 10 ]. For network training, we simulate 104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datasets of size ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 datapoints. For testing, we generate an additional 104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datasets of size ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT datapoints to demonstrate scalability. We use fully-connected MLPs of size [256, 256, 256] with ELU activations (Clevert et al., 2015) for both score and Fisher networks. Both networks receive the input data [yi,xi,σi2]Tsuperscriptsubscript𝑦𝑖subscript𝑥𝑖superscriptsubscript𝜎𝑖2𝑇[y_{i},x_{i},\sigma_{i}^{2}]^{T}[ italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. We train networks for 2500 epochs with an adam optimizer using a step learning rate decay schedule. We train an ensemble of 10 networks in parallel on the same training data with different initializations.

C.2 Robustness Network Architecture Comparison

We train three networks on the same ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 datasets as before: a sum-aggregated Fishnets network, a mean-aggregated deepset, and a learned softmax-aggregated deepset (no Fisher output and standard MSE loss against true parameters MSE(𝜽^,𝜽)MSE^𝜽𝜽\rm{MSE}(\hat{\bm{\theta}},\bm{\theta})roman_MSE ( over^ start_ARG bold_italic_θ end_ARG , bold_italic_θ )). Here we initialise Fishnets with [50,50,50] hidden units for score and Fisher networks, and two embeddings of [128,128,128] hidden units for both deepset networks, all with swish (Ramachandran et al., 2017) nonlinearities for the data embedding (see Table 2). All networks are initialised with the same seed.

C.3 Gamma Population Model

Consider a serum which increases patients’ red blood cell counts, whose decay rate, τ𝜏\tauitalic_τ, is not known. A population of patients are injected with the serum and then asked to come back to the lab within tmax=10subscript𝑡max10t_{\rm max}=10italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = 10 days for a measurement of their blood cell count, s𝑠sitalic_s. We can cast this problem using the following hierarchical model

μ𝜇\displaystyle\muitalic_μ 𝒰(0.5,10)absent𝒰0.510\displaystyle\curvearrowleft\mathcal{U}(0.5,10)↶ caligraphic_U ( 0.5 , 10 )
ΘΘ\displaystyle\Thetaroman_Θ 𝒰(0.1,1.5)absent𝒰0.11.5\displaystyle\curvearrowleft\mathcal{U}(0.1,1.5)↶ caligraphic_U ( 0.1 , 1.5 )
γisubscript𝛾𝑖\displaystyle\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT Gamma(α=μ/Θ,β=1/Θ)absentGammaformulae-sequence𝛼𝜇Θ𝛽1Θ\displaystyle\curvearrowleft\textrm{Gamma}(\alpha=\mu/\Theta,\beta=1/\Theta)↶ Gamma ( italic_α = italic_μ / roman_Θ , italic_β = 1 / roman_Θ )
τisubscript𝜏𝑖\displaystyle\tau_{i}italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT 𝒰(0,10)absent𝒰010\displaystyle\curvearrowleft\mathcal{U}(0,10)↶ caligraphic_U ( 0 , 10 )
λisubscript𝜆𝑖\displaystyle\lambda_{i}italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =Aexp(τi/γi)absent𝐴subscript𝜏𝑖subscript𝛾𝑖\displaystyle=A\exp(-\tau_{i}/\gamma_{i})= italic_A roman_exp ( - italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
sisubscript𝑠𝑖\displaystyle s_{i}italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT Pois(λi),absentPoissubscript𝜆𝑖\displaystyle\curvearrowleft\textrm{Pois}(\lambda_{i}),↶ Pois ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,

where the goal is to infer the mean μ𝜇\muitalic_μ and scale ΘΘ\Thetaroman_Θ of the decay rate Gamma distribution from the data, {τi,si}subscript𝜏𝑖subscript𝑠𝑖\{\tau_{i},s_{i}\}{ italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }. In the censored case, measurements are rejected if si<sminsubscript𝑠𝑖subscript𝑠mins_{i}<s_{\rm min}italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < italic_s start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT, and collected until ndatasubscript𝑛datan_{\rm data}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT samples are accepted. The model is visualised in a plate diagram in Figure LABEL:fig:plate_diagram. In the uncensored case, the posterior estimation for this problem is readily solved using a high-dimensional Hamiltonian Monte-Carlo (HMC) sampler. We implement this model in Numpyro (Phan et al., 2019) as a baseline MCMC comparison for our algorithm. For the Fishnets implementation, we generate 104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT simulations of size ndata=500subscript𝑛data500n_{\rm data}=500italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 500 over a uniform prior for μ𝜇\muitalic_μ and ΘΘ\Thetaroman_Θ. We then train the same Fishnets architecture used for the linear regression case with data inputs [τi,si]Tsuperscriptsubscript𝜏𝑖subscript𝑠𝑖𝑇[\tau_{i},s_{i}]^{T}[ italic_τ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. Once the networks were trained, we pass a suite of ndata=5000subscript𝑛data5000n_{\rm data}=5000italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 5000 simulations through the network to generate neural summaries with which to train a density estimation network. Following (Alsing et al., 2019), we use Mixture Density Networks to learn an amortized posterior for p(θ^NN|θ)𝑝conditionalsubscript^𝜃NN𝜃p(\hat{\theta}_{\rm NN}|\theta)italic_p ( over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_NN end_POSTSUBSCRIPT | italic_θ ) with three hidden layers of size [50,50,50]505050[50,50,50][ 50 , 50 , 50 ]. We then evaluate this posterior at the same target data used for the HMC, shown in green in Figure LABEL:fig:uncensored-corner. The Fishnets compression results in slightly inflated contours, indicating a small leakage of information. To demonstrate scaling, we addidionally generate another simulation at ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT using the same random seed. We train another amortised posterior using 5000 simulations at ndata=104subscript𝑛datasuperscript104n_{\rm data}=10^{4}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT = 10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT and pass the data through the same trained Fishnet architecture. The resulting posterior is shown in blue for comparison.

Refer to caption
Figure 6: Density estimation posteriors obtained from parameter-Fishnets summary pairs are robust over training data. Each parameter’s PIT test is close to uniform, which shows that the Fishnets summary posterior has successfully captured the underlying Bayesian information from the data.

Appendix D Graph Prediction Benchmark Experiment Details

All GNN models are implemented in PyTorch Geometric (Fey & Lenssen, 2019), and all experiments are performed on a single NVIDIA V100 32GB with the same random seed for initialisation and training. We first describe graph- and node-level prediction tasks, followed by the experimental details on the three benchmark datasets.

Node Property Prediction. This task consists of aggregating edge and node information within neighborhoods to predict properties at the node level. The ogbn-arxiv dataset is a directed citation graph of 169,343169343169,343169 , 343 papers summarised as 128-dimensional vectors (nodes) and 1,166,24311662431,166,2431 , 166 , 243 citations (edges), where the direction indicates the citation direction. The task is to predict which of 40 classes each paper belongs to. The ogbn-proteins dataset consists of 132,534132534132,534132 , 534 proteins encoded as 8-dimensional one-hot features indicating protein species (nodes) and 39,561,2523956125239,561,25239 , 561 , 252 undirected weighted edges indicating association scores between proteins. The task a 112-class classification from aggregated subgraph edges and nodes using an ROC-AUC metric.

Graph Property Prediction. This task requires the aggregation of edges and nodes to global features of a graph. We consider the ogbn-molhiv dataset, which is comprised of 41,1274112741,12741 , 127 molecules with atoms arranged as nodes and bonds as edges. The prediction task is binary classification.

ogb-molhiv. This dataset does not provide a node feature for each protein. We initialize the node features via a sum aggregation, e.g. xi=j𝒩subscript𝑥𝑖subscript𝑗𝒩x_{i}=\sum_{j\in\mathcal{N}}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j ∈ caligraphic_N end_POSTSUBSCRIPT eijsubscript𝑒𝑖𝑗e_{ij}italic_e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , where xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denotes the initialized node features and eijsubscript𝑒𝑖𝑗e_{ij}italic_e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT denotes the input edge features. We train a 7-layer DyResGEN model with softmax aggregator with learnable β𝛽\betaitalic_β parameter. A batch normalization is used for each layer. We set the hidden channel size as 256. A dropout with a rate of 0.5 is used for each layer. An Adam optimizer with a learning rate of 0.0001 are used to train the model for 150 epochs. For the Fishnets comparison we train a 3-layer network with Fishnets Aggregation with bottleneck np=8subscript𝑛𝑝8n_{p}=8italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = 8 in place of the softmax aggregation, and a learning rate of 0.00002.

ogb-arxiv. We train (Li et al., 2020)’s 28-layer ResGEN model with softmax aggregation where β𝛽\betaitalic_β is fixed as 0.1. Full batch training and test are applied. A batch normalization is used for each layer. The hidden channel size is 128. We apply a dropout with a rate of 0.5 for each layer. An Adam optimizer with a learning rate of 0.01 is used to train the model for 500 epochs. For the Fishnets comparison we train a 3-layer version of the same ResGEN network with Fishnets Aggregation with bottleneck np=9subscript𝑛𝑝9n_{p}=9italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = 9 in place of the softmax aggregation.

ogb-proteins. This dataset does not provide a node feature for each protein. We initialize the node features via a sum aggregation, e.g. xi=j𝒩subscript𝑥𝑖subscript𝑗𝒩x_{i}=\sum_{j\in\mathcal{N}}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j ∈ caligraphic_N end_POSTSUBSCRIPT eijsubscript𝑒𝑖𝑗e_{ij}italic_e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , where xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denotes the initialized node features and eijsubscript𝑒𝑖𝑗e_{ij}italic_e start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT denotes the input edge features. We train (Li et al., 2020)’s 112-layer DyResGEN with softmax aggregator. A hidden channel size of 64 is used. A layer normalization and a dropout with a rate of 0.1 are used for each layer. We train the model for 1000 epochs with an Adam optimizer with a learning rate of 0.01. For the Fishnets comparison we train an 8-layer and 16-layer version of the same DyResGEN network with Fishnets Aggregation with bottleneck np=8subscript𝑛𝑝8n_{p}=8italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = 8 in place of the softmax aggregation. Here we temper the learning rate to 0.005 and decrease the dropout rate to 0.25.

D.1 Noisy Proteins Focus Study

Here we again initialize the node features via a sum aggregation.

Refer to caption
Figure 7: Zoomed-in test ROC-AUC training trajectories for models considered in benchmark ablation study on ogbn-proteins.

We test five model architectures using the vanilla ogbn-proteins dataset (no subgraph and edge preprocessing as performed by (Li et al., 2020)). This change allowed us to flexibly incorporate the added edge feature in the noisy edge setting. To benchmark our training routine we adopt a 28-layer DyResGEN network with learned softmax aggregations and hidden size of 64, and a smaller version of this model with hidden size 14 and 28 layers. We construct two, shallower Fishnets GNNs, with 16 and 20 layers, each with 64 hidden units, and one small model with 14 hidden units and 14 layers. For each graph convolution aggregation, we adopt a “score” bottleneck of np=10subscript𝑛𝑝10n_{p}=10italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = 10 for the large Fishnets models and np=8subscript𝑛𝑝8n_{p}=8italic_n start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = 8 for the small model. We train all networks with a cross-entropy loss over the same dataset and fixed random seed using an Adam optimizer with fixed learning rate 0.0010.0010.0010.001. We incorporate an early stop** criterion conditioned on the validation data, which dictates an end to training (saturation) when the validation ROC-AUC metric stops increasing for patience=250patience250\texttt{patience}=250patience = 250 epochs.

In the noisy proteins setting we again control for stochasticity in training set loading and added edge noise by fixing the initial random seed before each training run.

test network # params test ROC-AUC
noisefree fishnets-20 442,372442372442,372442 , 372 0.8110±0.0021plus-or-minus0.81100.0021\bm{0.8110\pm 0.0021}bold_0.8110 bold_± bold_0.0021
fishnets-16 355,584355584355,584355 , 584 0.7963±0.0059plus-or-minus0.79630.00590.7963\pm 0.00590.7963 ± 0.0059
fishnets small 30,3603036030,36030 , 360 0.7929±0.0045plus-or-minus0.79290.00450.7929\pm 0.00450.7929 ± 0.0045
GCN-28 477,964477964477,964477 , 964 0.7951±0.0059plus-or-minus0.79510.00590.7951\pm 0.00590.7951 ± 0.0059
GCN small 33,5803358033,58033 , 580 0.7731±0.0052plus-or-minus0.77310.00520.7731\pm 0.00520.7731 ± 0.0052
noisy edges fishnets-20 442,500442500{442,500}442 , 500 0.7198±0.0109plus-or-minus0.71980.0109\bm{0.7198\pm 0.0109}bold_0.7198 bold_± bold_0.0109
GCN-28 478,092478092478,092478 , 092 0.6471±0.0090plus-or-minus0.64710.00900.6471\pm 0.00900.6471 ± 0.0090
Table 4: Full summary of performance on benchmark and noisy variants of the proteins dataset. Errorbars denote standard deviation of test ROC-AUC in the last ten epochs of training.

Appendix E Deepsets Formalism

Summary. The deepsets method presented by Zaheer et al. shows in Theorem 9 that any function over a countable set can be decomposed in the form f(X)=ρ(xXϕ(x))𝑓𝑋𝜌subscript𝑥𝑋italic-ϕ𝑥f(X)=\rho\left(\sum_{x\in X}\phi(x)\right)italic_f ( italic_X ) = italic_ρ ( ∑ start_POSTSUBSCRIPT italic_x ∈ italic_X end_POSTSUBSCRIPT italic_ϕ ( italic_x ) ). They then extend this to the universality of deepsets since ρ𝜌\rhoitalic_ρ and ϕitalic-ϕ\phiitalic_ϕ can be parameterized as neural networks, which can be universal function approximators. The deepsets formalism allows point-estimates for regression parameters to be obtained following an aggregation of features in a potentially variably-sized set of data. Incorporating our formalism, each set member disubscriptd𝑖\textbf{d}_{i}d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is first passed to a neural network f(di;w1)𝑓subscriptd𝑖subscript𝑤1f(\textbf{d}_{i};w_{1})italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), and subsequently aggregated using some permutation-invariant scheme, isubscriptdirect-sum𝑖\bigoplus_{i}⨁ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.

𝜽^=g(i=1ndataf(di;w1);w2),^𝜽𝑔superscriptsubscriptdirect-sum𝑖1subscript𝑛data𝑓subscriptd𝑖subscript𝑤1subscript𝑤2\hat{\bm{\theta}}=g\left(\bigoplus_{i=1}^{n_{\rm data}}f(\textbf{d}_{i};w_{1})% ;\ \ w_{2}\right),over^ start_ARG bold_italic_θ end_ARG = italic_g ( ⨁ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_f ( d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_w start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ; italic_w start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , (29)

where f𝑓fitalic_f is the embedding network and g𝑔gitalic_g is the “global” function that maps aggregated features to predicted parameters. When the aggregation is chosen to be the mean, the deepsets formalism is scalable to arbitrary data and becomes equivalent to the Fishnets aggregation formalism with flat weights across the aggregated data. The loss takes the form of a convex squared loss, e.g. the mean square error

=1nbatchinbatch(𝜽^i𝜽i)21subscript𝑛batchsuperscriptsubscript𝑖subscript𝑛batchsuperscriptsubscript^𝜽𝑖subscript𝜽𝑖2\mathcal{L}=\frac{1}{n_{\rm batch}}\sum_{i}^{n_{\rm batch}}(\hat{\bm{\theta}}_% {i}-\bm{\theta}_{i})^{2}caligraphic_L = divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT roman_batch end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT roman_batch end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( over^ start_ARG bold_italic_θ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (30)

where nbatchsubscript𝑛batchn_{\rm batch}italic_n start_POSTSUBSCRIPT roman_batch end_POSTSUBSCRIPT is a batch of full simulations, each of size ndatasubscript𝑛datan_{\rm data}italic_n start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT.

Training and Generalization. In practice, Deepsets requires a fixed aggregation scheme from which to learn its global function. Most often this is a summary of embedding layers ϕ(x)italic-ϕ𝑥\phi(x)italic_ϕ ( italic_x ). For networks to scale to arbitrary dataset cardinality, aggregations like max, mean, and variance need to be used. In a scenario where the training data distribution p(x,θ)𝑝𝑥𝜃p(x,\theta)italic_p ( italic_x , italic_θ ) follows a different distribution from the training data, these aggregations might pose an issue. Concretely, consider x𝒩(μ,1)similar-to𝑥𝒩𝜇1x\sim\mathcal{N}(\mu,1)italic_x ∼ caligraphic_N ( italic_μ , 1 ), with the target quantity θ=μ𝒰(0,2)𝜃𝜇similar-to𝒰02\theta=\mu\sim\mathcal{U}(0,2)italic_θ = italic_μ ∼ caligraphic_U ( 0 , 2 ). Next consider a deepset with the identity embedding layer ϕ(x)=xitalic-ϕ𝑥𝑥\phi(x)=xitalic_ϕ ( italic_x ) = italic_x and mean-aggregation:

μ^=ρ(1nixi)^𝜇𝜌1𝑛subscript𝑖subscript𝑥𝑖\hat{\mu}=\rho\left(\frac{1}{n}\sum_{i}x_{i}\right)over^ start_ARG italic_μ end_ARG = italic_ρ ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (31)

If test data xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT were drawn from the same distribution as the test data, ρ𝜌\rhoitalic_ρ would act on the mean value of the set of data, in this case ρ(𝔼p(x,θ)[xi])𝜌subscript𝔼𝑝𝑥𝜃delimited-[]subscript𝑥𝑖\rho(\mathbb{E}_{p(x,\theta)}[x_{i}])italic_ρ ( blackboard_E start_POSTSUBSCRIPT italic_p ( italic_x , italic_θ ) end_POSTSUBSCRIPT [ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] ), and would converge to a learned function of the joint prior-data distribution p(x,θ)𝑝𝑥𝜃p(x,\theta)italic_p ( italic_x , italic_θ ). However, if a test set of data were drawn from a different distribution, e.g. μtest𝒩(0.5,0.1)similar-tosubscript𝜇test𝒩0.50.1\mu_{\rm test}\sim\mathcal{N}(0.5,0.1)italic_μ start_POSTSUBSCRIPT roman_test end_POSTSUBSCRIPT ∼ caligraphic_N ( 0.5 , 0.1 ), then the expectation 𝔼p(x,θ)subscript𝔼𝑝𝑥𝜃\mathbb{E}_{p(x,\theta)}blackboard_E start_POSTSUBSCRIPT italic_p ( italic_x , italic_θ ) end_POSTSUBSCRIPT would take on a different value, and ρ𝜌\rhoitalic_ρ would return an incorrect result for the deterministic aggregation. Here it is important to emphasize that ptest(x,θ)superscript𝑝test𝑥𝜃p^{\rm{test}}(x,\theta)italic_p start_POSTSUPERSCRIPT roman_test end_POSTSUPERSCRIPT ( italic_x , italic_θ ) and p(x,θ)𝑝𝑥𝜃p(x,\theta)italic_p ( italic_x , italic_θ ) overlap along the same support, meaning the network will have seen examples of data drawn from this prior in the limit of an infinite training set. However, the fixed aggregation makes use of a training-data distribution-dependent quantity for its map**, which can be skewed under covariate shift or different noise settings.