Go To:

Paper Title Paper Authors Table Of Contents Abstract References
Home
Report a problem with this paper

Belief Propagation Neural Networks

Authors

  • J. Kuck
  • Shuvam Chakraborty
  • Hao Tang
  • R. Luo
  • Jiaming Song
  • Ashish Sabharwal
  • S. Ermon
  • NeurIPS
  • 2020
  • View in Semantic Scholar

Abstract

Learned neural solvers have successfully been used to solve combinatorial optimization and decision problems. More general counting variants of these problems, however, are still largely solved with hand-crafted solvers. To bridge this gap, we introduce belief propagation neural networks (BPNNs), a class of parameterized operators that operate on factor graphs and generalize Belief Propagation (BP). In its strictest form, a BPNN layer (BPNN-D) is a learned iterative operator that provably maintains many of the desirable properties of BP for any choice of the parameters. Empirically, we show that by training BPNN-D learns to perform the task better than the original BP: it converges 1.7x faster on Ising models while providing tighter bounds. On challenging model counting problems, BPNNs compute estimates 100's of times faster than state-of-the-art handcrafted methods, while returning an estimate of comparable quality.

1 Introduction

Probabilistic inference problems arise in many domains, from statistical physics to machine learning. There is little hope that efficient, exact solutions to these problems exist as they are at least as hard as NP-complete decision problems. Significant research has been devoted across the fields of machine learning, statistics, and statistical physics to develop variational and sampling based methods to approximate these challenging problems [13, 34, 48, 6, 38] . Variational methods such as Belief Propagation (BP) [31] have been particularly successful at providing principled approximations due to extensive theoretical analysis.

We introduce belief propagation neural networks (BPNNs), a flexible neural architecture designed to estimate the partition function of a factor graph. BPNNs generalize BP and can thus provide more accurate estimates than BP when trained on a small number of factor graphs with known partition functions. At the same time, BPNNs retain many of BP's properties, which results in more accurate estimates compared to general neural architectures. BPNNs are composed of iterative layers (BPNN-D) and an optional Bethe free energy layer (BPNN-B), both of which maintain the symmetries of BP under factor graph isomorphisms. BPNN-D is a parametrized iterative operator that strictly generalizes BP while preserving many of BP's guarantees. Like BP, BPNN-D is guaranteed to converge on tree structured factor graphs and return the exact partition function. For factor graphs with loops, BPNN-D computes a lower bound whenever the Bethe approximation obtained from fixed points of BP is a provable lower bound (with mild restrictions on BPNN-D). BPNN-B performs regression from the trajectory of beliefs (over a fixed number of iterations) to the partition function of the input factor graph. While this sacrifices some guarantees, the additional flexibility introduced by BPNN-B generally improves estimation performance.

Experimentally, we show that on Ising models BPNN-D is able to converge faster than standard BP and frequently finds better fixed points that provide tighter lower bounds. BPNN-D generalizes well to Ising models sampled from a different distribution than seen during training and to models with nearly twice as many variables as seen during training, providing estimates of the log partition function that are significantly better than BP or a standard graph neural network (GNN) in these settings. We also perform experiments on community detection problems, where BP is known to perform well both empirically and theoretically, and show improvements over BP and a standard GNN. We then perform experiments on approximate model counting [46, 27, 28, 8] , the problem of computing the number of solutions to a Boolean satisfiability (SAT) problem. Unlike the first two experiments it is very difficult for BP to converge in this setting. Still, we find that BPNN learns to estimate accurate model counts from a training set of 10's of problems and generalize to problems that are significantly harder for an exact model counter to solve. Compared to handcrafted approximate model counters, BP returns comparable estimates 100's times faster using GPU computation.

2 Background: Factor Graphs And Belief Propagation

In this section we provide background on factor graphs and belief propagation [31] . A factor graph is a representation of a discrete probability distribution that takes advantage of independencies between variables to make the representation more compact. Belief propagation is a method for approximating the normalization constant, or partition function, of a factor graph. Let p(x) be a discrete probability distribution defined in terms of a factor graph as

EQUATION (1): Not extracted; please refer to original document.

where x = {x 1 , x 2 , . . . , x n }, f a (x a ) > 0 are factors and Z is the partition function. As a data structure, a factor graph is a bipartite graph with n variables nodes and M factor nodes. Factor nodes and variables nodes are connected if and only if the variable is in the scope of the factor.

Belief Propagation Belief propagation performs iterative message passing between neighboring variable and factor nodes. Variable to factor messages, m

i→a (x i ), and factor to variable messages, m j→a (x j ).

(2)

Messages are typically initialized either randomly or as constants. The BP algorithm estimates approximate marginal probabilities over the sets of variables x a associated with each factor f a . We denote the belief over variables x a , after message passing iteration k is complete, as b a→i (x i ). The belief propagation algorithm proceeds by iteratively updating variable to factor messages and factor to variable messages until they converge to fixed values, referred to as a fixed point of Equations 2, or a predefined maximum number of iterations is reached. At this point the beliefs are used to compute a variational approximation of the factor graph's partition function. This approximation, originally developed in statistical physics, is known as the Bethe free energy F Bethe = U Bethe − H Bethe ≈ − ln Z [10] . It is defined in terms of the Bethe average energy U Bethe := − M a=1

xa b a (x a ) ln f a (x a ) and the Bethe entropy Numerically Stable Belief Propagation. For numerical stability, belief propagation is generally performed in log-space and messages are normalized at every iteration. It is also standard to add a damping parameter, α ∈ [0, 1), to improve convergence by taking partial update steps. BP without damping is recovered when α = 0, while α = 1 would correspond to not updating messages and instead retaining their values from the previous iteration. With these modifications, the variable to factor messages from Equation 2 are rewritten as follows, where terms scaled by α represent the difference in the message's value from the previous iteration:

H Bethe := − M a=1 xa b a (x a ) ln b a (x a ) + N i=1 (d i − 1) xi b i (x i ) ln b i (x i ),

m (k) i→a =m (k) i→a + α m (k−1) i→a −m (k)

i→a , wherem c→i .

Similarly, the factor to variable messages from Equation 2 are rewritten as

EQUATION (4): Not extracted; please refer to original document.

Note that m a→i are vectors of length |X i |, φ a (x a ) = ln (f a (x a )) denotes log factors, z i→a and z a→i are normalization terms, and we use the shorthand LSE for the log-sum-exp function:

LSE xa\xi φ a (x a ) = ln

xa\xi exp φ a (x a ) .

3 Belief Propagation Neural Networks

We design belief propagation neural networks (BPNNs) as a family of graph neural networks that operate on factor graphs. Unlike standard graph neural networks (GNNs), BPNNs do not resend messages between nodes, a property taken from BP known as avoiding 'double counting' the evidence. This property guarantees that BPNN-D described below is exact on trees (Theorem 3). BPNN-D is a strict generalization of BP (Proposition 1), but is still guaranteed to give a lower bound to the partition function upon convergence for a class of factor graphs (Theorem 3) by finding fixed points of BP (Theorem 2). Like BP, BPNN preserves the symmetries inherent to factor graphs (Theorem 4).

