Getting started with Attention for Classification

A quick guide on how to start using Attention in your NLP models

UPDATE 05/23/2020: If you’re looking to add Attention-based models like Transformers or even BERT, a recent Keras update has added more support for libraries from HuggingFace 🤗. You can see more of this tutorial in the Keras documentation. That being said, I highly recommend becoming familiar with how you would put together an attention mechanism from scratch, just like I recommend you do for any new machine learning tool.

With all the hype around attention, I realized there were far too few decent guides on how to get started. With that in mind, I present to you the “Hello World” of attention models: building text classification models in Keras that use an attention mechanism.

Step 1: Preparing the Dataset

For this guide we’ll use the standard IMDB dataset that contains the text of 50,000 movie reviews from the Internet Movie Database (basically Jeff Bezos’ Rotten Tomatoes competitor). The IMDB dataset usually comes pre-packaged with Keras. If we download it this way we will get a version that has already been preprocessed such that the sequences of words have been converted to sequences of integers, where each integer represents a specific word in a dictionary. However for our purposes, we’re going to take data directly from a CSV file of unprocessed reviews.

from importlib import  reload
import sys
from imp import  reload
import warnings
if sys.version[0] == '2':
import nltk'stopwords')'wordnet')
import re
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Concatenate, Dense, Input, LSTM, Embedding, Dropout, Activation, GRU, Flatten
from tensorflow.keras.layers import Bidirectional, GlobalMaxPool1D
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution1D
from tensorflow.keras import initializers, regularizers, constraints, optimizers, layers

It’s rarely the case that we get preprocessed numerical sequences representing text data. As such, we should pay attention to how our text is converted to the numbers that our attention model can understand.

For our raw text, we need to do some filtering before our model-building. In addition to filtering out punctuation characters, we also want to make everything lowercase. We also want to reduce words to just their root (e.g., so words like “jump” and “jumping” aren’t given widely different encodings).

import pandas as pd
df1 = pd.read_csv('labeledTrainData.tsv', delimiter="\t")
df1 = df1.drop(['id'], axis=1)
df2 = pd.read_csv('imdb_master.csv',encoding="latin-1")
df2 = df2.drop(['Unnamed: 0','type','file'],axis=1)
df2.columns = ["review","sentiment"]
df2 = df2[df2.sentiment != 'unsup']
df2['sentiment'] = df2['sentiment'].map({'pos': 1, 'neg': 0})
df = pd.concat([df1, df2]).reset_index(drop=True)
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
def clean_text(text):
    text = re.sub(r'[^\w\s]','',text, re.UNICODE)
    text = text.lower()
    text = [lemmatizer.lemmatize(token) for token in text.split(" ")]
    text = [lemmatizer.lemmatize(token, "v") for token in text]
    text = [word for word in text if  not word in stop_words]
    text = " ".join(text)
    return text
df['Processed_Reviews'] = x: clean_text(x))
0With all this stuff going down at the moment w…1stuff go moment mj ive start listen music watc…
1\The Classic War of the Worlds” by Timothy Hi…1classic war world timothy hines entertain film…
2The film starts with a manager (Nicholas Bell)…0film start manager nicholas bell give welcome …
3It must be assumed that those who praised this…0must assume praise film greatest film opera ev…
4Superbly trashy and wondrously unpretentious 8…1superbly trashy wondrously unpretentious 80 ex…

If you’ve done coding challenges with strings outside of machine learning, you can probably see just how much easier this makes our task. Beyond just processing and encoding words, we also want to make sure our sequences are properly padded. In other words, we want to set a pre-defined size for our model’s inputs (the same model that is going to be fed sentences of variable length). For our purposes, our padding method will involve filling in a sentence with empty indicators if our review is too short, and cutting off a sentence if it is too long. One good way of doing this is to use the mean sequence length.

df.Processed_Reviews.apply(lambda x: len(x.split(" "))).mean()

The built-in pad_sequences() function takes care padding sequences. All that’s needed now is the MAX_LEN argument which will determine the length of the output arrays. As mentione before, if sentences are shorter than this length, they will be padded. If they are longer than this value, they will be trimmed.

tokenizer = Tokenizer(num_words=MAX_FEATURES)
list_tokenized_train = tokenizer.texts_to_sequences(df['Processed_Reviews'])
MAX_LEN = 130  # Since our mean length is 128.5
X_train = pad_sequences(list_tokenized_train, maxlen=MAX_LEN)
y_train = df['sentiment']

Step 2: Creating the Attention Layer

Our use of an attention layer solves a conundrum with using RNNs. We can easily use the final encoded state of a recurrent neural network for a prediction task. However, given the tendency of RNNs to forget relevant information in the previous steps of the sequence, this could lose some of the useful information encoded there. In order to keep that information, we can use an average of the encoded states the RNN outputs. Since all these encoded states of the RNN are equally valuable, we use a weighted sum of these encoded states (i.e., our Attention mechanism) to make our prediction.

class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        score = tf.nn.tanh(
            self.W1(features) + self.W2(hidden_with_time_axis))
        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

If this concept of using a weighted sum of a bunch of encodings sounds familiar, it should. We’re computing these attention weights simply by building a small fully connected neural network on top of each encoded state. This network will have a single-unit final output layer that will correspond to the attention weight we will assign.

An overview of the specific attention mechanism that we’re using, which is additive attention

Our Attention function is very simple, just dense layers back to back plus a tanh function. Of course, Attention is used in many applications in NLP (and beyond). There are plenty of other more specialized types. Lilian Weng gave a very concise overview of the histories and different types of attention out there (summarized in the table below):

NameAlignment score functionCitationNotes
Content-based attentionscore(st,hi)=cosine[st,hi]\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \text{cosine}[\boldsymbol{s}_t, \boldsymbol{h}_i]Graves2014
Additivescore(st,hi)=vatanh(Wa[st;hi])\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \mathbf{v}_a^\top \tanh(\mathbf{W}_a[\boldsymbol{s}_t; \boldsymbol{h}_i])Bahdanau2015Referred to as “concat” in Luong, et al., 2015 and as “additive attention” in Vaswani, et al., 2017
Location- Basedαt,i=softmax(Wast)\alpha_{t,i} = \text{softmax}(\mathbf{W}_a \boldsymbol{s}_t) Note: This simplifies the softmax alignment to only depend on the target position.Luong2015
Generalscore(st,hi)=stWahi\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \boldsymbol{s}_t^\top\mathbf{W}_a\boldsymbol{h}_i where WaW_a is a trainable weight matrix in the attention layer.Luong2015
Dot-Productscore(st,hi)=sthi\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \boldsymbol{s}_t^\top\boldsymbol{h}_iLuong2015
Scaled Dot- Productscore(st,hi)=sthin\text{score}(\boldsymbol{s}_t, \boldsymbol{h}_i) = \frac{\boldsymbol{s}_t^\top\boldsymbol{h}_i}{\sqrt{n}} Very similar to the dot-product attention except for a scaling factor; where nn is the dimension of the source hidden state.Vaswani2017It adds a scaling factor 1/n1 / \sqrt{n}, motivated by the concern when the input is large, the softmax function may have an extremely small gradient, hard for efficient learning.

But for now, all we need is the simple attention layer that’s little more than a 3-layer multi-layer-perceptron (i.e., Bahdanau Attetnion).

Step 3: The Embedding Layer

Neural networks are the composition of operators from linear algebra and non-linear activation functions. In order to perform these computations on our input sentences, we must first embed them as a vector of numbers. There are three main approaches to perform this embedding pre-trained embeddings like Word2Vec or GloVe or randomly initializing. For the sake of simplicity, we’re going to stick with random initialization.

To perform this embedding we use the Embedding function from the layers package. The parameters of this matrix will then be trained with the rest of the graph.

sequence_input = Input(shape=(MAX_LEN,), dtype="int32")
embedded_sequences = Embedding(MAX_FEATURES, EMBED_SIZE)(sequence_input)

Step 4: Our Bi-directional RNN

We will be using a bi-directional RNN instead of a vanilla unidirectional RNN. Despite the fancy name this is simply the concatentation of two RNNs. One RNN processes the sequence from left to right (the “forward” RNN), while the other processes the sequence from right to left (the “backward” RNN). By using both directions, we get a more reliable encoding as each word can be given the context of its neighbors on both sides (rather than just earlier in the sequence).

lstm = Bidirectional(LSTM(RNN_CELL_SIZE, return_sequences = True), name="bi_lstm_0")(embedded_sequences)

# Getting our LSTM outputs
(lstm, forward_h, forward_c, backward_h, backward_c) = Bidirectional(LSTM(RNN_CELL_SIZE, return_sequences=True, return_state=True), name="bi_lstm_1")(lstm)

Since our model uses a bi-directional RNN, we first concatenate the hidden states from each RNN before computing the attention weights and applying the weighted sum.

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])
context_vector, attention_weights = Attention(10)(lstm, state_h)
dense1 = Dense(20, activation="relu")(context_vector)
dropout = Dropout(0.05)(dense1)
output = Dense(1, activation="sigmoid")(dropout)
model = keras.Model(inputs=sequence_input, outputs=output)

The last layer is densely connected with a single output node. Using the sigmoid activation function, this value is a float between 0 and 1, representing a probability, or confidence level. We can easily print out a list of our layers in Keras.

# summarize layers

Less than 1 million parameters for an Attention-based language model. This will probably be on the far low end of most Attention models you see.

Or alternatively, we can build a much more aesthetically-appealing graph of the connected layers.

keras.utils.plot_model(model, show_shapes=True, dpi=90)

Much easier than trying to set up Tensorboard within your notebook

Step 5: Compiling the Model

In order to actually train our model we need to give it a loss function and an optimizer. An out-of-the-box Adam optimizer will also be what we use to optimize our model. Since our model is a binary classification problem and the model outputs a probability we’ll use the standard binary_crossentropy loss function.

Of course, we can also go beyond just accuracy and loss. Here is a comprehensive look at all the metrics we can add to our optimization criteria:

  • T.P. (True Positives) - This is just a raw count of the number of positive items that are correctly classified as so.
  • F.P. (False Positives) - This is the raw count of false positives.
  • T.N. (True Negatives) - This is just a raw count of the number of positive items that are correctly classified as so.
  • F.N. (False Negatives) - This is the raw count of false negatives.
  • Binary Accuracy - This is a calculation of how often the predictions matches the labels (out of two possible options, 0 or 1).
  • Precision - This is a calculation of precision of the predictions with respect to the labels.
  • Recall - This is a calculation of the recall of the predictions with respect to the labels.
  • AUC - Computes the approximate AUC (Area under the curve) via a Riemann sum, all from the true positive, false positive, true negative, and false negatives above.

For the sake of model improvement, our optimizer will only focus on one of these metrics for now (namely accuracy).



Step 6: Training the Model

We’ll train our Attention model for 55 epochs in mini-batches of 100100 samples. You read that right. Not 5050 or 250250, just 55 is all we need. This is 55 iterations over all samples in the X_train and y_train tensors. While training, monitor the model’s loss and accuracy on the 2020 percent samples from the validation set.

history =,y_train,
Train on 60000 samples, validate on 15000 samples
Epoch 1/5
60000/60000 [==============================] - 409s 7ms/sample - loss: 0.3186 - tp: 19882.0000 - fp: 3309.0000 - tn: 31691.0000 - fn: 5118.0000 - accuracy: 0.8595 - precision: 0.8573 - recall: 0.7953 - auc: 0.9358 - val_loss: 0.2049 - val_tp: 11501.0000 - val_fp: 214.0000 - val_tn: 2286.0000 - val_fn: 999.0000 - val_accuracy: 0.9191 - val_precision: 0.9817 - val_recall: 0.9201 - val_auc: 0.9732
Epoch 2/5
60000/60000 [==============================] - 406s 7ms/sample - loss: 0.2153 - tp: 22233.0000 - fp: 2318.0000 - tn: 32682.0000 - fn: 2767.0000 - accuracy: 0.9153 - precision: 0.9056 - recall: 0.8893 - auc: 0.9703 - val_loss: 0.1705 - val_tp: 11641.0000 - val_fp: 164.0000 - val_tn: 2336.0000 - val_fn: 859.0000 - val_accuracy: 0.9318 - val_precision: 0.9861 - val_recall: 0.9313 - val_auc: 0.9818
Epoch 3/5
60000/60000 [==============================] - 405s 7ms/sample - loss: 0.1716 - tp: 22853.0000 - fp: 1755.0000 - tn: 33245.0000 - fn: 2147.0000 - accuracy: 0.9350 - precision: 0.9287 - recall: 0.9141 - auc: 0.9807 - val_loss: 0.1785 - val_tp: 11547.0000 - val_fp: 102.0000 - val_tn: 2398.0000 - val_fn: 953.0000 - val_accuracy: 0.9297 - val_precision: 0.9912 - val_recall: 0.9238 - val_auc: 0.9866
Epoch 4/5
60000/60000 [==============================] - 406s 7ms/sample - loss: 0.1378 - tp: 23333.0000 - fp: 1378.0000 - tn: 33622.0000 - fn: 1667.0000 - accuracy: 0.9492 - precision: 0.9442 - recall: 0.9333 - auc: 0.9871 - val_loss: 0.1635 - val_tp: 11635.0000 - val_fp: 46.0000 - val_tn: 2454.0000 - val_fn: 865.0000 - val_accuracy: 0.9393 - val_precision: 0.9961 - val_recall: 0.9308 - val_auc: 0.9919
Epoch 5/5
60000/60000 [==============================] - 407s 7ms/sample - loss: 0.1049 - tp: 23816.0000 - fp: 986.0000 - tn: 34014.0000 - fn: 1184.0000 - accuracy: 0.9638 - precision: 0.9602 - recall: 0.9526 - auc: 0.9916 - val_loss: 0.0723 - val_tp: 12289.0000 - val_fp: 109.0000 - val_tn: 2391.0000 - val_fn: 211.0000 - val_accuracy: 0.9787 - val_precision: 0.9912 - val_recall: 0.9831 - val_auc: 0.9947

