Exploring DeepMind's PonderNet

How to get a neural network to pause and think more about an answer

DeepMind recently released PonderNet. This is a technique that allows a neural network to change its computation complexity based on the complexity of the input sample.

What is PonderNet?

What is going on under the Hood?

There has been plenty of research (including by DeepMind itself) on the versatility of RNNs in creating networks that can perform complex tasks. This was the core of the Neural Turing Machine paper. The main advance of PonderNet is to tune just how many RNN cycles are needed for a given input. For simpler problems, the network is able to take fewer steps through the RNN. All of this is done with end-to-end gradient descent.

PonderNet has a step function of the form y^n,hn+1,λn=s(x,hn)\hat{y}_n, h_{n+1}, \lambda_n = s(x, h_n) where xx is the input, hnh_n is the state, y^n\hat{y}_n is the prediction at step nn, and λn\lambda_n is the probability of halting (stopping) at current step. ss can be any neural network (e.g. LSTM, MLP, GRU, Attention layer). The unconditioned probability of halting at step nn is then, pn=λnj=1n1(1λj)p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j) That is the probability of not being halted at any of the previous steps and halting at step nn. During inference, we halt by sampling based on the halting probability λn\lambda_n and get the prediction at the halting layer y^n\hat{y}_n as the final output. During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer pnp_n. The step function is applied to a maximum number of steps donated by NN. The overall loss of PonderNet is

L=LRec+βLRegLRec=n=1NpnL(y,y^n)LReg=KL(pnpG(λp))L = L_{Rec} + \beta L_{Reg} \\ L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n) \\ L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big)

L\mathcal{L} is the normal loss function between target yy and prediction y^n\hat{y}_n.

KL\mathop{KL} is the Kullback–Leibler divergence. pGp_G is the Geometric distribution parameterized by λp\lambda_p. λp\lambda_p has nothing to do with λn\lambda_n; we are just sticking to same notation as the paper. PrpG(λp)(X=k)=(1λp)kλpPr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p. The regularization loss biases the network towards taking 1λp\frac{1}{\lambda_p} steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.

Ponder Parity with GRU

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of n_elems.

Each element of the vector is either 0, 1 or -1 and the output is the parity - a binary value that is true if the number of 1s is odd and false otherwise. The prediction of the model is the log probability of the parity being 11.

  • n_elems is the number of elements in the input vector
  • n_hidden is the state vector size of the GRU
  • max_steps is the maximum number of steps NN

GRU

hn+1=sh(x,hn)h_{n+1} = s_h(x, h_n)

y^n=sy(hn)\hat{y}_n = s_y(h_n)

We could use a layer that takes the concatenation of hh and xx as input

but we went with this for simplicity.

λn=sλ(hn)\lambda_n = s_\lambda(h_n)

An option to set during inference so that computation is actually halted at inference time

  • x is the input of shape [batch_size, n_elems]

This outputs a tuple of four tensors:

  1. p1pNp_1 \dots p_N in a tensor of shape [N, batch_size]
  2. y^1y^N\hat{y}_1 \dots \hat{y}_N in a tensor of shape [N, batch_size] - the log probabilities of the parity being 11
  3. pmp_m of shape [batch_size]
  4. y^m\hat{y}_m of shape [batch_size] where the computation was halted at step mm

We get initial state h1=sh(x)h_1 = s_h(x)

Lists to store p1pNp_1 \dots p_N and y^1y^N\hat{y}_1 \dots \hat{y}_N

j=1n1(1λj)\prod_{j=1}^{n-1} (1 - \lambda_j)

A vector to maintain which samples has halted computation

pmp_m and y^m\hat{y}_m where the computation was halted at step mm

Iterate for NN steps

The halting probability λN=1\lambda_N = 1 for the last step

λn=sλ(hn)\lambda_n = s_\lambda(h_n)

y^n=sy(hn)\hat{y}_n = s_y(h_n)

pn=λnj=1n1(1λj)p_n = \lambda_n \prod_{j=1}^{n-1} (1 - \lambda_j)

Update j=1n1(1λj)\prod_{j=1}^{n-1} (1 - \lambda_j)

Halt based on halting probability λn\lambda_n

Collect pnp_n and y^n\hat{y}_n

Update pmp_m and y^m\hat{y}_m based on what was halted at current step nn

Update halted samples

Get next state hn+1=sh(x,hn)h_{n+1} = s_h(x, h_n)

Stop the computation if all samples have halted

Loss functions

Reconstruction loss

LRec=n=1NpnL(y,y^n)L_{Rec} = \sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)

L\mathcal{L} is the normal loss function between target yy and prediction y^n\hat{y}_n.

  • loss_func is the loss function L\mathcal{L}
  • p is p1pNp_1 \dots p_N in a tensor of shape [N, batch_size]
  • y_hat is y^1y^N\hat{y}_1 \dots \hat{y}_N in a tensor of shape [N, batch_size, ...]
  • y is the target of shape [batch_size, ...]

The total n=1NpnL(y,y^n)\sum_{n=1}^N p_n \mathcal{L}(y, \hat{y}_n)

Iterate upto NN

pnL(y,y^n)p_n \mathcal{L}(y, \hat{y}_n) for each sample and the mean of them

Add to total loss

Regularization loss LReg=KL(pnpG(λp))L_{Reg} = \mathop{KL} \Big(p_n \Vert p_G(\lambda_p) \Big) KL\mathop{KL} is the Kullback–Leibler divergence. pGp_G is the Geometric distribution parameterized by λp\lambda_p. λp\lambda_p has nothing to do with λn\lambda_n; we are just sticking to same notation as the paper. PrpG(λp)(X=k)=(1λp)kλpPr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p. The regularization loss biases the network towards taking 1λp\frac{1}{\lambda_p} steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.

  • lambda_p is λp\lambda_p - the success probability of geometric distribution
  • max_steps is the highest NN; we use this to pre-compute pG(λp)p_G(\lambda_p)

Empty vector to calculate pG(λp)p_G(\lambda_p)

(1λp)k(1 - \lambda_p)^k

Iterate upto max_steps

PrpG(λp)(X=k)=(1λp)kλpPr_{p_G(\lambda_p)}(X = k) = (1 - \lambda_p)^k \lambda_p

Update (1λp)k(1 - \lambda_p)^k

Save PrpG(λp)Pr_{p_G(\lambda_p)}

KL-divergence loss

  • p is p1pNp_1 \dots p_N in a tensor of shape [N, batch_size]

Transpose p to [batch_size, N]

Get PrpG(λp)Pr_{p_G(\lambda_p)} upto NN and expand it across the batch dimension

Calculate the KL-divergence. The PyTorch KL-divergence implementation accepts log probabilities.

Training

References

Cited as:

@article{mcateer2021dpe,
    title = "Exploring DeepMind's PonderNet in JAX",
    author = "McAteer, Matthew",
    journal = "matthewmcateer.me",
    year = "2021",
    url = "https://matthewmcateer.me/blog/deepmind-pondernet/"
}

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

At least this isn't a full-screen popup

That'd be more annoying. Anyways, subscribe to my newsletter to get new posts by email! I write about AI, Biotech, and a bunch of other topics.


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.