BPNNs consist of two parts. First, iterative BPNN layers output messages, analogous to standard BP. These messages are used to compute beliefs using the same equations as for BP. Second, the beliefs are passed into a Bethe free energy layer (BPNN-B) which generalizes the Bethe approximation by performing regression from beliefs to Z. Alternatively, when the standard Bethe approximation is used in place of BPNN-B, BPNN provides many of BP's guarantees.

BPNN Iterative Layers BPNN iterative layers are flexible neural operators that can operate on beliefs or message in a variety of ways. Here, we focus on a specific variant, BPNN-D, due to its strong convergence properties, and we refer the reader to Appendix C for information on other variants. The BPNN iterative damping layer (BPNN-D) modifies factor-to-variable messages (Equation 4) using the output of a learned operator H :

R n i=1 di|Xi| → R n i=1 di|Xi| in place of the conventional damping term α m (k−1) a→i −m (k)

a→i , where d i denotes the degree and |X i | the cardinality of variable X i . This learned operator H(•) takes as input the difference between iterations k − 1 and k of every factor-to-variable message, and modifies these differences jointly. It can thus be much richer than a scalar multiplier. BPNN-D factor-to-variable messages are given by

EQUATION (5): Not extracted; please refer to original document.

where ∆ (k) = H n (k−1) −ñ (k) denotes the result of applying H(•) to all factor-to-variable message differences and ∆

(k)

a→i is the output corresponding to the modified a → i message difference. Variable-to-factor messages are unchanged from Eq. 3, except for taking messages n Note that a broad class of highly expressive learnable operators are invertible [7] . Enforcing that every fixed point of BPNN-D is also a fixed point of BP is particularly useful, as it immediately follows that BPNN-D returns a lower bound whenever the Bethe approximation obtained from fixed points of BP returns a provable lower bound (Theorem 3). When a BPNN-D layer is applied iteratively until convergence, fast convergence is guaranteed for tree structured factor graphs (Proposition 2). As mentioned, BPNN iterative layers are flexible and can additionally be modified to operate directly on message values or factor beliefs at the expense of no longer returning a lower bound (see Appendix C).

EQUATION (6): Not extracted; please refer to original document.

Theorem 3. If Zero Is The Unique Fixed Point Of H(•), The Bethe Approximation Computed From Beliefs At A Fixed Point Of Bpnn-D (1) Is Exact For Tree Structured Graphs And (2) Lower Bounds The Partition Function Of Any Factor Graph With Binary Variables And Log-Supermodular Potential Functions. Proposition 2. Bpnn-D Converges Within ℓ Iterations On Tree Structured Factor Graphs With Height ℓ.

Bethe Free Energy Layer (BPNN-B). When convergence to a fixed point is unnecessary, we can increase the flexibility of our architecture by building a K-layer BPNN from iterative layers that do not share weights. Additionally we define a Bethe free energy layer (BPNN-B, Equation 7) using two MLPs that take the trajectories of learned beliefs from each factor and variable as input and output scalars:

EQUATION (7): Not extracted; please refer to original document.

This parameterization subsumes the standard Bethe approximation, so we can initialize the parameters of f BPNN to output the Bethe approximation computed from the final layer beliefs (see the appendix for details). Note that |x a | is the number of variables in the scope of factor a, S |xa| denotes the symmetric group (all permutations of {1, 2, . . . , |x a |}), and the permutation σ is applied to the dimensions of all 2k concatenated terms. We ensure that BPNN preserves the symmetries of BP (Theorem 4) by passing all factor permutations through MLP BF and averaging the result.

BPNN Preserves the Symmetries of BP. BPNN is designed so that equivalent input factor graphs are mapped to equivalent outputs. This is a property that BP satisfies by default. Standard GNNs are also designed to satisfy this property, however the notion of 'equivalence' between graphs is different than 'equivalence' between factor graphs. In this section we formalize these statements.

Graph isomorphism defines an equivalence relationship between graphs that is respected by standard GNNs. Two isomorphic graphs are structurally equivalent and indistinguishable if the nodes are appropriately matched. More formally, there exists a bijection between nodes (or their indices) in the two graphs that defines this matching. Standard GNNs are designed so that output node representations are equivariant to the input node indexing; the indexing of output node representations matches the indexing of input nodes. Output node representations of a GNN run on two isomorphic graphs can be matched using the same bijection that defines the isomorphism. Further, standard GNNs are designed to map isomorphic graphs to the same graph-level output representation. These two properties are achieved by using a message aggregation function and a graph-level output function that are both invariant to node indexing.

We formally define factor graph isomorphism in Definition 1 (Appendix A). This equivalence relationship is more complicated than for standard graphs because factor potentials define a structured relationship between factor and variable nodes. As in a standard graph, variable nodes are indexed globally (X 1 , X 2 , . . . , X n ) in the representation of a factor graph. Additionally, variable nodes are also indexed locally by factors that contain them. This is required because each factor dimension (note that factors are tensors) corresponds to a unique variable, unless the factor happens to be symmetric. Local variable indices define a mapping between factor dimensions and the variables' global indices. These local variable indices lead to additional bijections in the definition of isomorphic factor graphs (condition 2 in Definition 1). Note that standard GNNs do not respect factor graph isomorphisms because of these additional bijections.

In contrast to standard GNNs, BP respects factor graph isomorphisms. When BP is run on two isomorphic factor graphs for the same number of iterations with constant message initialization 2 the output beliefs and messages satisfy bijections corresponding to those of the input factor graphs.

4 Experiments

In our experiments we trained BPNN to estimate the partition function of factor graphs from a variety of domains. First, experiments on synthetic Ising models show that BPNN-D can learn to find better fixed points than BP and converge faster. Additionally, BPNN generalizes to Ising models with nearly twice as many variables as those seen during training and that were sampled from a different distribution. Second, experiments and an ablation study on the stochastic block model from community detection show that maintaining properties of BP in BPNN improves results over standard GNNs. Finally, model counting experiments performed on real world SAT problems show that BPNN can learn from 10's of training problems, generalize to problems that are harder for an exact model counter, and compute estimates 100's of times faster than handcrafted approximate model counters. We implemented our BPNN and the baseline GNN using PyTorch Geometric [19] . We refer the reader to Appendix B.2 for details on the GNN.

4.1 Ising Models

We followed a common experimental setup used to evaluate approximate integration methods [21, 17] . We randomly generated grid structured attractive Ising models whose partition functions can be computed exactly using the junction tree algorithm [33] for training and validation. BP computes a provable lower bound for these Ising models [41] . This family of Ising models is only slightly more general than the one studied in [30] , where BP was proven to quickly converge to the Bethe free energy's global optimum. We found that an iterative BPNN-D layer was able to converge faster than standard BP and could find tighter lower bounds for these problems. Additionally we trained a 10 layer BPNN and evaluated its performance against a 10 layer GNN architecture (details in Appendix). Compared to the GNN, BPNN has improved generalization when tested on larger Ising models and Ising models sampled from a different distribution than seen during training. We recorded the number of iterations that BPNN-D and BP run with parallel updates took to converge, defined as a maximum factor-to-variable message difference of 10 −5 . BPNN-D had a median improvement ratio of 1.7x over BP, please refer to the appendix for complete convergence plots. Among the 44 models where BP converged, the RMSE between the exact log partition function and BPNN-D's estimate was .97 compared with 7.20 for BP. For 10 of the 44 models, BPNN-D found fixed points corresponding to lower bounds on the log partition function that were larger (i.e., better) than BP's by 3 to 22 (corresponding to bounds on the partition function that were 20 to e 22 times larger). In contrast, the log lower bound found by BP was never larger than the bound found by BPNN-D by more than 1.7.

Out of Distribution Generalization. We tested BPNN's ability to generalize to larger factor graphs and to shifts in the test distribution. Again we used a training set of 50 Ising models of size 10x10 (100 variables). We sampled test Ising models from distributions with generative parameters increased by factors of 2 and 10 from their training values (see appendix for details) and with their size increase to 14x14 (for 196 variables instead of the 100 seen during training). For this experiment we used a BPNN architecture with 10 iterative layers whose weights were not tied and with MLPs that operate on factor messages (without a BPNN-B layer). As a baseline we trained a 10 layer GNN (maximally powerful GIN architecture) with width 4 on the same dataset. We also compute the Bethe approximation from running standard loopy belief propagation and the mean field approximation. We used the libDAI [35] implementation for both. We tested loopy belief propagation with and without damping and with both parallel and sequential message update strategies. We show results for two settings whose estimates of the partition function differ most drastically:

(1) run for a maximum of 10 iterations with parallel updates and damping set to .5, and (2) run for a maximum of 1000 iterations with sequential updates using a random sequence and no damping. Full test results are shown in Figure 1 . The leftmost point in the left figure shows results for test data that was drawn from the same distribution used for training the BPNN and GNN. The BPNN and GNN perform similarly for data drawn from the same distribution seen during training. However, our BPNN significantly outperforms the GNN when the test distribution differs from the training distribution and when generalizing to the larger models. Our BPNN also significantly outperforms loopy belief propagation, both for test data drawn from the training distribution and for out of distribution data.

Figure 1: Each point represents the root mean squared error (RMSE, y-axis) of the specified method on a test set of 50 Ising models sampled with the parameters fmax and cmax (x-axis). The leftmost point shows results for test data drawn from the same distribution as training. BPNN significantly improves upon loopy belief propagation (LBP) for both in and out of distribution data. BPNN also significantly outperforms GNN on out of distribution data and larger models.

4.2 Stochastic Block Model

The Stochastic Block Model (SBM) is a generative model describing the formation of communities and is often used to benchmark community detection algorithms [1] . While BP does not lower bound the partition functions of associated factor graphs for SBMs, it has been shown that BP asymptotically (in the number of nodes) reaches the information theoretic threshold for community recovery on SBMs with fewer than 4 communities [1] . We trained a BPNN to estimate the partition function of the associated factor graph and observed improvements over estimates obtained by BP or a maximally powerful GNN, which lead to more accurate marginals that can be used to better quantify uncertainty in SBM community membership. We refer the reader to Appendix F for a formal definition of SBMs as well as our procedure for constructing factor graphs from a sampled SBM.

Dataset And Methods

In our experiments, we consider SBMs with 2 classes and 15-20 nodes, so that exact inference is possible using the Junction Tree algorithm. In this non-asymptotic setting, BP is a strong baseline and can almost perfectly recover communities [14] , but is not optimal and thus does not compute exact marginals or partition functions. For training, we sample 10 two class SBMs with 15 nodes, class probabilities of .75 and .25, and edge probability of .93 within and .067 between classes along with four such graphs for validation. For each graph, we fix each node to each class and calculate the exact log partition using the Junction Tree Algorithm, producing 300 training and 120 validation graphs. We explain in Appendix F how these graphs can be used to calculate marginals.

To estimate SBM partition functions, we trained a BPNN with 30 iterative BPNN layers that operate on messages (see Appendix C), followed by a BPNN-B layer. Since BP does not provide a lower bound for SBM partitions, we took advantage of BPNN's flexibility and chose greater expressive power over BPNN-D's superior convergence properties. We compared against BP and a GNN as baseline methods. Additionally, we performed 2 ablation experiments. We trained a BPNN with a BPNN-B layer that was not permutation invariant to local variable indexing, by removing the sum over permutations in S |xa| from Equation 7 and only passing in the original beliefs. We refer to this non-invariant version as BPNN-NI. We then forced BPNN-NI to 'double count' messages by changing the sums in Equations 5 and 6 to be over j ∈ N (a). We refer to this non-invariant version that performs double counting as BPNN-DC. We refer the reader to Appendix F for further details on models and training.

Results As shown in Table 1 , BPNN provides the best estimates for the partition function. Critically, we see that not 'double counting' messages and preserving the symmetries of BP are key improvements of BPNN over GNN. Additionally, BPNN outperforms BP and GNN on out of distribution data and larger graphs and can learn more accurate marginals. We refer the reader to Appendix F for more details on these additional experiments.

Table 1: RMSE of SBM ln(Z) estimates. BPNN outperforms BP, GNN, and ablated versions of BPNN.

4.3 Model Counting

In this section we use a BPNN to estimate the number of satisfy solutions to a Boolean formula, a challenging problem for BP which generally fails to converge due to the complex logical constraints and 0 probability states. Computing the exact number of satisfy solutions (exact model counting) is a #P-complete problem [47] . Model counting is a fundamental problem that arises in many domains including probabilistic reasoning [40, 9] , network reliability [16] , and detecting private information leakage from programs [11] . However, the computational complexity of exact model counting has led to a significant body of work on approximate model counting [46, 27, 28, 8, 20, 18, 24, 3, 5, 44] , with the goal of estimating the number of satisfying solutions at a lower computational cost. Training Setup. All BPNNs trained in this section were composed of 5 BPNN-D layers followed by a BPNN-B layer and were trained to predict the natural logarithm of the number of satisfying solutions to an input formula in CNF form. This is accomplished by converting the CNF formula into a factor graph whose partition function is the number of satisfying solutions to the input formula. We evaluated the performance of our BPNN using benchmarks from [44] , with ground truth model counts obtained using DSharp [37] . The benchmarks fall into 7 categories, including network QMR problems (Quick Medical Reference) [26] , network grid problems, and bit-blasted versions of satisfiability modulo theories library (SMTLIB) benchmarks [12] . Each category contains 14 to 105 problems allocated for training and validation. See the appendix for additional details on training, the dataset, and our use of minimal independent support variable sets.

Baseline Approximate Model Counters. For comparison we ran two state-of-the-art approximate model counters on all benchmarks, ApproxMC3 [12, 44] and F2 [4, 5] . ApproxMC3 is a randomized hashing algorithm that returns an estimate of the model count that is guaranteed to be within a multiplicative factor of the exact model count with high probability. F2 gives up the probabilistic guarantee that the returned estimate will be within a multiplicative factor of the true model count in return for significantly increased computational efficiency. We also attempted to train a GNN, using the architecture from [43] adapted from classification to regression. We used the author's code, slightly modified to perform regression, but were not successful in achieving non-trivial learning.

