Sampling via Moment Sharing: A New Framework for Distributed Bayesian Inference for Big Data (Oxford) in collaboration with: Minjie Xu, Jun Zhu, Bo Zhang (Tsinghua) Balaji Lakshminarayanan (Gatsby)
Bayesian Inference! Parameter vector X. X! Data items Y = y1, y2,... yn. y 1 y 2 y 3 y 4... y N! Model:! Aim: p(x, Y )=p(x) p(x Y )= NY i=1 p(y i X) p(x)p(y X) p(y )
Why Bayes for Machine Learning?! An important framework to frame learning.! Quantification of uncertainty.! Flexible and intuitive construction of complex models.! Straightforward derivation of learning algorithms.! Mitigation of overfitting.
Big Data and Bayesian Inference?! Large scale datasets are fast becoming the norm.! Analysing and extracting understanding from these data is a driver of progress in many sectors of society.! Current successes in scalable learning are optimizationbased and non-bayesian.! What is the role of Bayesian learning in world of Big Data?
Generic (Machine) Learning on Big Data! Stochastic optimisation using mini-batches.! Stochastic gradient descent. > Stochastic Gradient Langevin Dynamics (Welling & Teh, Teh et al)! Distributed/parallel computations on cores/clusters/gpus.! MapReduce, parameter server.! Bringing the computations to the data, not the reverse.! High communication costs. > Distributed Bayesian Posterior Sampling via Moment Sharing (Xu et al)! High synchronisation costs. > Asynchronous Anytime Sequential Monte Carlo (Paige et al)
Generic (Bayesian) Learning on Big Data! Stochastic optimisation using mini-batches.! Stochastic gradient descent.! > Stochastic Gradient Langevin Dynamics [Welling & Teh 2011, Patterson & Teh 2013, Teh et al (forthcoming)]! Distributed/parallel computations on cores/clusters/gpus.! MapReduce, parameter server.! Bringing the computations to the data, not the reverse.! High communication costs.! > Distributed Bayesian Posterior Sampling via Moment Sharing [Xu et al 2014]! High synchronisation costs.! > Asynchronous Anytime Sequential Monte Carlo [Paige et al 2014]
Machine Learning on Distributed Systems! Distributed storage! Distributed computation! Network communication costs y 1i y 2i y 3i y 4i
Embarassingly Parallel MCMC Sampling Combine samples together. {X i } i=1...n Treat as independent inference problems. Collect samples. y 1i y 2i y 3i y 4i {X ji } j=1...m,i=1...n! Only communication at the combination stage.
! where Local and Global Posteriors! Each worker machine j has access only to its data subset. p j (X y j )=p j (X) pj(x) is a local prior and pj(x yj) is local posterior. IY i=1 p(y ji X)! The (target) global posterior is p(x y) / p(x) my j=1! If prior p(x) = j pj(x), then p(x y) / p(y j X) / p(x)! Given collection of samples { Xji }i=1 n from pj(. y), how do we get { Xi }i=1 n samples from p(. y)? my j=1 p j (X y j ) my j=1 p j (X y j ) p j (X)
Consensus Monte Carlo! Each worker machine j collects N samples {Xmn} from: p j (X y j )=p(x) 1/m IY i=1 p(y ji X)! Master machine combines samples by weighted average: 0 mx X i = @ W j 1 A 1 m X j=1 j=1 W j X ji [Scott et al 2013]
Consensus Monte Carlo X i = 0 @ mx W j 1 A 1 m X W j X ji j=1 j=1! Combination is correct if local posteriors are Gaussian.! Weights are local posterior precisions.! If not Gaussian, makes strong assumptions and unclear what local priors and weights for it to work. [Scott et al 2013]
Approximating Local Posterior Densities! [Neiswanger et al 2013] proposed methods to combine estimates of local posterior densities instead of samples:! Parametric: Gaussian approximation.! Nonparametric: kernel density estimation based on samples.! Semiparametric: Product of a parametric Gaussian approximation with a nonparametric KDE correction term. p(x y) / my j=1 p j (X y j ) my j=1 1 n nx K hj (X; X ji ) i=1! Combination: Product of (approximate) densities.! Sampling: Resort to Metropolis-within-Gibbs.! [Wang & Dunson 2013] s Weierstrass sampler is similar, using rejection sampling instead. [Neiswanger et al 2013, Wang & Dunson 2013]
Approximating Local Posterior Densities! Parametric approximation can be quite bad unless Bernstein-von Mises Theorem kicks in.! Complex and expensive combination step in non- and semi-parametric estimates.! KDE suffers from curse of dimensionality.! Performs poorly if local posteriors differ significantly.
Intuition and Desiderata! Distributed system with independent MCMC sampling.! Identify regions of high (global) posterior probability mass.! Each local sampler is based on local data, but concentrate on high probability regions.! High probability regions found using samples, by allowing for some small amount of communication.
(Not Quite) Embarrassingly Parallel MCMC! Allow some amount of communication to align worker MCMC samplers.! High probability region defined by low order moments.! Align using Expectation Propagation (EP). y 1i y 2i y 3i y 4i! Asynchronous and infrequent updates.
Expectation Propagation! If N is large, the worker j likelihood term p(yj X) should be well approximated by Gaussian p(y j X) q j (X) =N (X; µ j, j )! Parameters fit iteratively using a variational approach to minimize KL divergence: p(x y) p j (X y) / p(y j X) p(x) Y k6=j q k (X) {z } p j (X) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) [Minka 2001]
Expectation Propagation p(x y) p j (X y) / p(y j X) p(x) Y k6=j q k (X)! Update performed as follows: {z } p j (X) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( )! Compute (or estimate) first two moments µ*, Σ* of pj( X y).! Compute µj, Σj so that N(.; µj, Σj) pj( X )/Z has moments µ*, Σ*.! Computations done on natural parameters.! Generalizes to other exponential families.
Expectation Propagation q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p(x)! Variational parameters fit iteratively until convergence.! EP tends to converge very quickly (when it does).! Damping updates can help convergence. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! At convergence, all local posteriors agree on their first two moments. y 1i y 2i y 3i y 4i! Generalizes to hierarchical and graphical models [infer.net, Gelman et al 2014].
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p(x)! KL minimized by matching moments of pj(x y).! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i! Communicate only moments, synchronous or asynchronous.
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p j ( ) p(x)! KL minimized by matching moments of pj(x y).! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i! Communicate only moments, synchronous or asynchronous.
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p j ( ) p(x)! KL minimized by matching moments of pj(x y).! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i {X ji }! Communicate only moments, synchronous or asynchronous.
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p j ( ) p(x)! KL minimized by matching moments of pj(x y).! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i {X ji } ) (µ, )! Communicate only moments, synchronous or asynchronous.
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p j ( ) p(x)! KL minimized by matching moments of pj(x y).! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i {X ji } ) (µ, ) ) (µ j, j )! Communicate only moments, synchronous or asynchronous.
Sampling via Moment Sharing (SMS) q new j ( ) = arg min N ( ;µ, ) KL p j( y) k N ( ; µ, )p j ( ) p(x)! KL minimized by matching moments of pj(x y). p j ( ) q j ( )! Moments computed by drawing MCMC samples. p(y1 X) q1(x) p(y2 X) q2(x) p(y3 X) q3(x) p(y4 X) q4(x)! All samples from all machines can be treated as approximate samples from full posterior given all data. y 1i y 2i y 3i y 4i {X ji } ) (µ, ) ) (µ j, j )! Communicate only moments, synchronous or asynchronous.
Bayesian Logistic Regression 1 1 0.5 0.5 0 0 0.5 0.5 1 1 1.5 1.5 2 2 2.5 250 500 750 1000 1250 1500 k T N/m 10 3 2.5 100 200 300 400 500 600 k T N/m 10 3! Simulated dataset.! d=20, # data items N=1000.! NUTS base sampler.! # workers m = 4,10,50.! # MCMC iters T = 1000,1000,10000.! # EP iters k given as vertical lines. 1 0.5 0 0.5 1 1.5 2 2.5 200 400 600 800 1000 1200 1400 k T N/m 10 3
Bayesian Logistic Regression! MSE of posterior mean, as function of total # iterations. 10 0 10 2 10 4 10 6 SMS(s) SMS(a) SCOT NEIS(p) NEIS(n) WANG 3.2 6.4 9.6 12.8 16 19.2 k T m x 10 5
Bayesian Logistic Regression! Approximate KL, MSE of predictive probabilities, as function of total # iterations. 10 2 10 1 10 2 10 1 10 3 10 4 10 0 10 1 SMS(s) SMS(a) SCOT WANG 3.2 6.4 9.6 12.8 16 19.2 k T m x 10 5 10 5 10 6 10 7 SMS(s) SMS(a) SCOT NEIS(n) WANG 3.2 6.4 9.6 12.8 16 19.2 k T m x 10 5
Bayesian Logistic Regression! Approximate KL as function of # nodes. 2.5 2 SMS(s,s) SMS(s,e) SMS(a,s) SMS(a,e) SCOT XING(p) 1.5 1 0.5 0 m=8 m=16 m=32 m=48 m=64
Bayesian Logistic Regression! Approximate KL, as function of # iterations per node and # likelihood evaluations. 10 2 10 1 SMS(s) SMS(a) m = 8 m = 16 m = 32 m = 48 m = 64 10 2 10 1 SMS(s) SMS(a) m = 8 m = 16 m = 32 m = 48 m = 64 10 0 10 0 10 1 10 1 10 2 0 1 2 3 4 5 6 7 k T x 10 4 10 2 0 0.5 1 1.5 2 2.5 k T N/m x 10 8
Spike-and-Slab Sparse Regression! Posterior mean coefficients. 0.4 0.2 0 0.2 0.4 0.4 0.2 0 0.2 0.4 0 1000 2000 3000 4000 k T N/m 10 3 0 500 1000 1500 2000 k T N/m 10 3
Some Remarks! Scalable distributed MCMC sampling.! A bit of communication goes a long way.! Issue with stochasticity of moment estimates:! EP theory does not cover stochastic updates.! Not clear what is the best stochastic update to use.! Nor how can we characterise convergence and quality of approximation.! Matlab source: https://github.com/chokkyvista/smssample
Other Approaches to Scalable Bayes! Median posterior [Stanislav et al 2014]:! Embeds local posteriors into an RKHS, and computes the geometric median.! Improves robustness to outliers in data.! Stochastic gradient MCMC approaches:! Reduce cost of each MCMC step by using data subset.! A distributed version have been developed.! [Welling & Teh 2011, Ahn et al 2012, 2014, Teh, Thiery & Vollmer (forthcoming), Bardenet et al 2014]! Variational approaches:! Faster convergence, with possibly significant bias.! Recent works successfully extend these to large scale datasets using stochastic approximation techniques [Hoffman et al 2010, 2013, etc] and to flexible parameterized variational distributions [Mnih & Gregor 2014, Rezende et al 2014, Kingma & Welling 2014].
Bigger Picture! The probabilistic modelling/bayesian inference approach offers a principled and powerful data analysis framework.! Standard methodologies do not extend easily to Big Data.! Important to develop generic methodologies allowing these approaches to be applicable on Big Data.! Bias/variance trade-offs becoming more important.! Low bias exact methods do not scale as well to Big Data.
Thank you! Thanks for funding: