# Exploring DeepMind's PonderNet

## How to get a neural network to pause and think more about an answer

DeepMind recently released PonderNet. This is a technique that allows a neural network to change its computation complexity based on the complexity of the input sample.

## What is going on under the Hood?

There has been plenty of research (including by DeepMind itself) on the versatility of RNNs in creating networks that can perform complex tasks. This was the core of the Neural Turing Machine paper. The main advance of PonderNet is to tune just how many RNN cycles are needed for a given input. For simpler problems, the network is able to take fewer steps through the RNN. All of this is done with end-to-end gradient descent.

PonderNet has a step function of the form $\hat{y}_n, h_{n+1}, \lambda_n = s(x, h_n)$ where $x$ is the input, $h_n$ is the state, $\hat{y}_n$ is the prediction at step $n$, and $\lambda_n$ is the probability of halting (stopping) at current step. $s$ can be any neural network (e.g. LSTM, MLP, GRU, Attention layer). The unconditioned probability of halting at step $n$ is then, $p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j)$ That is the probability of not being halted at any of the previous steps and halting at step $n$. During inference, we halt by sampling based on the halting probability $\lambda_n$ and get the prediction at the halting layer $\hat{y}_n$ as the final output. During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer $p_n$. The step function is applied to a maximum number of steps donated by $N$. The overall loss of PonderNet is

$L = L_{Rec} + \beta L_{Reg} \\ L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n) \\ L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big)$

$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

$\mathop{KL}$ is the Kullback–Leibler divergence. $p_G$ is the Geometric distribution parameterized by $\lambda_p$. $\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper. $Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$. The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.

## Ponder Parity with GRU

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of n_elems.

Each element of the vector is either 0, 1 or -1 and the output is the parity - a binary value that is true if the number of 1s is odd and false otherwise. The prediction of the model is the log probability of the parity being $1$.

• n_elems is the number of elements in the input vector
• n_hidden is the state vector size of the GRU
• max_steps is the maximum number of steps $N$

GRU

$h_{n+1} = s_h(x, h_n)$

$\hat{y}_n = s_y(h_n)$

We could use a layer that takes the concatenation of $h$ and $x$ as input

but we went with this for simplicity.

$\lambda_n = s_\lambda(h_n)$

An option to set during inference so that computation is actually halted at inference time

• x is the input of shape [batch_size, n_elems]

This outputs a tuple of four tensors:

1. $p_1 \dots p_N$ in a tensor of shape [N, batch_size]
2. $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape [N, batch_size] - the log probabilities of the parity being $1$
3. $p_m$ of shape [batch_size]
4. $\hat{y}_m$ of shape [batch_size] where the computation was halted at step $m$

We get initial state $h_1 = s_h(x)$

Lists to store $p_1 \dots p_N$ and $\hat{y}_1 \dots \hat{y}_N$

$\prod_{j=1}^{n-1} (1 - \lambda_j)$

A vector to maintain which samples has halted computation

$p_m$ and $\hat{y}_m$ where the computation was halted at step $m$

Iterate for $N$ steps

The halting probability $\lambda_N = 1$ for the last step

$\lambda_n = s_\lambda(h_n)$

$\hat{y}_n = s_y(h_n)$

$p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j)$

Update $\prod_{j=1}^{n-1} (1 - \lambda_j)$

Halt based on halting probability $\lambda_n$

Collect $p_n$ and $\hat{y}_n$

Update $p_m$ and $\hat{y}_m$ based on what was halted at current step $n$

Update halted samples

Get next state $h_{n+1} = s_h(x, h_n)$

Stop the computation if all samples have halted

## Loss functions

Reconstruction loss

$L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)$

$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

• loss_func is the loss function $\mathcal{L}$
• p is $p_1 \dots p_N$ in a tensor of shape [N, batch_size]
• y_hat is $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape [N, batch_size, ...]
• y is the target of shape [batch_size, ...]

The total $\sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)$

Iterate upto $N$

$p_n \mathcal{L}(y, \hat{y}_n)$ for each sample and the mean of them

Regularization loss $L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big)$ $\mathop{KL}$ is the Kullback–Leibler divergence. $p_G$ is the Geometric distribution parameterized by $\lambda_p$. $\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper. $Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$. The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.

• lambda_p is $\lambda_p$ - the success probability of geometric distribution
• max_steps is the highest $N$; we use this to pre-compute $p_G(\lambda_p)$

# Empty vector to calculate $p_G(\lambda_p)$

$(1 - \lambda_p)^k$

Iterate upto max_steps

$Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$

Update $(1 - \lambda_p)^k$

Save $Pr_{p_G(\lambda_p)}$

KL-divergence loss

• p is $p_1 \dots p_N$ in a tensor of shape [N, batch_size]

Transpose p to [batch_size, N]

Get $Pr_{p_G(\lambda_p)}$ upto $N$ and expand it across the batch dimension

Calculate the KL-divergence. The PyTorch KL-divergence implementation accepts log probabilities.

# References

Cited as:

@article{mcateer2021dpe,
title = "Exploring DeepMind's PonderNet in JAX",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2021",
url = "https://matthewmcateer.me/blog/deepmind-pondernet/"
}

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!