BPNNs Provide Excellent Computational Efficiency. Figure 2 shows runtimes and estimates for BPNN, ApproxMC3, and F2 on all benchmarks from the category 'or_50'. BPNN Learning from Limited Data. We trained a separate BPNN on a random sampling of 70% of the problems in each training category. This gave each BPNN only 9 to 73 benchmarks to learn from. In contrast, prior work has performed approximate model counting on Boolean formulas in disjunctive normal form (DNF) by creating a large training set of 100k examples whose model counts can be approximated with an efficient polynomial time algorithm [2] . Such an algorithm does not exist for model counting on CNF formulas, making this approach intractable. Nonetheless, BPNN achieves training and validation RMSE comparable to or better than F2 across the range of benchmark categories (see the appendix for complete results). This demonstrates that BPNNs can capture the distribution of diverse families of SAT problems in an extremely data limited regime.

Figure 2: Left: cactus plot of runtimes for the 105 instances in the ’or_50’ category solved by BPNN, F2, and ApproxMC3. BPNN-P denotes the time taken to run BPNN in parallel on a GPU divided by the number of instances per batch (batch size=103). Median speedups of BPNN-P over F2 and ApproxMC among the plotted benchmarks are 248 and 3,689 respectively. BPNN-S denotes the time taken to run BPNN sequentially on each instance (using a CPU). Median speedups of BPNN-S over F2 and ApproxMC among the plotted benchmarks are 2.2 and 32, resp. While BPNN solved each instance within 1 second, ApproxMC3 timed out on 12 instances (out of 105) after 5000 seconds, which are not plotted. Right: error in estimated log model count (base e) plotted against the exact model count for ‘or_50’ training and validation benchmarks. BPNN’s validation RMSE was .30 on this category compared with a RMSE of 2.5 for F2.

Generalizing from Easy Data to Hard Data. We repeated the same experiment from the previous paragraph, but trained each BPNN on the 70% of the problems from each category that DSharp solved fastest. Validation was performed on the remaining 30% of problems that took longest for DSharp to solve. These hard validation sets are significantly more challenging for Dsharp. The median runtime in each category's hard validation set is 4 to 15 times longer than the longest runtime in each corresponding easy training set. Validation RMSE on these hard problems was within 33% of validation error when trained and validated on a random sampling for 3 of the 7 categories. This demonstrates that BPNNs have the potential to be trained on available data and then generalize to related problems that are too difficult for any current methods. See the appendix for complete results.

Learning Across Diverse Domains. We trained a BPNN on a random sampling of 70% of problems from all categories, spanning network grid problems, bit-blasted versions of SMTLIB benchmarks, and network DQMR problems. The BPNN achieved a final training RMSE of 3.9 and validation RMSE of 5.31, demonstrating that the BPNN is capable of capturing a broad distribution that spans multiple domains from a small training set.

5 Related Work

[2] use a graph neural network to perform approximate weighted disjunctive normal form (DNF) counting. Weighted DNF counting is a #P-complete problem. However, in contrast to model counting on CNF formulas, there exists an O(nm) polynomial time approximation algorithm for weighted DNF counting (where n is the number of variables and m is the number of clauses). The authors leverage this to generate a large training dataset of 100k DNF formulas with approximate solutions. In comparison, our BPNN can learn and generalize from a very small training dataset of less than 50 problems. This result provides the significant future work alluded to in the conclusion of [2] .

Recently, 3 [42] designed a graph neural network that operates on factor graphs and exchanges messages with BP to perform error correction decoding. In contrast, BPNN-D preserves all of BP's fixed point, computes the exact partition function on tree structured factor graphs, and returns a lower bound whenever the Bethe approximation obtained from fixed points of BP is a provable lower bound. All BPNN layers preserve BP's symmetries (invariances and equivariances) to permutations of both variable and factor indices. Finally BPNN avoids 'double counting' during message passing.

Prior work has shown that neural networks can learn how to solve NP-complete decision problems and optimization problems [43, 39, 23] . [53] perform marginal inference in relatively small graphical models using GNNs. [22] consider improving message passing in expectation propagation for probabilistic programming, when users can specify arbitrary code to define factors and the optimal updates are intractable. [50] consider learning Markov random fields and address the problem of estimating marginal likelihoods (generally intractable to compute precisely). They use a transformer network that is faster than LBP but computes comparable estimates. This allows for faster amortized inference during training when likelihoods must be computed at every training step. In contrast, BPNNs significantly outperform LBP and generalize to out of distribution data.

6 Conclusion

We introduced belief propagation neural networks, a strict generalization of BP that learns to find better fixed points faster. The BPNN architecture resembles that of a standard GNN, but preserves BP's invariances and equivariances to permutations of variable and factor indices. We empirically demonstrated that BPNNs can learn from tiny data sets containing only 10s of training points and generalize to test data drawn from a different distribution than seen during training. BPNNs significantly outperform loopy belief propagation and standard graph neural networks in terms of accuracy.

BPNNs provide excellent computational efficiency, running orders of magnitudes faster than stateof-the-art randomized hashing algorithms while maintaining comparable accuracy.

Broader Impact

This work makes both a theoretical contribution and a practical one by advancing the state-of-theart in approximate inference on some benchmark problems. Our theoretical analysis of neural fixed point iterators is unlikely to have a direct impact on society. BPNN, on the other hand, can make approximate inference more scalable. Because approximate inference is a key computational problem underlying, for example, much of Bayesian statistics, it is applicable to many domains, both beneficial and harmful to society. Among the beneficial ones, we have applications of probabilistic inference to medical diagnosis and applications of model counting to reliability, safety, and privacy analysis.

A PROOFS Theorem 1. Every fixed point of BP satisfiesm

(k) a→i = m (k−1)

a→i by definition. The computation of n

a→i = n (k−1) a→i by definition. Equation 5 gives n (k−1) a→i −ñ (k) a→i = ∆ (k) a→i = H n (k−1) −ñ (k)

a→i . Given the restriction on H(•) that H(x) = x only if x = 0, it follows that n Proposition 2. If we consider a BPNN with weight tying, then regardless of the number of iterations or layers, the output messages are the same if the input messages are the same. Without loss of generality, let us first consider any node r as the root node, and consider all the messages on the path from the leaf nodes through r. Let d r,i denote the depth of the sub-tree with root i when we consider r as the root (e.g. for a leaf node i, d r,i = 1). We use the following induction argument:

• At iteration 1, the message from all nodes with d r,i = 1 to their parents will be fixed for subsequent iterations since the inputs to the BPNN for these messages are the same.

• If at iteration t−1, the message from all nodes with d r,i ≤ t−1 to their parents are fixed for all subsequent iterations, then the inputs to the BPNN for all the messages from all nodes with d r,i = t to their parents will be fixed (since they depend on lower level messages that are fixed). Therefore, at iteration t, the messages from all the nodes with d r,i ≤ t to their parents will be fixed because of weight tying between BPNN layers.

• The maximum tree depth is l, so max i d r,i ≤ l. From the induction argument above, after at most l iterations, all the messages along the path from leaf nodes to r will be fixed.

Since the BPNN layer performs the operation over all nodes, this above argument is valid for all nodes when we consider them as root nodes. Therefore, all messages will be fixed after at most l iterations, which completes the proof.

Isomorphic Factor Graphs To prove Theorem 4 we define isomorphic factor graphs, an equivalence relation among factor graph representations, and break Theorem 4 into the lemmas in this section. Standard GNNs are built on the assumption that isomorphic graphs should be mapped to the same representation and non-isomorphic graphs should be mapped to different representations [51] . This is a challenging goal, in fact [51] The input to a BPNN is a factor graph. With the same motivation as for standard GNNs, BPNNs should map isomorphic factor graphs to the same output representation. A factor graph is represented as 4 G = (A, F p , F idx ). A ∈ {0, 1} M×N is an adjacency matrix over M factor nodes and N variable nodes, where ,5 A ai = 1 if the i-th variable is in the scope of the a-th factor and A ai = 0 otherwise. F p is an ordered list of M factor potentials, where the a-th factor potential, F p a , corresponds to the a-th factor (row) in A and is represented as a tensor with one dimension for every variable in the scope of F p a . F idx is an ordered list of ordered lists that locally indexes variables within each factor. F idx a is an ordered list specifying the local indexing of variables within the a-th factor (in A and F p ). F idx ak = i specifies that the k-th dimension of the tensor F p a corresponds to the i-th variable (column) in A. We define two factor graphs to be isomorphic when they meet the conditions of Definition 1. 2. There exists a bijection for every factor,

Definition 1. Factor Graphs G = (G(

A), G(F p ), G(F idx )) and G ′ = (G ′ (A), G ′ (F p ), G ′ (F idx )) with G(A) ∈ {0, 1} M×N and G ′ (A) ∈ {0, 1} M ′ ×N

EQUATION (8): Not extracted; please refer to original document.

such that

