Heirarchal Temporal Memory (with Keras example)

Making ANNs that resemble actual neural networks in the brain

Machine learning has many different schools of thought, ranging from the connectionists (Neural networks), to bayesians (Markov logic networks & bayesian inference). While the current AI craze has been driven by the success of deep neural networks, both schools have made a lot of progress. Still, there are some areas where modern machine learning algorithms do poorly:

  • Unsupervised data models (this still hasn’t progressed as much as supervised learning)
  • Understanding Context – NLP systems can calculate probabilities of words appearing after or before one another, but understanding the context is still something they struggle at.
  • “High-priority” data - neural networks don’t have “reflexes”, or systems that can shorten the time needed to make a decision
  • Continuously streamed data – AI is still bad at live learning, learning from data that is being continuously ingested
  • One shot learning – neural networks still require very large datasets of categories to get meaningful information out of them

One possible reason behind this is that artificial neural networks do not actually resemble biological neurons all that much.

What are HTMs?

Hierarchal Temporal Memory is a theoretical framework developed by Numenta. Numenta is a company founded in 2005 by Jeff Hawkins (of Palm Pilot fame). Numenta’s goal is artifical intelligence research. Specifically, it is devoted to research inspired by neuroscience, by the functional behavior of the neocortex itself. Form this neurology-inspired research, Numenta has created a machine learning framework called Hierarchical Temporal Memory (HTM).

HTMs differ from regular neural networks in many ways. HTMs are better than ANNs at unsupervised learning, they can learn continuously rather than relying on batchwise learning, and HTM “neurons” can occupy three states instead of two (adding on a “predictive” state instead of “active/inactive”. HTMs can also deal with much smaller datasets than ANNs.

According to Numenta, HTMs can be expected to outperform regular neural netowrks when a problem has one or more of the following qualities:

  1. The model needs to learn continuously
  2. The data has a high velocity
  3. The data is temporal
  4. There are multiple data sources and inputs (e.g., image and audio together)
  5. The data is unlablelled

With this in mind, Numenta has created several products based on HTMs with the intent of proving their effectiveness in the real world.

Inner workings of HTMs

When it comes to how HTMs actually work, there are a few important concepts to understand:

  1. Sparse Distributed Representation (SDR)
  2. Semantic Encoding
  3. Spatial Pooling
  4. Hebbian Learning
  5. Boosting
  6. Temporal Memory

The entire algorithm can be a little overwhelming without visual simulations. So I strongly recommend that you check out the online videos published by Numenta that have some very cool simulations of the process.

Building an HTM-based model

What I like about these models of cognition is that they focus on translating these models into actionable, automatable steps. To demonstrate this, I tried building an HTM-based classifier that can separate MNIST digits. Numenta has already produced a well-supported library for using HTMs in python (nupic), but just using high-level APIs is inadequate for fully mastering a topic. That’s why we’re going to build this in Keras.

The frst step was defining the spatial pooling component of the HTM. This was build from what was effectively a slimmed-down clone of the Layer class in Keras.

import numpy as np
import tensorflow as tf

def one_hot(i, nb_classes):
    arr = np.zeros(nb_classes)
    arr[i] = 1
    return arr


class Layer:
    def __init__(self):
        self.is_built = False
        self.train_ops = []

    def build(self, input_shape):
        assert self.is_built == False
        self.is_built = True

    def call(self, x):
        pass

    def train(self, x, y):
        pass

    def __call__(self, x):
        if not self.is_built:
            self.build(x.get_shape().as_list())

        y = self.call(x)
        self.train_ops.append(self.train(x, y))
        return y

From the Layer class, we can build our spatial pooling computation layer. This takes in several arguments:

  • output_dim: Size of the output dimension
  • sparsity: The target sparsity to achieve
  • lr: The learning rate in which permenance is updated (still keeping that part of regular neural networks)
  • pool_density: Percent of input a cell is connected to on average.
  • duty_cycle: Used for calculating new activation averages during training
  • boost_strength: What it says on the tin, how much we’re boosting
class SpatialPooler(Layer):
    def __init__(self, output_dim, sparsity=0.02, lr=1e-2, pool_density=0.9,
                 duty_cycle=1000, boost_strength=100, **kwargs):
        self.output_dim = output_dim
        self.sparsity = sparsity
        self.lr = lr
        self.pool_density = pool_density
        self.duty_cycle = duty_cycle
        self.boost_strength = boost_strength
        self.top_k = int(np.ceil(self.sparsity * np.prod(self.output_dim)))
        super().__init__(**kwargs)

When building the SpatialPooler class, we will specify the permenance of the connections between the neurons, the potential pool matrix (making sure it masks out connections randomly), and the connection matrix (if permenance > 0.5, a connnection is established). We also want to create a variable for the time-averaged activation level for each mini-column

    def build(self, input_shape):
        self.p = tf.Variable(tf.random_uniform((input_shape[1], self.output_dim), 0, 1), name='Permanence')

        rand_mask = np.random.binomial(1, self.pool_density, input_shape[1] * self.output_dim)
        pool_mask = tf.constant(np.reshape(rand_mask, [input_shape[1], self.output_dim]), dtype=tf.float32)

        self.connection = tf.round(self.p) * pool_mask
        self.avg_activation = tf.Variable(tf.zeros([1, self.output_dim]))

        super().build(input_shape)

For the SpatialPooler-calling method (the method that returns the activations whenever we call this layer), we want to create the boosting galculations from the average activation and the boosting strength. We also include the the recent activity in the mini-column’s (global) neighborhood. We compute the overlap score between inputs and compute active mini-columns. We create a matrix of repeated batch IDs, and finally, we stack the batch IDs to generate 2D indices of activated units as our outputs.

    def call(self, x):
        neighbor_mask = tf.constant(-np.identity(self.output_dim) + 1, dtype=tf.float32)
        neighbor_activity = tf.matmul(self.avg_activation, neighbor_mask) / (self.output_dim - 1)
        boost_factor = tf.exp(-self.boost_strength * (self.avg_activation - neighbor_activity))

        overlap = tf.matmul(x, self.connection) * boost_factor

        batch_size = tf.shape(x)[0]
        _, act_indicies = tf.nn.top_k(overlap, k=self.top_k, sorted=False)
        
        batch_ids = tf.tile(tf.reshape(tf.range(0, batch_size), [-1, 1]), [1, self.top_k])
        act_indicies = tf.to_int64(tf.reshape(tf.stack([batch_ids, act_indicies], axis=2), [-1, 2]))
        act_vals = tf.ones((batch_size * self.top_k,))
        output_shape = tf.to_int64(tf.shape(overlap))

        activation = tf.SparseTensor(act_indicies, act_vals, output_shape)
        activation = tf.sparse_tensor_to_dense(activation, validate_indices=False)
        return activation

When we’re training this network, we want to update the weights using the Hebbian learning rule. For each active SP mini-column, we reinforce active input connections by increasing the permanence, and punish inactive connections by decreasing the permanence. We only want to modify permances of connections in active mini-columns. All non-connections are ignored, and the remaining connections are clipped between 0 and 1.

    def train(self, x, y):
        x_shifted = 2 * x - 1
        batch_size = tf.to_float(tf.shape(x)[0])
        delta = tf.einsum('ij,ik,jk->jk', x_shifted, y, self.connection) / batch_size

        new_p = tf.clip_by_value(self.p + self.lr * delta, 0, 1)

        train_op = tf.assign(self.p, new_p)

        avg_activation = tf.reduce_mean(y, axis=0, keep_dims=True)
        new_act_avg = ((self.duty_cycle - 1) * self.avg_activation + avg_activation) / self.duty_cycle
        update_act_op = tf.assign(self.avg_activation, new_act_avg)

        return [train_op, update_act_op]

Now that we’ve built our SpatialPooler class, we can build the actual model itself.

epochs = 100       # The number of iterations we want to do the training for
num_classes = 10   # number of classes of handwritten digits
num_pixels = 784   # number of pixels in each MNIST image
pixel_bits = 4     # Number of bits per pixel
validation_split = 0.9
input_units = num_pixels * pixel_bits
htm_units = 2048   # The number of units in our HTM
batch_size = 32

class HTMModel:
    def __init__(self):
        pooler = SpatialPooler(htm_units, lr=1e-2)
        # Model input
        self.x = tf.placeholder(tf.float32, [None, input_units])
        self.y = pooler(self.x)
        self.train_ops = pooler.train_ops

        # Build classifier
        classifier_in = Input((htm_units,))
        classifier_out = Dense(num_classes, activation='softmax')(classifier_in)
        self.classifier = Model(classifier_in, classifier_out)
        self.classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

# Build a model
model = HTMModel()

Now that we have our model, we can generate the MNIST dataset to train it on (this next part was shamelessly ripped right from the Tensorflow MNIST tutorial)

# Load MNSIT
print('Loading data...')
mnist = input_data.read_data_sets("data/", one_hot=False)

# Process data using simple greyscale encoder
all_data = []

print('Processing data...')
for img in tqdm(mnist.train.images):
    img_data = []
    for pixel in img:
        # one-hot representation
        index = min(int(pixel * pixel_bits), pixel_bits - 1)
        img_data += list(one_hot(index, pixel_bits))
    all_data.append(img_data)

all_labels = [one_hot(x, num_classes) for x in mnist.train.labels]

num_data = int(len(all_data) * validation_split)
num_validate = len(all_data) - num_data

input_set = np.array(all_data[:num_data])
input_labels = all_labels[:num_data]
val_set = all_data[num_data:num_data+num_validate]
val_labels = all_labels[num_data:num_data+num_validate]

We then define our functions for training the HTM. During each epoch, we want to train the HTM classifier, as well as stopping every now and then to confirm that it’s not overfitting by testing on the validation data. Each time the data is shuffled within the training and validation splits.

def train_htm(sess):
    print('Training HTM...')
    order = np.random.permutation(len(input_set))

    for i in tqdm(range(0, len(order) + 1 - batch_size, batch_size)):
        batch_indices = order[i:i+batch_size]
        x = [input_set[ii] for ii in batch_indices]
        sess.run(model.train_ops, feed_dict={ model.x: x })

def train_classifier(sess):
    print('Training classifier...')
    all_outputs = sess.run(model.y, feed_dict={ model.x: input_set })
    model.classifier.fit(np.array(all_outputs), np.array(input_labels), epochs=10)

def validate(sess):
    print('Validating...')
    all_outputs = sess.run(model.y, feed_dict={ model.x: val_set })
    loss, accuracy = model.classifier.evaluate(np.array(all_outputs), np.array(val_labels))
    print('Accuracy: {}'.format(accuracy))

Final results

Using the following code block, we put all this together and output our trained model:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(epochs):
        print('=== Epoch ' + str(epoch) + ' ===')
        train_htm(sess)
        train_classifier(sess)
        validate(sess)

Our results? On this unlabelled MNIST dataset, we can get 9494% Validation Accuracy

Obviously this isn’t a perfect-fidelity HTM. For example, our Spatial Pooler only implements global inhibition. Implementing a stimulus threshold would also be nice. There are also some coding quirks that make this less efficient than it could be on a GPU (for example, we convert between sparse and dense tensors during the spatial pooler call, but keeping it as a sparse would be much more efficient. Computational efficiency aside, it looks like we succeeded in implementing HTMs in Keras.

If you want to learn even more about HTMs, I recommend reading Numenta’s HTM Cortical Learning Whitepaper, or look at the other papers the have on their website.

Additional Thoughts

So what does this mean? Does this mean that we can model the entire brain using only neurons like this?

Not so fast, it’s not quite that simple…

Projects like OpenWorm have attemtped to replicate entire neural netowkrs for simple model organisms like C. elegans worms. One of the issues is that the individual dynamics of individual neurons themselves are not fully understood yet. Some researchers have even proposed modelling indivisual neurons with ODEs and neural netowrks of their own

And if it is true that the behavior of neurons is the result of some sort of emegence of quantum-level phenomena, then saying HTMs are closer to biological neural networks than typical ANNs would be like saying the top of the Washington Monument is closer than the top of the Capitol building to the surface of Mercury.

Quantum Brain, with some artistic license, courtesy of davidope at http://davidope.com/

We’ll save that for next time

Acknowledgements

Thanks to Matt Taylor, Pascal Weinberger, Jeff Hawkins, Henry Mao for providing the resources, prior HTM work, and inspiration needed to get this code off the ground.

Update

You can now run the code described above in this wonderful new tool called Google Colab. You can basically run JuPyter notebooks in-browser. You can explore the code for yourself in this notebook.


Cited as:

@article{mcateer2017htms,
  title   = "Heirarchal Temporal Memory (with Keras example)",
  author  = "McAteer, Matthew",
  journal = "matthewmcateer.me",
  year    = "2017",
  url     = "https://matthewmcateer.me/blog/heirarchal-temporal-memory-in-keras/"
}

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

At least this isn't a full-screen popup

That'd be more annoying. Anyways, subscribe to my newsletter to get new posts by email! I write about AI, Biotech, and a bunch of other topics.


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.