Using TPUs in Google Colab (properly)
A simple way to use TPUs with minimal hardware optimization
Early on in your machine learning journey you might be content with your lone GTX 1070. At some point, you may find out about Google’s Tensor Processing Units (TPUs) and get excited.

Many who reach this stage might then become confused after reading more on the subject of TPUs.
- “Are these just re-marketted GPUs?”
- “How do I actually use them?”
- “Why is my model not training faster than it would with just a GPU or CPU?”
- “How much will this cost me?”
I’d heard this exact train of thought in multiple conversations with other ML engineers, so I figured I’d put together a quick guide to getting from 0 to 1 (on TPUs).
There are plenty of articles out there detailing the differences between TPUs. One of the main challenges that can take some getting used to is that you need to specify the hardware itself. Many parts of Tensorflow may automatically grow with avaliable GPU resources, but it’s not that simple with TPUs. As such, I’ve made sure to emphasize what parts of your tf.keras
model-building workflow will be different in the context of using one of the free TPUs on Google Colab.
If you’re interested in trying the code for yourself, you can follow along in the full Colab Notebook right here.
First steps
When you first enter the Colab, you want to make sure you specify the runtime environment. Go to Runtime, click “Change Runtime Type”, and set the Hardware accelerator to “TPU”.

First, let’s set up our model. We follow the usual imports for setting up our tf.keras
model training. It’s important to not that using TPUs results in the biggest performance increases with tf.keras
, less so the original keras
.
Other than that, we can use any arbitrary model we want, but for demonstration purposes we’ll use a large, out-of-the-box model. VGG is not the most up-to-date classification model, but it is large enough that we should see some dramatic differences in performance (if there is an improvement).

model = VGG16Net(input_shape=X_tr.shape[1:], classes=1001, batch_size=batch_size)
model.compile(loss='categorical_crossentropy',
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3),
metrics=['accuracy'])
You just specified an incredibly deep neural network using less than 5 lines of code.
configuring the TPU
Now for the part you came to see: In order to set up the TPUs, you want to use the tf.contrib.tpu.keras_to_tpu_model
function to convert your tf.keras
model to an equivalent TPU version (again, it’s important that this is a keras
model and not a tf.keras
model).
import os
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
strategy = tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER))
tpu_model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
tpu_model.summary()
Now, you can actually use the TPUs to fit the model with the regular .fit()
method. It’s important to note that the batch_size
is equal to the model batch_size
the TPU number (which is 8). This is also a crucial step to keep in mind, or else your model training will be very … anticlimactic.
tpu_number = 8
batch_size=128 * tpu_number
history = tpu_model.fit(X_tr, Y_tr,
epochs=100,
batch_size=batch_size,
validation_split=0.2)
tpu_model.evaluate(X_val, Y_val, batch_size=batch_size)
tpu_model.save_weights('./our_tpu_model.h5')
Using the model on GPUs and CPUs
Now that the model has been trained and we have saved the weights, we probably want to query the model on hardware that’s less expensive and/or rare as a TPU. We can take the make_model()
function from before and use it like so to allow our model to evaluate data that doesn’t necessarily line up with the batch size.
evaluation_model = make_model(batch_size=None)
evaluation_model.load_weights('./our_tpu_model.h5')
evaluation_model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Input (InputLayer) (None, 500) 0
_________________________________________________________________
Embedding (Embedding) (None, 500, 128) 1280000
_________________________________________________________________
LSTM (LSTM) (None, 32) 20608
_________________________________________________________________
Output (Dense) (None, 1) 33
=================================================================
Now you’re free to call evaluation_model.evaluate()
for evaluation, evaluation_model.fit()
for transfer learning and fine-tuning, and even evaluation_model.loss
, evaluation_model.input
, evaluation_model.output
if you want to use just pieces of the trained keras models
Next Steps
This was obviously an incrediby minimal tutorial for TPU use. The free TPUs on Google Colab are pretty exciting when you first find out how to make the most of them. Of course, at some point, you may decide you want to train models outside of Google Colab, or use multiple TPUs.
For some of you reading this, this tutorial might be all you needed. If you’re trying to build much larger models that go beyond the limits of Google Colab’s RAM and Disk caps, I’d recommend the following resources:
- Google Cloud TPU Documentation - This will go far more in-depth than this simple tutorial.
- Google Cloud TPU Performance Guide - Also great if you want to get performance gains like the ones I described, but on non-standard models
- Google Cloud TPU Troubleshooting guide - You probably don’t want to skip this.
- XLA (Accelerated Linear Algebra) Overview - Just what XLA is, and when’s a good time to use it.
- Using TPUs with Kubernetes - For training many different models (like doing rapid-hyperparameter-optimization), or training just one really big model.
- Google Cloud TPU Pricing - Cost breakdown for TPU usage (depends on whether you want to schedule the TPU or use it instantly)
- PyTorch/XLA - A Repo of Pytorch Entusiasts working on making it possible to train your PyTorch models on TPUs. It’s not quite ready for primetime, but if you’re a PyTorch fan I’d recommend paying attention to this one.
Cited as:
@article{mcateer2019tpucolab,
title = "Using TPUs in Google Colab (properly)",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2019",
url = "https://matthewmcateer.me/blog/using-tpus-in-google-colab/"
}
If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me]
and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.
See you in the next post 😄