f V G(F idx ak ) = G ′ (F idx bl ) and G(F p a ) = σ a G ′ (F p b ) , where where b = f F (a), l = f idx a (k), σ a = (f idx a (1), f idx a (2), . . . , f idx a (|G(F idx a )|)

, and σ a G ′ (F p b ) denotes permuting the dimensions of the tensor G ′ (F p b ) according to σ a .

Condition 1 in Definition 1 states that permuting the global indices of variables or factors in a factor graph results in an isomorphic factor graph. Condition 2 in Definition 1 states that permuting the local indices of variables within factors also results in an isomorphic factor graph. In Lemmas 1, 2, and 3 we formalize the equivariance of messages and beliefs obtained by applying BPNN iterative layers. We use using the bijections from Definition 1 to construct bijective mappings between messages and beliefs. In Lemma 4 we use the equivariance of beliefs between isomorphic factor graphs to show that the output of BPNN-B is identical for isomorphic factor graphs. Proof. We use a proof by induction.

Base case: the initial messages are all equal when constant initialization is used and therefore satisfy any bijective mapping.

Inductive step: Writing the definition of variable to factor messages, we have

EQUATION (9): Not extracted; please refer to original document.

4 Note that a factor graph can be viewed as a weighted hypergraph where factors define hyperedges and factor potentials define hyperedge weights for every variable assignment within the factor. 5 For readability, we use a and b to index factors and i and j to index variables throughout this section. 6 For K ∈ N, we use [K] to denote {1, 2, . . . , K}. 7 Any message initialization strategy can be used, as long as initial messages are equivariant; e.g. they satisfy the bijective mapping g

i→a = h

j→b and g since the bijective mapping holds for factor to variable messages at iteration k − 1 by the inductive hypothesis. Writing the definition of factor to variable messages, we have

g (k) a→i (x i ) = xa\xi G(F p a )(x a ) l∈N (a)\i g (k) l→a (x l ) = x b \xj σ a G ′ (F p b ) (x b ) l∈N (b)\j g (k) l→b (x l ) = g (k)

b→j (x j ). (10) showing that the bijective mapping continues to hold at iteration k.

Proof extension to BPNN-D: the logic of the proof is unchanged when BP is performed in logspace with damping. The only difference between BPNN-D and standard BP is the replacement of the term α m

(k−1) a→i −m (k)

a→i in the computation of factor to variable messages with ∆

(k)

a→i , where ∆ (k) = H n (k−1) −ñ (k) . If H(•) is equivariant to global node indexing (the bijective mapping ∆

(k) a→i (G) = ∆ (k) b→j (G ′ ) holds, where ∆ (k)

a→i (G) denotes applying the operator H(•) to the k-th iteration's message differences when the input factor graph is G and taking the output correpsonding to message a → i), then equality is maintained in Equation 10 and the bijective mapping between messages holds.

(k) i = h (k) j , where j = f V (i).

Proof. By The Definition Of Variable Beliefs

, g (k) i (x i ) = 1 z i a∈N (i) g (k)

a→i

EQUATION (11): Not extracted; please refer to original document.

where the second equality holds due to factor to variable message equivariance from Lemma 1. Proof. By the definition of factor beliefs,

EQUATION (12): Not extracted; please refer to original document.

where the second equality holds due to variable to factor message equivariance from Lemma 1. Proof. By the definition of the Bethe approximation (or the negative Bethe free energy),

−F Bethe (G) = M a=1 xa g a (x a ) ln G(F p a )(x a ) − M a=1

xa

EQUATION (13): Not extracted; please refer to original document.

where

a ′ = f −1 F (b)

, the second equality follows from the equivariance of variable and factor beliefs (Lemmas 2 and 3), and the final equality follows from the commutative property of addition.

Proof extension to BPNN-B: the proof holds for BPNN-B because every permutation (in S |xa| ) of factor belief terms is input to MLP BF .

B Extended Background

We provide background on belief propagation and graph neural networks (GNN) to motivate and clarify belief propagation neural networks (BPNN).

B.1 Belief Propagation

We describe a general version of belief propagation [52] that operates on factor graphs.

Factor Graphs. A factor graph [32, 52] is a general representation of a distribution over n discrete random variables, {X 1 , X 2 , . . . , X n }. Let x i denote a possible state of the i th variable. We use the shorthand p(x) = p(X 1 = x 1 , . . . , X n = x 1 ) for the joint probability mass function, where x = {x 1 , x 2 , . . . , x n } is a specific realization of all n variables. Without loss of generality, p(x) can be written as the product

EQUATION (14): Not extracted; please refer to original document.

The functions f 1 , f 2 , . . . , f m each take some subset of variables as arguments; function f a takes x a ⊂ {x 1 , x 2 , . . . , x n }. We require that all functions are non-negative and finite. This makes p(x) a well defined probability distribution after normalizing by the distribution's partition function

EQUATION (15): Not extracted; please refer to original document.

A factor graph is a bipartite graph that expresses the factorization of the distribution in equation 14. c→i (x i ).

The message m (k)

a→i (x i ) from factor node a to variable node i during iteration k is then computed according to the rule m

EQUATION (17): Not extracted; please refer to original document.

The BP algorithm estimates approximate marginal probabilities for each variable, referred to as beliefs. We denote the belief at variable node i, after message passing iteration k is complete, as b

EQUATION (18): Not extracted; please refer to original document.

Similarly, BP computes joint beliefs over the sets of variables x a associated with each factor f a . We denote the belief over variables x a , after message passing iteration k is complete, as b

(k)

a (x a ) which is computed as

EQUATION (19: Not extracted; please refer to original document.

:= − M a=1 xa b a (x a ) ln b a (x a )+ N i=1 (d i −1) xi b i (x i ) ln b i (x i ) defines

B.2 Gnn Background

This section provides background on graph neural networks (GNNs), a form of neural network used to perform representation learning on graph structured data. GNNs perform iterative message passing operations between neighboring nodes in graphs, updating the learned, hidden representation of each node after every iteration. Xu et al. [51] showed that graph neural networks are at most as powerful as the Weisfeiler-Lehman graph isomorphism test [49] , which is a strong test that generally works well for discriminating between graphs. Additionally, [51] presented a GNN architecture called the Graph Isomorphism Network (GIN), which they showed has discriminative power equal to that of the Weisfeiler-Lehman test and thus strong representational power. We will use GIN as a baseline GNN for comparison in our experiments because it is provably as discriminative as any GNN that aggregates information from 1-hop neighbors.

We now describe in detail the GIN architecture that we use as a baseline. Our architecture performs regression on graphs, learning a function f GIN : G → R from graphs to a real number. Our input is a graph G = (V, E) ∈ G with node feature vectors h (0) v for v ∈ V and edge feature vectors e u,v for (u, v) ∈ E. Our output is the number f GIN (G), which should ideally be close to the ground truth value y G . Let h v denote the representation vector corresponding to node v after the k th message passing operation. We use a slightly modified GIN update to account for edge features as follows:

EQUATION (20): Not extracted; please refer to original document.

A K-layer GIN network with width M is defined by K successive GIN updates as given by Equation 20, where h 2 is different in that its input dimensionality is given by the dimensionality of the original node feature representations. The final output of our GIN network is given by

(k) v ∈ R M is an M -dimensional

EQUATION (21): Not extracted; please refer to original document.

where we concatenate summed node feature vectors from all layers and MLP (K+1) is a multilayer perceptron with a single hidden layer. Its input and hidden layers have dimensionality M • K and its output layer has dimensionality 1.

C Bpnn Iterative Layer Additional Variants

When the convergence properties of BPNN-D are not needed (e.g., if BP is not a lower bound to the partition function of a particular problem), we have the flexibility to create BPNN iterative layers that directly operate on a combination of messages and beliefs by modifyingm

(k) i→a andm (k)

a→i from Equations 3 and 4. We can introduce a variant that parameterizes both factor to variable messages and factor beliefs and computes factor to variable messages as:

EQUATION (22): Not extracted; please refer to original document.

where we use the shorthand

EQUATION (23): Not extracted; please refer to original document.

and MLP θi is a multilayer perceptron parameterized by θ i . We exponentiate before applying the multilayer perceptron because we empirically find that this improves training as opposed to having MLPs operate directly in log space.

We can also introduce additional variants that operate only on messages and parameterize both variable to factor and factor to variable messages:

EQUATION (24): Not extracted; please refer to original document.

EQUATION (25): Not extracted; please refer to original document.

BPNN iterative layers allow for great flexibility and different combinations of these MLPs can be applied in a specific layer depending on the task at hand. These MLPs can even be combined with the damping MLPs found in BPNN-D layers in lieu of the fixed scalar damping coefficient α found in Equations 3 and 4.

BPNN Initialization Note that any BPNN architecture built from iterative layers with or without a BPNN-B layer can be initialized to perform BP run for a fixed number of iterations by initializing MLPs functions f (x) = x. E.g. weight matrices are set to the identity, bias terms to zero, and any nonlinearities are chosen so as to avoid affecting the input at initialization.

D Ising Model Experiments

Data Generation. An N × N Ising model is defined over binary variables x i ∈ {−1, 1} for i = 1, 2, . . . , N 2 , where each variable represents a spin. Each spin has a local field parameter J i which corresponds to its local potential function J i (x i ) = J i x i . Each spin variable has 4 neighbors, unless it occupies a grid edge. Neighboring spins interact with coupling potentials

J i,j (x i , x j ) = J i,j x i x j .

The probability of a complete variable configuration x = {x 1 , . . . , x N 2 } is defined to be

EQUATION (26): Not extracted; please refer to original document.

where the normalization constant Z, or partition function, is defined to be

EQUATION (27): Not extracted; please refer to original document.

We performed experiments using datasets of randomly generated Ising models. Each dataset was created by first choosing N , c max , and f max . We sampled N × N Ising models according to the following process

c ∼ Unif[0, c max ), f ∼ Unif[0, f max ), (J i ) i∈V i.i.d. ∼ Unif[−f, f ), (J i,j ) (i,j)∈E i.i.d. ∼ Unif[0, c).

Baselines. We trained a 10 layer GNN (GIN architecture) with width 4 on the same dataset of attractive Ising models that we used for our BPNN. We set edge features to the coupling potentials; that is, e u,v = J u,v . We set the initial node representations to the local field potentials of each node, h

EQUATION (0): Not extracted; please refer to original document.

We used the same training loss and optimizer as for our BPNN. We used an initial learning rate of 0.001 and trained for 5k epochs, decaying the learning rate by .5 every 2k epochs.

We consider two additional baselines: Bethe approximation from running standard loopy belief propagation and mean field approximation. We used the libDAI [35] implementation for both. We test loopy belief propagation with and without damping and with both parallel and sequential message update strategies. We show results for two settings whose estimates of the partition function differ most drastically: (1) run for a maximum of 10 iterations with parallel updates and damping set to .5, and (2) run for a maximum of 1000 iterations with sequential updates using a random sequence and no damping.

Improved Lower Bounds and Faster Convergence. We trained a BPNN-D to estimate the partition function on a training set of 50 random Ising models. We randomly sampled the number of iterations of BPNN-D to apply during training between 5 and 30. When BPNN-D is then run to convergence on a validation set of random Ising models, we find that (1) it finds fixed points that provide tighter lower bounds on the partition function as explained in the main text and (2) it converges faster than BP as shown in Figure 3 .

Figure 3: The maximum difference in factor to variable message values between iterations is plotted against the message passing iteration for BPNN-D and BPNN run on 50 validation Ising models. BPNN-D converges to a maximum difference of 10−5 faster than BP, with a median speedup of 1.7x.

Out of Distribution Generalization. We tested BPNN's ability to generalize to larger factor graphs and to shifts in the test distribution. Again we used a training set of 50 Ising models. We sampled test data from distributions with c max and f max increased by factors of 2 and 10 from their training values, with N set to 14 (for 196 variables instead of the 100 seen during training). For this experiment we used a BPNN architecture with 10 iterative layers whose weights were not tied and with MLPs that operate on factor messages. For the out of distribution experiments, we did not use a final BPNN-B layer, we set the residual parameters to α 0 = α 1 = α 2 = .5, and trained on 50 attractive Ising models generated with N = 10, f max = .1, and c max = 5. We used mean squared error as our training loss. We used the Adam optimizer [29] with an initial learning rate of .0005 and trained for 100 epochs, with a decay of .5 after 50 epochs. Batching was over the entire training set (of size 50) with one optimization step per epoch.

E Sat Experiments

Additional Dataset Details We evaluated the performance of our BPNN using the suite of benchmarks from Soos and Meel [44] . Some of these benchmarks come with a sampling set. The sampling set redefines the model counting problem, asking how many configurations of variables in the sampling set correspond to at least one complete variable configuration that satisfies the formula. (A formula with n variables may have at most 2 n satisfying solutions, but a sampling set over i variables will restrict the number of solutions to at most 2 i ). We stripped all problems of sampling sets since they are outside the scope of this work. We also stripped all problems of minimal independent support variables sets and recomputed these when possible (we will discuss further later in this section).We ran the exact model counter DSharp 8 [37] on all benchmarks with a timeout of 5k seconds to obtain ground truth model counts for 928 of the 1,896 benchmarks. Only 50 of these problems had more than 5 variables in the largest factor, so we discarded these problems and set the BPNN architecture to run on factors over 5 variables. We categorized the remaining 878 by their arcane names into groupings. With some sleuthing we determined that categories 'or_50', 'or_60', 'or_70', and 'or_100' contain network DQMR problems with 150, 121, 111, and 138 benchmarks per category respectively. Categories '75' and '90' contain network grid problems with 20 and 107 benchmarks per category respectively. Category 'blasted' conains bit-blasted versions of SMTLIB ( satisfiability modulo theories library) benchmarks [12] and has 147 benchmarks. Category 's' contains representations of circuits with a subset of outputs randomly xor-ed and has 68 benchmarks. We discarded 4 categories that contained fewer than 10 benchmarks. For each category that contained more than 10 benchmarks, we split 70% into the training set and left the remaining benchmarks in the test set. We then performed two splits of the training set for training and validation; for each category we (1) trained on a random sampling of 70% of the training problems and performed validation on the remaining 30% and (2) trained on 70% of the training problems that DSharp solved fastest and performed validation on the remaining 30% that took longest for DSharp to solve. These hard validation sets are significantly more challenging for Dsharp. The median runtime in each category's hard validation set is 4 to 15 times longer than the longest runtime in each corresponding easy training set.

Minimal Independent Support As a pre-processing step for ApproxMC3 and F2, we attempted to find a set of variables that define a minimal independent support (MIS) [25] for each benchmark using the authors' code 9 with a timeout of 1k seconds. rithms can run significantly faster when given a set of variables that define a MIS. When we could find a set of variables that define a MIS, we recorded the time that each randomized hashing algorithm required without the MIS and the sum of the time to find the MIS and perform randomized hashing with the MIS. We report the minimum of these two times.

Baseline Approximate Model Counters. For comparison, we ran the state of the art approximate model counter ApproxMC3 10 [12, 44] on all benchmarks. ApproxMC3 is a randomized hashing algorithm that returns an estimate of the model count that is guaranteed to be within a multiplicative factor of the exact model count with high probability. Improving the guarantee, either by tightening the multiplicative factor or increasing the confidence, will increase the algorithm's runtime. We ran ApproxMC3 with the default parameters; confidence set to 0.81 and epsilon set to 16.

We also compare with the state of the art randomized hashing algorithm F2 11 from [4, 5] , run with CryptoMiniSat5 12 [45, 44] . This algorithm gives up the probabilistic guarantee that the returned estimate will be within a multiplicative factor of the true model count in return for significantly increased computational efficiency. We computed only a lower bound and ran F2 with variables appearing in only 3 clauses. This significantly speeds up the reported results [4, p.14] , at some additional cost to accuracy. For example, on the problem 'blasted_case37' [4, p.14] report an estimate of log 2 (#models) ≈ 151.02 and a runtime of 4149.9 seconds. Running F2 with variables appearing in only 3 clauses, we computed the lower bound on log 2 (#models) of 148 in 2 seconds.

We also attempted to train a GNN, using the architecture from [43] to perform regression instead of classification. We used the author's code, making slight modifications to perform regression. However, we were not successful in achieving non-trivial learning. BPNN Training Protocol. We trained our BPNN to predict the natural logarithm of the number of satisfying solutions to an input boolean formula. We consider the general case of an input formula over n boolean variables, {X 1 , X 2 , . . . , X n }, in conjunctive normal form (CNF). Formulas in CNF are a conjunction of clauses, where each clause is a disjunction of literals. A literal is either a variable or its negation. We converted boolean formulas into factor graphs where each clause corresponds to a factor. Factors take the value of 1 for variable configurations that satisfy the clause and 0 for variable configurations that do not satisfy the clause. The partition function of this factor graph is equal to the number of satisfying solutions. We trained a BPNN architecture composed of 5 BPNN-D layers followed by a BPNN-B layer. We used the Adam optimizer [29] with learning rate decay.

Ablation Study. The columns labeled BPNN-NI and BPNN-DC in table 2 correspond to ablated versions of our BPNN model. We trained a BPNN with a BPNN-B layer that was not permutation invariant to local variable indexing, by removing the sum over permutations in S |xa| from Equation 7 and only passing in the original beliefs. We refer to this non-invariant version as BPNN-NI. We then forced BPNN-NI to 'double count' messages by changing the sums in Equations 5 and 6 to be over j ∈ N (a). We this non-invariant version that performs double counting as BPNN-DC. We observe validation improvement in BPNN over these ablated versions when generalization is particularly challenging, e.g. on 'blasted' problems individually and on all categories. Table 3 shows the root mean squared error (RMSE) of estimates from the approximate model counters ApproxMC3 and F2 across all training benchmarks in each category. Error was computed as the difference between the natural logarithm of the number of satisfying solutions and the estimate. The fraction of benchmarks that each approximate counter was able to complete within the time limit of 5k seconds is also shown.

Table 2: RMSE of BPNN for each training/validation set, along with ablation results. BPNN corresponds to a model with 5 BPNN-D layers followed by a Bethe layer that is invariant to the factor graph representation. BPNN-NI corresponds to removing invariance from the Bethe layer. BPNN-DC corresponds to performing ’double counting’ as is standard for GNN, rather than subtracting previously sent messages as is standard for BP. ’Random Split’ rows show that BPNNs are capable of learning a distribution from a tiny dataset of only 10s of training problems. ‘Easy / Hard’ rows additionally show that BPNNs are able to generalize from simple training problems to significantly more complex validation problems.
Table 3: Root mean squared error (RMSE) of estimates of the natural logarithm of the number of satisfying solutions is shown. The fraction of benchmarks within each category that each approximate counter was able to complete within the time limit of 5k seconds is shown in parentheses.

Additional Baseline Approximate Model Counter Information

For each benchmark category we show runtime percentiles for ApproxMC3, F2, and the exact model counter DSharp in Table 4 . The DSharp runtime column shows the runtime dividing our easy training sets and hard validation sets for each benchmark category. It also shows the median run time of each hard validation set (85th percentile). The median runtime in each category's hard validation set is 4 to 15 times longer than the longest runtime in each corresponding easy training set. We observe that F2 is generally tens or hundreds of times faster than ApproxMC3. On these benchmarks DSharp is generally faster than F2, however there exist problems that can be solved much faster by randomized hashing (ApproxMC3 or F2) than by DSharp [5, 44] .

Table 4: Runtime percentiles (in seconds) are shown for DSharp, ApproxMC3, and F2. Percentiles are computed separately for each category’s training dataset. In comparison, BPNN sequential runtime is nearly a constant and BPNN parallel runtime is limited by GPU memory.

F Stochastic Block Model Experiments

Stochastic Block Model Definition A C class Stochastic Block Model (SBM) is a randomly generated graph with N vertices, class assignment probabilities p i ; i ∈ 1, . . . , C, where C i=1 p i = 1, and edge probabilities e ij ; i, j ∈ 1, . . . , C. Then, to generate the graph, we sample a class for each node, c m ; m ∈ 1, . . . , N in accordance with the class assignment probabilities. Then, we sample the edge set E in the following manner: we take every pair of nodes x m , x n ; m, n ∈ 1, . . . , N and with probability e cm,cn assign an edge between those nodes.

Runtimes By Percentile

Category DSharp (0/70/85/100) ApproxMC3 (0/70/100) F2 (0/70/100) SBM Factor Graph Construction For a given SBM with N nodes, C classes, class assignment probabilities p i ; i ∈ 1, . . . , C, sampled class assignments c m ; m ∈ 1, . . . , N , edge probabilities e ij ; i, j ∈ 1, .., C, and sampled edge set E, we have the following unary factor potentials f i (x m ); i ∈ 1, . . . , C for every node x m ; m ∈ 1, . . . , N :

f i (x m ) = p i

We can construct binary factor potentials f ij (x m , x n ); i, j ∈ 1, .., C between nodes x m , x n ; (m, n) ∈ E as:

f ij (x m , x n ) = e cm,cn

and between nodes x m , x n ; (m, n) / ∈ E as:

f ij (x m , x n ) = 1 − e cm,cn

Note that when we fix a variable to a specific value, we simply set all factor potentials involving that variable that do not agree with that value to zero.

Marginal Calculation from Log Partitions Training a model to estimate partition functions with fixed variables is advantageous as we train the model to perform tasks that can directly be used to compute marginals, which are the probabilities that a node belongs to a specific class. This can be used to perform community detection or to quantify uncertainty and rare events in community membership. To see how we compute marginals with our experimental setup, take a two class SBM, select a node x m , fix its value to class 0 to obtain log partition function ln(Z 1 ) and then fix its value to class 1 to obtain log partition ln(Z 2 ). Then, the log marginals ln(x 0 m ) and ln(x 1 m ) are simply: ln(x 0 m ) = ln(Z 1 ) − ln(Z 1 + Z 2 ) and ln(x 1 m ) = ln(Z 2 ) − ln(Z 1 + Z 2 ) where

ln(Z 1 + Z 2 )

can be computed in a numerically stable fashion from ln(Z 1 ) and ln(Z 2 ) using the logsumexp trick.

Model and Training Details For baselines, we ran Belief Propagation to convergence with parallel updates and damping coefficient .5 as well as a Graph Isomorphism Network (GIN) with 30 layers and width 8. GIN is maximally discriminative among GNNs that consider 1-hop neighbors, which is computationally comparable to BPNN. In our evaluations, GIN performs comparably to more computationally expensive two hop GNNs on the related problem of SBM community detection [14] . We trained our GIN GNN architecture on the 5 class graph coloring community detection setting described in [14] and compared it to the performance of the two hop GNNs described there. Our GNN had 20 layers with a width of 8 and achieved a permutation invariant validation overlap score of .166 when trained for the same number of iterations, nearly identical to the two hop GNN performance reported in [14] . Since one hop GNNs train significantly faster than two hop, we man- Table 5 : RMSE of ln(Z) of BPNN against BP and GNN for SBM's generated from different distributions and larger graphs than the training or validation set. We see that BPNN outperforms both methods here across different edge probabilities, class probabilities, and on larger graphs. Furthermore, it generalizes better than GNN in all these settings.

Table 5: RMSE of ln(Z) of BPNN against BP and GNN for SBM’s generated from different distributions and larger graphs than the training or validation set. We see that BPNN outperforms both methods here across different edge probabilities, class probabilities, and on larger graphs. Furthermore, it generalizes better than GNN in all these settings.

aged to obtain overlap scores as high as .185 when training for longer. In any case, our one hop GNN performs comparably with two hop GNN architectures on the related task of SBM community detection and thus, in addition to its convenience, makes for a strong baseline method. For all models, we trained for 300 epochs on 1 GPU with an Adam Optimizer (learning rate of 2e-4, batch size of 8) minimizing Mean Squared Error between the estimated log partition and true log partition.

Out Of Distribution Generalization

We test the capacity of our BPNN (with no double counting and an invariant BPNN-B layer) to generalize to out of distribution graphs compared to the GNN model, while comparing both against the BP benchmark. Since the factor graphs are fully connected, slight changes to the initial parameters can produce rather large differences in the graphs and their log partition function. In addition to perturbing the initial class probabilities and edge probabilities, we also test the ability of BPNN to generalize to larger graphs, which is a desirable property as the Junction Tree algorithm for exact inference becomes exponentially more expensive as the size of the graph grows due to the fully connected nature of SBM factor graphs. For each scenario, we generate five separate graphs and generate test examples as mentioned previously. We present our results in Table 5 . We observe that BPNN performs the best of all three methods when class and edge probabilities are changed and generalizes better than GNN in these settings as well. Furthermore, when the size of graphs are increased, BPNN can outperform BP and GNN on graphs with as many as 20 nodes (a setting with over 80% more edges than training) and generalizes significantly better than GNN.

Since our SBM factor graphs are fully connected, adding n times more nodes leads to a O(n 2 ) increase in edges which may make it tougher for the model to generalize to larger and larger graphs. Using an auxiliary field approximation for SBM message passing, as described in [15] can help generalization to larger graphs, as in this case the increase in edges will increase linearly with graph size, and this is something to investigate further.

Marginal Estimation

We also compared BPNN to BP for marginal estimation, using the estimated log partition functions with single nodes set to a fixed value to calculate marginals for those nodes, as described above. Under the graph parameters used in these experiments, the marginals are usually extremely close to 1 and 0, but in such dense graphs, changes to the magnitude of these marginals can have large effects on the log partition function calculation. In some cases, BP computes the correct marginals under these conditions, but in some cases, it is off by 20-30 orders of magnitude on the smaller marginal. Such errors do not affect community recovery, however, when we care about very rare outcomes, they can have a big effect on quantifying uncertainty in community membership. On 15 node graphs, BPNN, by learning more accurate log partitions, is on average almost 5 orders of magnitude closer to the true marginals than BP but only an order more accurate than GNN. We see that on marginals, BPNN's overall performance and generalization ability relative to GNN is not as strong as it was with estimating partitions, likely because it is not specifically trained to estimate marginals, and estimating partitions and marginals, while quite related, are still different tasks. Training explicitly to estimate marginals, e.g. by correctly predicting the difference in partitions between graphs with one variable fixed to either value, may help performance and generalization ability of BPNN on marginals, and this is an area of further investigation.

Preprint. Under review.

For lack of space, all proofs are deferred to Appendix A.

Any message initialization can be used, as long as initial messages are equivariant, see Lemma 1.

An early version of our paper concurrent with[42] was submitted to UAI 2020: https://github.com/jkuck/jkuck.github.io/blob/master/files/BPNN_UAI_submission.pdf

https://github.com/QuMuLab/dsharp 9 https://github.com/meelgroup/mis

https://github.com/meelgroup/ApproxMC 11 https://github.com/ptheod/F2 12 https://github.com/msoos/cryptominisat