A quick intro to Bayesian neural networks
Making neural networks shrug their shoulders
This post is an exploration of some recent work that I did with the Tensorflow Team at Google. Most recently, Google added Tensorflow Probability (TFP) to the Tensorflow ecosystem. The TFP Github has been updated with plenty of examples. However, I decided I’d take the time to explore one part that I’m particularly excited about: Bayesian Neural Networks.
Bayesian neural networks are different from regular neural networks due to the fact that their states are described by probability distributions instead of single 1D float values for each parameter. Such probability distributions reflect weight and bias uncertainties, and therefore can be used to convey predictive uncertainty. Instead of typical direct backpropagation, these weight distribution parameters are learned through variational inference.
In this post, I go over some of the onceptual requirements for bayesian machine learning, outline just what bayesian ML has that deterministic ML doesn’t, and show you how to build the “Hello World” of Bayesian networks: A Bayesian LeNet trained using the method described in Weight Uncertainty in Neural Networks.
Outline
Part 1: The Basics
What deterministic NNs lack
Mere knowledge of the input-output mapping by a NN is inadequate when it is needed to gauge predictive uncertainty in their predictions.
Neural networks (NNs) have people pretty excited, and it’s easy to see why. The main selling point is that Neural networks can act as universal function approximators, including for ultra-complex functions between inputs and outputs such as image and natural language processing.
That being said, one of the main limitations of deterministic NNs is that they are fundamentally a frequentist tool. This is part of the reason why, when there’s not much data to work with, deterministic NNs will often overfit to the data. These overfit models can result in extrapolations of the data that to humans are obviously unfounded. This is most glaring when we try to apply a deterministic neural network to data far outside the realm of what it was trained on.
In a regression task, this can take the form of trendlines that far esceed the range of the original training data, but do not clarify which parts of the trendlines are more reliable than others. In classification tasks, a network will often have to choose one of it’s categories to assign data to, even if that instance does not belong to any of them. If you have an MNIST classifier, and you feed in a letter from the alphabet, the classifier will have to choose the closest letter no matter how remote the similarity is. This overfitting and need to stick to the model on uncertain data is part of how adversarial attacks work so well.
Ideally, we would want some kind of predictive uncertainty from our model that could reflect the confidence intervals of the model, but that’s deceptively tricky to put together from a regular determnistic NN.
Probabilistic Models and Incomplete Solutions
Even if you try to modify your softmax layer of your neural network, you still won’t be able to get a truly reliable confidence interval. The reason for this comes down to the kind of problem NNs try to solve.
Let’s imagine our Neural Network as a probabilistic model for the probability .
NN Type | Variable Type | Distribution Type |
---|---|---|
Regressor | Continuous | Gaussian |
Classifier | Categorical | Categorical |
With a training dataset , we can calculate a likelihood function over the data and . Maximizing this likelihood function gives the maximimum likelihood estimate (MLE) of . In other words, we’re maximizing the likelihood of the seen data given the network parameters .
Note: The usual optimization objective during training is the negative log likelihood. For a categorical distribution this is the cross entropy error function, for a Gaussian distribution this is proportional to the sum of squares error function.
For large numbers of parameters, backpropagation is our algorithm of choice for MLE optimization. Since it’s trying to maximize the probability of the data itself, the consequence can be overfitting and failure to generalize. This is the probability equivalent of “If all you have is a hammer, everything looks like a nail”. A useful fix to this is, instead of calculating the MLE, to calculate the maximum a posteriori (MAP) point estimates. This makes the model more resistant by optimizing for a data distribution that makes the parameters more likely.
Initially, it doesn’t seem like too big of a leap. For example instead of using L1 and L2 regularization for the MLE calculations, we can substitute with Gaussian Priors and Laplace Priors respectively for MAP calculations.
The issue is that this change in the optimization problem doesn’t completely fix the problem of unwanted extrapolation. As mentioned before, we want not just predictions but estimates of the confidence in or uncertainty about the predictions. Uncertainty should be highest away from the data and lower within the range of the training data (and vice versa for confidence).

A complete solution
Both MLE and MAP give point estimates of parameters. If we instead had a full posterior distribution over parameters we could make predictions that take weight uncertainty into account. This is covered by the posterior predictive distribution in which the parameters have been marginalized out. This is equivalent to averaging predictions from an ensemble of neural networks weighted by the posterior probabilities of their parameters .
Doing a full Bayesian inference (in order to estimate the entire posterior distribution) would allow us to do exactly this. Bayesian inference adjusts the beliefs about a distribution in the light of data or evidence
Full bayesian inference uses Bayes rule in the light of seen data to estimate a full posterior distribution of the parameters.
where, is our Posterior parameter distribution, is our Prior parameter distribution, is our evidence, and is our data likelihood. Since we don’t always have a data likelihood, we can reframe Bayes rule as an approximation:

The prediction step to compute output of the new samples, say is done by taking an expectation of the output over the optimized posterior parameter distribution. If our parameter distribution is , then our prediction function is
Functionally, this is roughly equivalent to predicting by averaging an infinite number of NNs by weighing their prediction with their posterior probability. The advantage is that this results in a built-in model-averaging component to our model, making it more resistant to noise. The obvious problem with this approach is that calculating an exact solution to this would require more computational power than humans have at their disposal. If you’re reading this, I’m guessing you have access to far less compute power than that. This averaging approach also means our equation is not differentiable, which means approximating by backpropagation is out of the question.
However, both the exact computation of the posterior and the prediction step as shown in the equations above are computationally intractable. Also, finding a form to differentiate with respect to parameters as distributions is not possible which is indispensable for backpropagation.
Our solution to this is to update our model through a process called Variational Inference. The next section goes into the mathematical details of how we do this.
Part 2: More Math for BNN training
Variational inference
As mentioned before, we need to approximate the true posterior with a variational distribution of known functional form whose parameters we want to estimate. This can be done by minimizing the Kullback-Leibler divergence between and the true posterior w.r.t. to . It can be shown that the corresponding optimization objective or cost function can be written as
This is known as the variational free energy. The first term is the Kullback-Leibler divergence between the variational distribution and the prior and is called the complexity cost. The second term is the expected value of the likelihood w.r.t. the variational distribution and is called the likelihood cost. By re-arranging the KL term, the cost function can also be written as
We see that all three terms in equation are expectations w.r.t. the variational distribution . The cost function can therefore be approximated by drawing Monte Carlo samples from .
In the following example, we’ll use a Gaussian distribution for the variational posterior, parameterized by where is the mean vector of the distribution and the standard deviation vector. The elements of are the elements of a diagonal covariance matrix which means that weights are assumed to be uncorrelated. Instead of parameterizing the neural network with weights directly we parameterize it with and and therefore double the number of parameters compared to a plain neural network.
Network training
A training iteration consists of a forward-pass and and backward-pass. During a forward pass a single sample is drawn from the variational posterior distribution. It is used to evaluate the approximate cost function defined by equation . The first two terms of the cost function are data-independent and can be evaluated layer-wise, the last term is data-dependent and is evaluated at the end of the forward-pass. During a backward-pass, gradients of and are calculated via backpropagation so that their values can be updated by an optimizer.
Since a forward pass involves a stochastic sampling step we have to apply the so-called re-parameterization trick for backpropagation to work. The trick is to sample from a parameter-free distribution and then transform the sampled with a deterministic function for which a gradient can be defined. Here, is drawn from a standard normal distribution i.e. and function shifts the sample by mean and scales it with where is element-wise multiplication.
For numeric stability we will parameterize the network with instead of directly and transform with the softplus function to obtain . This ensures that is always positive. As prior, a scale mixture of two Gaussians is used where , and are shared parameters. Their values are learned during training (which is in contrast to the paper where a fixed prior is used).
Uncertainty characterization
Uncertainty in predictions that arise from the uncertainty in weights is called epistemic uncertainty. This kind of uncertainty can be reduced if we get more data. Consequently, epistemic uncertainty is higher in regions of no or little training data and lower in regions of more training data. Epistemic uncertainty is covered by the variational posterior distribution. Uncertainty coming from the inherent noise in training data is an example of aleatoric uncertainty. It cannot be reduced if we get more data. Aleatoric uncertainty is covered by the probability distribution used to define the likelihood function.
Part 3: Simple Regression Example
Variational inference of neural network parameters is now demonstrated on a simple regression problem. We therefore use a Gaussian distribution for . The training dataset consists of 32 noisy samples X
, y
drawn from a sinusoidal function.
import numpy as np
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def f(x, sigma):
epsilon = np.random.randn(*x.shape) * sigma
return 10 * np.sin(2 * np.pi * (x)) + epsilon
train_size = 32
noise = 1.0
X = np.linspace(-0.5, 0.5, train_size).reshape(-1, 1)
y = f(X, sigma=noise)
y_true = f(X, sigma=0.0)
plt.scatter(X, y, marker='+', label='Training data')
plt.plot(X, y_true, label='Truth')
plt.title('Noisy training data and ground truth')
plt.legend();
The noise in training data gives rise to aleatoric uncertainty. To cover epistemic uncertainty we implement the variational inference logic in a custom DenseVariational
Keras layer. The learnable parameters of the mixture prior, , and , are shared across layers. The complexity cost (kl_loss
) is computed layer-wise and added to the total loss with the add_loss
method. Implementations of build
and call
directly follow the equations defined above.
from keras import backend as K
from keras import activations, initializers
from keras.layers import Layer
import tensorflow as tf
def mixture_prior_params(sigma_1, sigma_2, pi, return_sigma=False):
params = K.variable([sigma_1, sigma_2, pi], name='mixture_prior_params')
sigma = np.sqrt(pi * sigma_1 ** 2 + (1 - pi) * sigma_2 ** 2)
return params, sigma
def log_mixture_prior_prob(w):
comp_1_dist = tf.distributions.Normal(0.0, prior_params[0])
comp_2_dist = tf.distributions.Normal(0.0, prior_params[1])
comp_1_weight = prior_params[2]
return K.log(comp_1_weight * comp_1_dist.prob(w) + (1 - comp_1_weight) * comp_2_dist.prob(w))
# Mixture prior parameters shared across DenseVariational layer instances
prior_params, prior_sigma = mixture_prior_params(sigma_1=1.0, sigma_2=0.1, pi=0.2)
Our model is a neural network with two DenseVariational
hidden layers, each having 20 units, and one DenseVariational
output layer with one unit. Instead of modeling a full probability distribution as output the network simply outputs the mean of the corresponding Gaussian distribution. In other words, we do not model aleatoric uncertainty here and assume it is known. We only model epistemic uncertainty via the DenseVariational
layers.
Since the training dataset has only 32 examples we train the network with all 32 examples per epoch so that the number of batches per epoch is 1. For other configurations, the complexity cost (kl_loss
) must be weighted by as described in section 3.4 of the paper where is the number of mini-batches per epoch.
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
batch_size = train_size
num_batches = train_size / batch_size
kl_loss_weight = 1.0 / num_batches
# Specify the surrogate posterior over `keras.layers.Dense` `kernel` and `bias`.
def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
c = np.log(np.expm1(1.))
return tf.keras.Sequential([
tfp.layers.VariableLayer(2 * n, dtype=dtype),
tfp.layers.DistributionLambda(lambda t: tfd.Independent(
tfd.Normal(loc=t[..., :n],
scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
reinterpreted_batch_ndims=1)),
])
# Specify the prior over `keras.layers.Dense` `kernel` and `bias`.
def prior_trainable(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
return tf.keras.Sequential([
tfp.layers.VariableLayer(n, dtype=dtype),
tfp.layers.DistributionLambda(lambda t: tfd.Independent(
tfd.Normal(loc=t, scale=1),
reinterpreted_batch_ndims=1)),
])
# Build model.
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(1,)),
tfp.layers.DenseVariational(units=20,
make_posterior_fn=posterior_mean_field,
make_prior_fn=prior_trainable,
kl_weight=kl_loss_weight,
activation='relu'),
tfp.layers.DenseVariational(units=20,
make_posterior_fn=posterior_mean_field,
make_prior_fn=prior_trainable,
kl_weight=kl_loss_weight,
activation='relu'),
tfp.layers.DenseVariational(units=1,
make_posterior_fn=posterior_mean_field,
make_prior_fn=prior_trainable,
kl_weight=kl_loss_weight)
])
The network can now be trained with a Gaussian negative log likelihood function (neg_log_likelihood
) as loss function assuming a fixed standard deviation (noise
). This corresponds to the likelihood cost, the last term in equation .
from keras import callbacks, optimizers
def neg_log_likelihood(y_true, y_pred, sigma=noise):
dist = tf.distributions.Normal(loc=y_pred, scale=sigma)
return K.sum(-dist.log_prob(y_true))
model.compile(loss=neg_log_likelihood, optimizer=Adam(lr=0.03), metrics=['mse'])
model.fit(X, y, batch_size=batch_size, epochs=1500, verbose=0);
When calling model.predict
we draw a random sample from the variational posterior distribution and use it to compute the output value of the network. This is equivalent to obtaining the output from a single member of a hypothetical ensemble of neural networks. Drawing 500 samples means that we get predictions from 500 ensemble members. From these predictions we can compute statistics such as the mean and standard deviation. In our example, the standard deviation is a measure of epistemic uncertainty.
import tqdm
X_test = np.linspace(-1.5, 1.5, 1000).reshape(-1, 1)
y_pred_list = []
for i in tqdm.tqdm(range(500)):
y_pred = model.predict(X_test)
y_pred_list.append(y_pred)
y_preds = np.concatenate(y_pred_list, axis=1)
y_mean = np.mean(y_preds, axis=1)
y_sigma = np.std(y_preds, axis=1)
plt.plot(X_test, y_mean, 'r-', label='Predictive mean');
plt.scatter(X, y, marker='+', label='Training data')
plt.fill_between(X_test.ravel(),
y_mean + 2 * y_sigma,
y_mean - 2 * y_sigma,
alpha=0.5, label='Epistemic uncertainty')
plt.title('Prediction')
plt.legend();
We can clearly see that epistemic uncertainty is much higher in regions of no training data than it is in regions of existing training data. The predictive mean could have also been obtained with a single forward pass i.e. a single model.predict call by using only the mean of the variational posterior distribution which is equivalent to sampling from the variational posterior with set to . The corresponding implementation is omitted here but is trivial to add.
For an example how to model both epistemic and aleatoric uncertainty I recommend reading Regression with Probabilistic Layers in TensorFlow Probability which uses probabilistic Keras layers from the upcoming Tensorflow Probability 0.7.0 release. Their approach to variational inference is similar to the approach described here but differs in some details. For example, they compute the complexity cost analytically instead of estimating it from Monte Carlo samples, among other differences.
Part 4: Bayesian LeNet5 in Tensorflow Probability
Thanks to Tensorflow Probability, we can extend our bayesian example to an image classification task with relative ease. Much of our process for building the model is similar. For example, we import the usual dependencies (along with TFP). We import MNIST and set the image dimensions to the usual pixels. We set our learning rate to be 0.001, our number of maximum steps to be 6000, and our batch size to 128. We also set aside a directory to record test accuracy at select steps (in our case, every 400 steps).
There are some parameters to set that are uncommon to typical deep learning. Our bayesian network’s output probability distributions are going to be defined by draws from the distributions within the network. In short, we need to define a number of network draws (for demo purposes, we’ll stick with 50 for now).
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
# Dependency imports
import matplotlib
matplotlib.use("Agg")
from matplotlib import figure # pylint: disable=g-import-not-at-top
from matplotlib.backends import backend_agg
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.learn.python.learn.datasets import mnist
warnings.simplefilter(action="ignore")
tfd = tfp.distributions
IMAGE_SHAPE = [28, 28, 1]
learning_rate = 0.001 # Initial learning rate.
max_steps = 6000 #Number of training steps to run.
batch_size = 128 #Batch size.
# Directory where data is stored (if using real data).")
data_dir = os.path.join(os.getenv("TEST_TMPDIR", "/tmp"),
"bayesian_neural_network/data")
# Directory to put the model's fit.
model_dir = os.path.join(os.getenv("TEST_TMPDIR", "/tmp"),
"bayesian_neural_network/")
viz_steps = 400 #Frequency at which save visualizations.
num_monte_carlo = 50 #Network draws to compute predictive probabilities.
fake_data = None #If true, uses fake data. Defaults to real data.
As our model trains we’ll want to visualize how the weight posteriors change. Here is our function for visualizing the distributions of posteriors in layers at various depths throughout the network.
def plot_weight_posteriors(names, qm_vals, qs_vals, fname):
fig = figure.Figure(figsize=(6, 3))
canvas = backend_agg.FigureCanvasAgg(fig)
ax = fig.add_subplot(1, 2, 1)
for n, qm in zip(names, qm_vals):
sns.distplot(qm.flatten(), ax=ax, label=n)
ax.set_title("weight means")
ax.set_xlim([-1.5, 1.5])
ax.legend()
ax = fig.add_subplot(1, 2, 2)
for n, qs in zip(names, qs_vals):
sns.distplot(qs.flatten(), ax=ax)
ax.set_title("weight stddevs")
ax.set_xlim([0, 1.])
fig.tight_layout()
canvas.print_figure(fname, format="png")
print("saved {}".format(fname))
We also want to be able to plot the posterior uncertainty on data that’s been omitted from the training process. The function below does exactly that, and neatly saves the resulting plots to PNG files.
def plot_heldout_prediction(input_vals, probs, fname, n=10, title=""):
fig = figure.Figure(figsize=(9, 3*n))
canvas = backend_agg.FigureCanvasAgg(fig)
for i in range(n):
ax = fig.add_subplot(n, 3, 3*i + 1)
ax.imshow(input_vals[i, :].reshape(IMAGE_SHAPE[:-1]), interpolation="None")
ax = fig.add_subplot(n, 3, 3*i + 2)
for prob_sample in probs:
sns.barplot(np.arange(10), prob_sample[i, :], alpha=0.1, ax=ax)
ax.set_ylim([0, 1])
ax.set_title("posterior samples")
ax = fig.add_subplot(n, 3, 3*i + 3)
sns.barplot(np.arange(10), np.mean(probs[:, i, :], axis=0), ax=ax)
ax.set_ylim([0, 1])
ax.set_title("predictive probs")
fig.suptitle(title)
fig.tight_layout()
canvas.print_figure(fname, format="png")
print("saved {}".format(fname))
One thing that remains the same between bayesian networks and deterministic networks is the need to maintain separation between training and test data. The function below constructs and iterator that can alternate between training and validation/heldout data where appropriate.
def build_input_pipeline(mnist_data, batch_size, heldout_size):
training_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.train.images, np.int32(mnist_data.train.labels)))
training_batches = training_dataset.shuffle(
50000, reshuffle_each_iteration=True).repeat().batch(batch_size)
training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)
heldout_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.validation.images,
np.int32(mnist_data.validation.labels)))
heldout_frozen = (heldout_dataset.take(heldout_size).
repeat().batch(heldout_size))
heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_frozen)
handle = tf.compat.v1.placeholder(tf.string, shape=[])
feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
handle, training_batches.output_types, training_batches.output_shapes)
images, labels = feedable_iterator.get_next()
return images, labels, handle, training_iterator, heldout_iterator
And of course, we also have our random noise image generating function for testing our network on non-standard inputs (with a bonus of being useful for unit testing).
def build_fake_data(num_examples=10):
class Dummy(object):
pass
num_examples = 10
mnist_data = Dummy()
mnist_data.train = Dummy()
mnist_data.train.images = np.float32(np.random.randn(
num_examples, *IMAGE_SHAPE))
mnist_data.train.labels = np.int32(np.random.permutation(
np.arange(num_examples)))
mnist_data.train.num_examples = num_examples
mnist_data.validation = Dummy()
mnist_data.validation.images = np.float32(np.random.randn(
num_examples, *IMAGE_SHAPE))
mnist_data.validation.labels = np.int32(np.random.permutation(
np.arange(num_examples)))
mnist_data.validation.num_examples = num_examples
return mnist_data
if tf.io.gfile.exists(model_dir):
tf.io.gfile.rmtree(model_dir)
tf.io.gfile.makedirs(model_dir)
if fake_data:
mnist_data = build_fake_data()
else:
mnist_data = mnist.read_data_sets(data_dir, reshape=False)
(images, labels, handle,
training_iterator, heldout_iterator) = build_input_pipeline(
mnist_data, batch_size, mnist_data.validation.num_examples)
Now for the actual network architecture you’ve been waiting for. We’ll construct a Bayesian LeNet5 network. We use the Flipout Monte Carlo estimator for the convolution and fully-connected layers: this enables lower variance stochastic gradients than naive reparameterization.
And yes, like in other tensorflow graph implementations, we can set a name scope so we can navigate our model in Tensorboard.
with tf.compat.v1.name_scope("bayesian_neural_net", values=[images]):
neural_net = tf.keras.Sequential([
tfp.layers.Convolution2DFlipout(6,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
strides=[2, 2],
padding="SAME"),
tfp.layers.Convolution2DFlipout(16,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
strides=[2, 2],
padding="SAME"),
tfp.layers.Convolution2DFlipout(120,
kernel_size=5,
padding="SAME",
activation=tf.nn.relu),
tf.keras.layers.Flatten(),
tfp.layers.DenseFlipout(84, activation=tf.nn.relu),
tfp.layers.DenseFlipout(10)
])
logits = neural_net(images)
labels_distribution = tfd.Categorical(logits=logits)
We compute the -ELBO as the loss, averaged over the batch size.
neg_log_likelihood = -tf.reduce_mean(
input_tensor=labels_distribution.log_prob(labels))
kl = sum(neural_net.losses) / mnist_data.train.num_examples
elbo_loss = neg_log_likelihood + kl
Next we build metrics for evaluation. Predictions are formed from a single forward pass of the probabilistic layers. As you can imagine, these are noisy predictions. Their main redeeming quality is that they’re computationally cheap enough to allow us to do thousands of training steps.
predictions = tf.argmax(input=logits, axis=1)
accuracy, accuracy_update_op = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=predictions)
We want to be sure to exct weight posterior statistics for layers with weight distributions for later visualization.
names = []
qmeans = []
qstds = []
for i, layer in enumerate(neural_net.layers):
try:
q = layer.kernel_posterior
except AttributeError:
continue
names.append("Layer {}".format(i))
qmeans.append(q.mean())
qstds.append(q.stddev())
And finally, we can run the training loop.
with tf.compat.v1.name_scope("train"):
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate)
train_op = optimizer.minimize(elbo_loss)
init_op = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
with tf.compat.v1.Session() as sess:
sess.run(init_op)
# Run the training loop.
train_handle = sess.run(training_iterator.string_handle())
heldout_handle = sess.run(heldout_iterator.string_handle())
for step in range(max_steps):
_ = sess.run([train_op, accuracy_update_op],
feed_dict={handle: train_handle})
if step % 100 == 0:
loss_value, accuracy_value = sess.run(
[elbo_loss, accuracy], feed_dict={handle: train_handle})
print("Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f}".format(step, loss_value, accuracy_value))
if (step+1) % viz_steps == 0:
# Compute log prob of heldout set by averaging draws from the model:
# p(heldout | train) = int_model p(heldout|model) p(model|train)
# ~= 1/n * sum_{i=1}^n p(heldout | model_i)
# where model_i is a draw from the posterior p(model|train).
probs = np.asarray([sess.run((labels_distribution.probs),
feed_dict={handle: heldout_handle}) for _ in range(num_monte_carlo)])
mean_probs = np.mean(probs, axis=0)
image_vals, label_vals = sess.run((images, labels),
feed_dict={handle: heldout_handle})
heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
label_vals.flatten()]))
print(" ... Held-out nats: {:.3f}".format(heldout_lp))
qm_vals, qs_vals = sess.run((qmeans, qstds))
plot_weight_posteriors(names, qm_vals, qs_vals,
fname=os.path.join(
model_dir,
"step{:05d}_weights.png"
.format(step)))
plot_heldout_prediction(image_vals, probs,
fname=os.path.join(
model_dir,
"step{:05d}_pred.png".format(step)),
title="mean heldout logprob {:.2f}"
.format(heldout_lp))
Once our training is done, we should have a full folder of test images at each stage in the training.
For example, see how our weight distributions change between Training Step 400 and Training Step 6000:


Likewise, we can see the class probabilities our network assigns to the various digits in the network:

It does pretty good, but let’s see how this does at Step 6000:

Much better. That particular digit still caused some uncertainty (you can see how it was easier to confuse with a digit), but for all the others the probability is much easier.
Part 5: Giving our LeNet unfamiliar data
So we saw how our Bayesian LeNet made it’s decisions. It took samples from the distirbutions that made up the weights, and then used this to construct the class probabilities. We also saw that these class probabilities do not need to add up to a certain amount (like with the softmax layer in regular deterministic neural networks). That’s because each of the class probabilities is sampled independently.
This makes a pretty big impact on how our neural network makes it’s decisions. We can demonstrate this by straying outside the MNIST dataset.
We can take data from our fake-MNIST-generating function, which just makes images out of random noise. After 6000 training steps, this is how our model treats the noisy images.

Pretty impressive! This illustrates one of the other add-ons we can easily make for bayesian neural networks: a probability cutoff. In this case, if none of our probabilities exceed 0.2, we can get our network to refuse to classify the images.
Of course, it wasn’t always like this. It took our network a while before it was correctly able to refuse to classify the noise. Take the example of the same images at step 400/6000 during the training:

This should serve as a reminder that no new machine learning framework is magic. There are best practices and best use-cases for them. In our case, one of the best practices is training the network for an adequate amount of time.
Part 6: Design considerations
How do we choose the priors?
This is one of the more controversial questions regarding bayesian models, and even probabilistic programming in general. There isn’t universal agreement on this, but there are a few helpful ideas to take into account.
The central limit theorem states that, regardless of the true underlying distribution, samples taken from data will steadily approximate a normal distribution.
It may not be perfect, but choosing a normal distribution as a prior is one of the better ways to initialize a bayesian neural network (though I will admit that it’s a pretty low bar towards finding a better replacement strategy)
What happens if we make a network part Bayesian and part deterministic?
The obvious result is that you will still be able to get some samples, but the results will be worse at conveying uncertainty than the full bayesian network. As we saw earlier, we can still have failure cases for Bayesian networks.
It should be stressed that full BNNs should be categorized differently from NNs that have distributions defined over their hidden units as opposed to their parameters. The full BNN is designed for expressing uncertainty about observations in question. The regular NN with probability distributions is, at best, a useful tool for regularization and model averaging.
What kinds of hardware optimization options are available?
For deterministic algorithms like this, this is still one area where GPUs and TPUs aren’t quite as optimized as they could be. Most of the deep learning hardware is built with deterministic models in mind.
How do we convert Deterministic Layers to Probabilistic Layers in Tensorflow?
A few of the layers for bayesian networks have very close analogs to deterministic layers. For example, DenseFlipout
corresponds to Dense
Layers, Conv2DFlipout
corresponds to Conv2D
, and so on.
That being said, there are still plenty of models that do not actually have many close analogs. At it’s core, the way Bayesian neural networks function by using samplers like Monte Carlo is fundamentally different from deterministic ones by virtue of the updating method. Many of the layers offered by packages like TFP are outside the realm of what was easy to implement in neural networks previously.
Are there other ways of updating bayesian models?
Hence various ways to approximate this in the context of BNNs have been developed which yields us a wide variety of BNNs today([4], [5], [6], [7], [8], [9], [10], [11]). Given how young this part of Machine Learning is, it would probably be helpful at this point to set up alerts on Google Scholar for Bayesian Machine Learning.
Part 7: Closing Remarks
In conclusion, BNNs are useful for integrating and modeling uncertainties. Furthermore, they have also been shown to improve predictive performances([4], [14]) and do systematic exploration([13]). Recent advances in the field of deep learning and hardware allow us to approximate the relevant quantities scalably using off-the-shelf optimizers. The fundamental problems in developing BNNs or any probabilistic model are the intractable computations of the posterior distribution and their expectations. Hence we have to resort to their approximation. There are broadly two categories of methods of doing this approximation – stochastic (eg. Markov Chain Monte Carlo) and deterministic (eg. variational inference). For readers interested in knowing more about them, I would point to two resources.
Further Reading
- Chapters 10 and 11 of the book Pattern Recognition and Machine Learning by Christopher Bishop,
- Talk on Scalable Bayesian Inference by David Dunson during NeurIPS 2018, Montreal.
References
- Hornik, Kurt, Maxwell Stinchcombe, and Halbert White. “Multilayer feedforward networks are universal approximators.” Neural networks 2, no. 5 (1989): 359-366.
- Cybenko, George. “Approximations by superpositions of a sigmoidal function.” Mathematics of Control, Signals and Systems 2 (1989): 183-192.
- Goodfellow, Ian, Yoshua Bengio, Aaron Courville, and Yoshua Bengio. Deep learning. Vol. 1. Cambridge: MIT press, 2016.
- Blundell, Charles, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. “Weight uncertainty in neural networks.” arXiv preprint arXiv:1505.05424 (2015).
- Gal, Yarin, and Zoubin Ghahramani. “Dropout as a Bayesian approximation: Representing model uncertainty in deep learning.” In international conference on machine learning, pp. 1050-1059. 2016.
- Bui, Thang D., José Miguel Hernández-Lobato, Yingzhen Li, Daniel Hernández-Lobato, and Richard E. Turner. “Training deep Gaussian processes using stochastic expectation propagation and probabilistic backpropagation.” arXiv preprint arXiv:1511.03405 (2015).
- Minka, Thomas P. “Expectation propagation for approximate Bayesian inference.” In Proceedings of the Seventeenth conference on Uncertainty in artificial intelligence, pp. 362-369. Morgan Kaufmann Publishers Inc., 2001.
- Hernández-Lobato, José Miguel, and Ryan Adams. “Probabilistic backpropagation for scalable learning of bayesian neural networks.” In International Conference on Machine Learning, pp. 1861-1869. 2015.
- Neal, Radford M. Bayesian learning for neural networks. Vol. 118. Springer Science & Business Media, 2012.
- MacKay, David JC. “A practical Bayesian framework for backpropagation networks.” Neural computation 4, no. 3 (1992): 448-472.
- Jylänki, Pasi, Aapo Nummenmaa, and Aki Vehtari. “Expectation propagation for neural networks with sparsity-promoting priors.” The Journal of Machine Learning Research 15, no. 1 (2014): 1849-1901.
- Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).
- Houthooft, Rein, Xi Chen, Yan Duan, John Schulman, Filip De Turck, and Pieter Abbeel. “Curiosity-driven exploration in deep reinforcement learning via bayesian neural networks.” arXiv preprint arxiv.1605.09674 (2016).
- Yoon, Jaesik, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn. “Bayesian Model-Agnostic Meta-Learning.” In Advances in Neural Information Processing Systems, pp. 7342-7352. 2018.
Cited as:
@article{mcateer2019bayesnn,
title = "A quick intro to Bayesian neural networks",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2019",
url = "https://matthewmcateer.me/blog/a-quick-intro-to-bayesian-neural-networks/"
}
If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me]
and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.
See you in the next post 😄