Step 7: Evaluating the Model

Our model seems to have gotten some impressive results, no less after only 5 training epochs. Let’s take a closer look at our training progress and performance. Normally when collecting a training history we would be fine with just two values: Loss and accuracy. However, we went far beyond just those.

For evaluation purposes, we can generate a list of predictions on previously unseen test data (we can also use the preprocessing functions from earlier to easily convert the Test data to a usable form).

# Loading the test dataset, and repeating the processing steps
df_test=pd.read_csv("testData.tsv",header=0, delimiter="\t", quoting=3)
df_test["review"] x: clean_text(x))
df_test["sentiment"] = df_test["id"].map(lambda x: 1  if  int(x.strip('"').split("_")[1]) >= 5  else  0)
y_test = df_test["sentiment"]
list_sentences_test = df_test["review"]
list_tokenized_test = tokenizer.texts_to_sequences(list_sentences_test)
X_test = pad_sequences(list_tokenized_test, maxlen=MAX_LEN)
## Making predictions on our model
prediction = model.predict(X_test)
y_pred = (prediction > 0.5) returns a handy History object that contains a dictionary with everything that happened during training:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (classification_report,
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
report = classification_report(y_test, y_pred)
def plot_cm(labels, predictions, p=0.5):
    cm = confusion_matrix(labels, predictions)
    plt.figure(figsize=(5, 5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title("Confusion matrix (non-normalized))")
    plt.ylabel("Actual label")
    plt.xlabel("Predicted label")
plot_cm(y_test, y_pred)

Now that’s a fine-looking report

Overall our classifier is doing pretty well, a 0.970.97 F1-score on the test dataset.

Let’s take a closer look at hour our laundry-list of metrics fared over the entire training process.

# Cross Validation Classification Accuracy
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
mpl.rcParams["figure.figsize"] = (12, 18)
def plot_metrics(history):
    metrics = [
        "tp", "fp", "tn", "fn",
        "precision", "recall",
    for n, metric in  enumerate(metrics):
        name = metric.replace("_", " ").capitalize()
        plt.subplot(5, 2, n + 1)
            history.history["val_" + metric],
        if metric == "loss":
            plt.ylim([0, plt.ylim()[1] * 1.2])
        elif metric == "accuracy":
            plt.ylim([0.4, 1])
        elif metric == "fn":
            plt.ylim([0, plt.ylim()[1]])
        elif metric == "fp":
            plt.ylim([0, plt.ylim()[1]])
        elif metric == "tn":
            plt.ylim([0, plt.ylim()[1]])
        elif metric == "tp":
            plt.ylim([0, plt.ylim()[1]])
        elif metric == "precision":
            plt.ylim([0, 1])
        elif metric == "recall":
            plt.ylim([0.4, 1])
            plt.ylim([0, 1])

That’s right. All this progress was made in just 5 epochs

True Negatives and True Positives rise, False Postives and False Negatives fall, Accuracy and AUC steadily rise, and the Precision and Recall tend towards 1.0. This was pretty good for a first try.

We also want to be sure to calculate the Receiver Operating Characteristic (ROC) metric as another test of the classifier output quality. Plots of ROC curves usually feature the false positive rate on the X axis, and the true positive rate on the Y axis. Since we are aiming for a false positive rate of zero, and a true positive rate of one, our ideal point is the top left corner of the plot. In practice reaching this level of quality is impossible, but it at least gives us metrics like area under the curve (AUC) that we can steadily improve. We can also judge the classifier quality by the ‘steepness’ of the ROC curves, another manifestation of a high true positive rate and minimal false positive rate.

import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
mpl.rcParams["figure.figsize"] = (6, 6)
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from scipy import interp
from sklearn.metrics import roc_auc_score
# Binarize the output
y_bin = label_binarize(y_test, classes=[0, 1])
n_classes = 1
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in  range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test.ravel(), y_pred.ravel())
    roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
lw = 2
plt.plot(fpr[0], tpr[0], color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[0])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")

Almost to the corner

So our final model has an exceptionally low false positive and false negative rate. This is great, though we’re a bit far from making a Roger Ebert replacement. Thus far I’ve only shown you a guide on using Attention mechanisms for a basic classification task. For tasks like Natural Language Generation, Translation, or Named Entity description, we’ll need to go into more nuances of attention, namely the Encoder-Decoder framework. But, that’s a topic for another post.


Cited as:

  title   = "Getting started with Attention for Classification",
  author  = "McAteer, Matthew",
  journal = "",
  year    = "2018",
  url     = ""

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

At least this isn't a full-screen popup

That'd be more annoying. Anyways, subscribe to my newsletter to get new posts by email! I write about AI, Biotech, and a bunch of other topics.

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.