# Exploring DeepMind's PonderNet

## How to get a neural network to pause and think more about an answer

| UPDATED

DeepMind recently released PonderNet. This is a technique that allows a neural network to change its computation complexity based on the complexity of the input sample.

## What is PonderNet?

## What is going on under the Hood?

There has been plenty of research (including by DeepMind itself) on the versatility of RNNs in creating networks that can perform complex tasks. This was the core of the Neural Turing Machine paper. The main advance of PonderNet is to tune just how many RNN cycles are needed for a given input. For simpler problems, the network is able to take fewer steps through the RNN. All of this is done with end-to-end gradient descent.

PonderNet has a step function of the form $\hat{y}_n, h_{n+1}, \lambda_n = s(x, h_n)$ where $x$ is the input, $h_n$ is the state, $\hat{y}_n$ is the prediction at step $n$, and $\lambda_n$ is the probability of halting (stopping) at current step. $s$ can be any neural network (e.g. LSTM, MLP, GRU, Attention layer). The unconditioned probability of halting at step $n$ is then, $p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j)$ That is the probability of not being halted at any of the previous steps and halting at step $n$. During inference, we halt by sampling based on the halting probability $\lambda_n$ and get the prediction at the halting layer $\hat{y}_n$ as the final output. During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer $p_n$. The step function is applied to a maximum number of steps donated by $N$. The overall loss of PonderNet is

$L = L_{Rec} + \beta L_{Reg} \\ L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n) \\ L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big)$$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

$\mathop{KL}$ is the Kullback–Leibler divergence.
$p_G$ is the Geometric distribution parameterized by
$\lambda_p$. *$\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper*.
$Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$.
The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivizes
non-zero probabilities for all steps; i.e. promotes exploration.

## Ponder Parity with GRU

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of `n_elems`

.

Each element of the vector is either `0`

, `1`

or `-1`

and the output is the parity - a binary value that is true if the number of `1`

s is odd and false otherwise. The prediction of the model is the log probability of the parity being $1$.

`n_elems`

is the number of elements in the input vector`n_hidden`

is the state vector size of the GRU`max_steps`

is the maximum number of steps $N$

GRU

$h_{n+1} = s_h(x, h_n)$

$\hat{y}_n = s_y(h_n)$

We could use a layer that takes the concatenation of $h$ and $x$ as input

but we went with this for simplicity.

$\lambda_n = s_\lambda(h_n)$

An option to set during inference so that computation is actually halted at inference time

`x`

is the input of shape`[batch_size, n_elems]`

This outputs a tuple of four tensors:

- $p_1 \dots p_N$ in a tensor of shape
`[N, batch_size]`

- $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape
`[N, batch_size]`

- the log probabilities of the parity being $1$ - $p_m$ of shape
`[batch_size]`

- $\hat{y}_m$ of shape
`[batch_size]`

where the computation was halted at step $m$

We get initial state $h_1 = s_h(x)$

Lists to store $p_1 \dots p_N$ and $\hat{y}_1 \dots \hat{y}_N$

$\prod_{j=1}^{n-1} (1 - \lambda_j)$

A vector to maintain which samples has halted computation

$p_m$ and $\hat{y}_m$ where the computation was halted at step $m$

Iterate for $N$ steps

The halting probability $\lambda_N = 1$ for the last step

$\lambda_n = s_\lambda(h_n)$

$\hat{y}_n = s_y(h_n)$

$p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j)$

Update $\prod_{j=1}^{n-1} (1 - \lambda_j)$

Halt based on halting probability $\lambda_n$

Collect $p_n$ and $\hat{y}_n$

Update $p_m$ and $\hat{y}_m$ based on what was halted at current step $n$

Update halted samples

Get next state $h_{n+1} = s_h(x, h_n)$

Stop the computation if all samples have halted

## Loss functions

Reconstruction loss

$L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)$

$\mathcal{L}$ is the normal loss function between target $y$ and prediction $\hat{y}_n$.

`loss_func`

is the loss function $\mathcal{L}$`p`

is $p_1 \dots p_N$ in a tensor of shape`[N, batch_size]`

`y_hat`

is $\hat{y}_1 \dots \hat{y}_N$ in a tensor of shape`[N, batch_size, ...]`

`y`

is the target of shape`[batch_size, ...]`

The total $\sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)$

Iterate upto $N$

$p_n \mathcal{L}(y, \hat{y}_n)$ for each sample and the mean of them

Add to total loss

Regularization loss
$L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big)$
$\mathop{KL}$ is the Kullback–Leibler divergence.
$p_G$ is the Geometric distribution parameterized by
$\lambda_p$. *$\lambda_p$ has nothing to do with $\lambda_n$; we are just sticking to same notation as the paper*.
$Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$.
The regularization loss biases the network towards taking $\frac{1}{\lambda_p}$ steps and incentivies non-zero probabilities
for all steps; i.e. promotes exploration.

`lambda_p`

is $\lambda_p$ - the success probability of geometric distribution`max_steps`

is the highest $N$; we use this to pre-compute $p_G(\lambda_p)$

# Empty vector to calculate $p_G(\lambda_p)$

$(1 - \lambda_p)^k$

Iterate upto `max_steps`

$Pr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p$

Update $(1 - \lambda_p)^k$

Save $Pr_{p_G(\lambda_p)}$

KL-divergence loss

`p`

is $p_1 \dots p_N$ in a tensor of shape`[N, batch_size]`

Transpose `p`

to `[batch_size, N]`

Get $Pr_{p_G(\lambda_p)}$ upto $N$ and expand it across the batch dimension

Calculate the KL-divergence.
*The PyTorch KL-divergence
implementation accepts log probabilities.*

## Training

# References

- PonderNet in PyTorch
- Annotated Deep Learning Implementation
- DeepMind unveils PonderNet, just please don’t call it ‘pondering’

Cited as:

```
@article{mcateer2021dpe,
title = "Exploring DeepMind's PonderNet in JAX",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2021",
url = "https://matthewmcateer.me/blog/deepmind-pondernet/"
}
```

*If you notice mistakes and errors in this post, don’t hesitate to contact me at [*

`contact at matthewmcateer dot me`

] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.See you in the next post 😄