Optimal Brain Damage

Cutting away 80% of neurons in a neural network with no impact on accuracy

This project was a result of work I recently did with FOR.ai, a multi-disciplinary team of scientists and engineers who like doing machine learning research for fun. You can read more about their projects on their blog, or follow them on Twitter

There are multiple ways of optimizing neural-network-based machine learning algorithms. One of these optimizations is the removal of connections between neurons and layers, and thus speeding up computation by reducing the overal number of parameters.

Pruning example

Networks generally look like the one on the left: every neuron in the layer below has a connection to the layer above; but this means that we have to multiply a lot of floats together. Ideally, we’d only connect each neuron to a few others and save on doing some of the multiplications; this is called a “sparse” network.

Given a layer of a neural network ReLU(xW)\text{ReLU}(xW) are two well-known ways to prune it:

  • Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.

    • Here, to achieve sparsity of kk% we rank the individual weights in weight matrix WW according to their magnitude (absolute value) wi,j|w_{i,j}|, and then set to zero the smallest kk%.
  • Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.

    • Here to achieve sparsity of kk% we rank the columns of a weight matrix according to their L2-norm w=i=1N(xi)2|w| = \sqrt{\sum_{i=1}^{N}(x_i)^2} and delete the smallest kk%.

Naturally, as you increase the sparsity and delete more of the network, the task performance will progressively degrade. Here, we will be implementing both weight and unit pruning and compare the performance across both the MNIST and FMNIST datasets.

If you want to follow along, you can find the full notebook here

Training our model (first without pruning)

First we’re going to set up our dataset. For the sake of running experiments on multiple datasets in parallel, we create a dataset-loading function.

def load_dataset(dataset='mnist'):
    # input image dimensions (equal for both MNIST and FMNIST)
    img_rows, img_cols = 28, 28
    if dataset=='mnist':
        # Number of classes in the data
        num_classes = 10
        # the data, shuffled and split between train and test sets
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    elif dataset=='fmnist':
        # Number of classes in the data
        num_classes = 10
        # the data, shuffled and split between train and test sets
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    else:
        print('dataset name does not match available options \n( mnist | keras )')
    x_train = x_train.reshape(x_train.shape[0], img_rows*img_cols)
    x_test = x_test.reshape(x_test.shape[0], img_rows*img_cols)
    input_shape = (img_rows*img_cols*1,)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')

    # convert class vectors to binary class matrices
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    return x_train, x_test, y_train, y_test, num_classes, input_shape

# We will load separate tensors for the MNIST and FMNIST data. 
mnist_x_train, mnist_x_test, mnist_y_train, mnist_y_test, num_classes, input_shape = load_dataset(dataset='mnist')
fmnist_x_train, fmnist_x_test, fmnist_y_train, fmnist_y_test, num_classes, input_shape = load_dataset(dataset='fmnist')

We will construct a ReLU\text{ReLU}-activated neural network with four hidden layers. These layers will be dense, fully-connected layers with sizes 10001000, 10001000, 500500, & 200200.We’ll also have have a fifth layer for the output logits, which we will have 1010 (Note: Since these connect directly to the output layer, these will be spared from any and all pruning).

For the sake of simplicity, we will also omit Dropout layers, Convolutional layers, Batch Normalization Layers, and Avg Pooling Layers.

l = tf.keras.layers

def build_model_arch(input_shape, num_classes, sparsity=0.0):
    model = tf.keras.Sequential()

    model.add(l.Dense(int(1000-(1000*sparsity)), activation='relu',
                      input_shape=input_shape)),
    model.add(l.Dense(int(1000-(1000*sparsity)), activation='relu'))
    model.add(l.Dense(int(500-(500*sparsity)), activation='relu'))
    model.add(l.Dense(int(200-(200*sparsity)), activation='relu'))
    model.add(l.Dense(num_classes, activation='softmax'))

    return model


# The architectures are the same, but we are initializing 2 different sequential
# models. One is for MNIST, and one is for FMNIST

mnist_model_base = build_model_arch(input_shape, num_classes)
fmnist_model_base = build_model_arch(input_shape, num_classes)

We can also set up TensorBoard to monitor the training process.

logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)

Now that we have our dataset feeder and logging set up, we can build our non-sparse model that we will prune post-training. Setting up our model definitions in a function like this allows us to create multiple graphs that we can test in experiments like this (this is less a machine learning fact and more a coding practice I’d love to encourage)

def make_nosparse_model(model, x_train, y_train, batch_size, 
                         epochs, x_test, y_test):
    callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]
    model.compile(
        loss=tf.keras.losses.categorical_crossentropy,
        optimizer='adam', metrics=['accuracy'])

    model.fit(x_train, y_train, batch_size=batch_size,
              epochs=epochs, verbose=1,
              callbacks=callbacks,
              validation_data=(x_test, y_test))
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
    
    return model, score

batch_size = 128
epochs = 10

mnist_model, mnist_score = make_nosparse_model(mnist_model_base,
                                               mnist_x_train,
                                               mnist_y_train,
                                               batch_size,
                                               epochs,
                                               mnist_x_test,
                                               mnist_y_test)
print(mnist_model.summary())

fmnist_model, fmnist_score = make_nosparse_model(fmnist_model_base,
                                                 fmnist_x_train,
                                                 fmnist_y_train,
                                                 batch_size,
                                                 epochs,
                                                 fmnist_x_test,
                                                 fmnist_y_test)
print(fmnist_model.summary())

So how do our non-pruned models perform?

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 3s 46us/sample - loss: 0.2092 - acc: 0.9361 - val_loss: 0.1327 - val_acc: 0.9551
Epoch 2/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0857 - acc: 0.9740 - val_loss: 0.0882 - val_acc: 0.9738
Epoch 3/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0626 - acc: 0.9804 - val_loss: 0.0906 - val_acc: 0.9738
Epoch 4/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0443 - acc: 0.9862 - val_loss: 0.0633 - val_acc: 0.9814
Epoch 5/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0361 - acc: 0.9887 - val_loss: 0.0962 - val_acc: 0.9738
Epoch 6/10
60000/60000 [==============================] - 2s 28us/sample - loss: 0.0331 - acc: 0.9900 - val_loss: 0.0813 - val_acc: 0.9791
Epoch 7/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0270 - acc: 0.9920 - val_loss: 0.0864 - val_acc: 0.9783
Epoch 8/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0226 - acc: 0.9924 - val_loss: 0.1024 - val_acc: 0.9751
Epoch 9/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.0206 - acc: 0.9939 - val_loss: 0.0785 - val_acc: 0.9814
Epoch 10/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.0192 - acc: 0.9942 - val_loss: 0.0762 - val_acc: 0.9820
Test loss: 0.07615731712152446
Test accuracy: 0.982
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1000)              785000    
_________________________________________________________________
dense_1 (Dense)              (None, 1000)              1001000   
_________________________________________________________________
dense_2 (Dense)              (None, 500)               500500    
_________________________________________________________________
dense_3 (Dense)              (None, 200)               100200    
_________________________________________________________________
dense_4 (Dense)              (None, 10)                2010      
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 2s 35us/sample - loss: 0.4828 - acc: 0.8260 - val_loss: 0.4383 - val_acc: 0.8391
Epoch 2/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3591 - acc: 0.8671 - val_loss: 0.3768 - val_acc: 0.8594
Epoch 3/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3217 - acc: 0.8805 - val_loss: 0.3598 - val_acc: 0.8722
Epoch 4/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2959 - acc: 0.8906 - val_loss: 0.3294 - val_acc: 0.8814
Epoch 5/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2843 - acc: 0.8936 - val_loss: 0.3463 - val_acc: 0.8761
Epoch 6/10
60000/60000 [==============================] - 2s 31us/sample - loss: 0.2657 - acc: 0.9000 - val_loss: 0.3267 - val_acc: 0.8815
Epoch 7/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2528 - acc: 0.9049 - val_loss: 0.3491 - val_acc: 0.8763
Epoch 8/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.2383 - acc: 0.9091 - val_loss: 0.3388 - val_acc: 0.8828
Epoch 9/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.2273 - acc: 0.9126 - val_loss: 0.3320 - val_acc: 0.8855
Epoch 10/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2207 - acc: 0.9154 - val_loss: 0.3231 - val_acc: 0.8885
Test loss: 0.3230560186743736
Test accuracy: 0.8885
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_5 (Dense)              (None, 1000)              785000    
_________________________________________________________________
dense_6 (Dense)              (None, 1000)              1001000   
_________________________________________________________________
dense_7 (Dense)              (None, 500)               500500    
_________________________________________________________________
dense_8 (Dense)              (None, 200)               100200    
_________________________________________________________________
dense_9 (Dense)              (None, 10)                2010      
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Non-trainable params: 0
_________________________________________________________________
None

So our model does pretty good. On MNIST we get % accuracy and 0.0 loss, and on KMNIST we get 98.2% loss and 0.076 cross entropy loss.

Not bad, but this is done with 2,388,710 trainable parameters. We used this many parameters on what is effecitvely the Hello world! of machine learning. If we’re running an application that relies on multiple, repeated evaluations, this could get computationally expensive. This ignoring what would happen if we wanted to train Convolutional networks, or expand to multi-channel classification tasks.

Pruning our Keras Model

How many parameters are actually needed at minimum to solve this task? We can’t quite get a provable solution for a dataset like this, but we can do the next best thing: reducing subsets of weights to 0.0 and seeing how that impacts performance (aka ‘pruning’). Two pruning methods will be used for this:

  • Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.

    • Here, to achieve sparsity of kk% we rank the individual weights in weight matrix WW according to their magnitude (absolute value) wi,j|w_{i,j}|, and then set to zero the smallest kk%.
  • Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.

    • Here to achieve sparsity of kk% we rank the columns of a weight matrix according to their L2-norm w=i=1N(xi)2|w| = \sqrt{\sum_{i=1}^{N}(x_i)^2} and delete the smallest kk%.

The k%k\% of weights using weight and unit pruning for kk in [0,25,50,60,70,80,90,95,97,99][0, 25, 50, 60, 70, 80, 90, 95, 97, 99]. Neither of these pruning methods will prune the weights leading to the softmax layer.

First, we define our function for weight-pruning, which takes in matrices of kernel and bias weights (for a dense layer) and returns the weight-pruned versions of each:

def weight_prune_dense_layer(k_weights, b_weights, k_sparsity):
    # Copy the kernel weights and get ranked indeces of the abs
    kernel_weights = np.copy(k_weights)
    ind = np.unravel_index(
        np.argsort(
            np.abs(kernel_weights),
            axis=None),
        kernel_weights.shape)
        
    # Number of indexes to set to 0
    cutoff = int(len(ind[0])*k_sparsity)
    # The indexes in the 2D kernel weight matrix to set to 0
    sparse_cutoff_inds = (ind[0][0:cutoff], ind[1][0:cutoff])
    kernel_weights[sparse_cutoff_inds] = 0.
        
    # Copy the bias weights and get ranked indeces of the abs
    bias_weights = np.copy(b_weights)
    ind = np.unravel_index(
        np.argsort(
            np.abs(bias_weights), 
            axis=None), 
        bias_weights.shape)
        
    # Number of indexes to set to 0
    cutoff = int(len(ind[0])*k_sparsity)
    # The indexes in the 1D bias weight matrix to set to 0
    sparse_cutoff_inds = (ind[0][0:cutoff])
    bias_weights[sparse_cutoff_inds] = 0.
    
    return kernel_weights, bias_weights

We can also set up a similar function for unit/neuron-pruning. This will take in matrices of kernel and bias weights (again, for a dense layer) and returns the unit/neuron-pruned versions of each

def unit_prune_dense_layer(k_weights, b_weights, k_sparsity):
    # Copy the kernel weights and get ranked indeces of the
    # column-wise L2 Norms
    kernel_weights = np.copy(k_weights)
    ind = np.argsort(LA.norm(kernel_weights, axis=0))
        
    # Number of indexes to set to 0
    cutoff = int(len(ind)*k_sparsity)
    # The indexes in the 2D kernel weight matrix to set to 0
    sparse_cutoff_inds = ind[0:cutoff]
    kernel_weights[:,sparse_cutoff_inds] = 0.
        
    # Copy the bias weights and get ranked indeces of the abs
    bias_weights = np.copy(b_weights)
    # The indexes in the 1D bias weight matrix to set to 0
    # Equal to the indexes of the columns that were removed in this case
    #sparse_cutoff_inds
    bias_weights[sparse_cutoff_inds] = 0.
    
    return kernel_weights, bias_weights

Pruning across the entire model

The previous functions only took in individual layers and pruned them. Let’s create a function that automatically takes the entire architecture and applies our puning algorithm of choice:

def sparsify_model(model, x_test, y_test, k_sparsity, pruning='weight'):

    # Copying a temporary sparse model from our original
    sparse_model = tf.keras.models.clone_model(model)
    sparse_model.set_weights(model.get_weights())
    
    # Getting a list of the names of each component (w + b) of each layer
    names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
    # Getting the list of the weights for each component (w + b) of each layer
    weights = sparse_model.get_weights()
    
    # Initializing list that will contain the new sparse weights
    newWeightList = []

    # Iterate over all but the final 2 layers (the softmax)
    for i in range(0, len(weights)-2, 2):
        
        if pruning=='weight':
            kernel_weights, bias_weights = weight_prune_dense_layer(weights[i],
                                                                    weights[i+1],
                                                                    k_sparsity)
        elif pruning=='unit':
            kernel_weights, bias_weights = unit_prune_dense_layer(weights[i],
                                                                  weights[i+1],
                                                                  k_sparsity)
        else:
            print('does not match available pruning methods ( weight | unit )')
        
        # Append the new weight list with our sparsified kernel weights
        newWeightList.append(kernel_weights)
        
        # Append the new weight list with our sparsified bias weights
        newWeightList.append(bias_weights)

    # Adding the unchanged weights of the final 2 layers
    for i in range(len(weights)-2, len(weights)):
        unmodified_weight = np.copy(weights[i])
        newWeightList.append(unmodified_weight)

    # Setting the weights of our model to the new ones
    sparse_model.set_weights(newWeightList)
    
    # Re-compiling the Keras model (necessary for using `evaluate()`)
    sparse_model.compile(
        loss=tf.keras.losses.categorical_crossentropy,
        optimizer='adam',
        metrics=['accuracy'])
    
    # Printing the the associated loss & Accuracy for the k% sparsity
    score = sparse_model.evaluate(x_test, y_test, verbose=0)
    print('k% weight sparsity: ', k_sparsity,
          '\tTest loss: {:07.5f}'.format(score[0]),
          '\tTest accuracy: {:05.2f} %%'.format(score[1]*100.))
    
    return sparse_model, score

Weight-and-unit pruning across all k%k\% sparsities

Now for our full set of experiments:

# list of sparsities
k_sparsities = [0.0, 0.25, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.97, 0.99]

# The empty lists where we will store our training results
mnist_model_loss_weight = []
mnist_model_accs_weight = []
mnist_model_loss_unit = []
mnist_model_accs_unit = []
fmnist_model_loss_weight = []
fmnist_model_accs_weight = []
fmnist_model_loss_unit = []
fmnist_model_accs_unit = []

dataset = 'mnist'
pruning = 'weight'
print('\n MNIST Weight-pruning\n')
for k_sparsity in k_sparsities:
    sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
                                         y_test=mnist_y_test,
                                         k_sparsity=k_sparsity, 
                                         pruning=pruning)
    mnist_model_loss_weight.append(score[0])
    mnist_model_accs_weight.append(score[1])
    
    # Save entire model to an H5 file
    sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
    del sparse_model


pruning='unit'
print('\n MNIST Unit-pruning\n')
for k_sparsity in k_sparsities:
    sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
                                         y_test=mnist_y_test, 
                                         k_sparsity=k_sparsity, 
                                         pruning=pruning)
    mnist_model_loss_unit.append(score[0])
    mnist_model_accs_unit.append(score[1])
    
    # Save entire model to an H5 file
    sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
    del sparse_model

dataset = 'fmnist'
pruning = 'weight'
print('\n FMNIST Weight-pruning\n')
for k_sparsity in k_sparsities:
    sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
                                         y_test=fmnist_y_test,
                                         k_sparsity=k_sparsity, 
                                         pruning=pruning)
    fmnist_model_loss_weight.append(score[0])
    fmnist_model_accs_weight.append(score[1])
    
    # Save entire model to an H5 file
    sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
    del sparse_model


pruning='unit'
print('\n FMNIST Unit-pruning\n')
for k_sparsity in k_sparsities:
    sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
                                         y_test=fmnist_y_test, 
                                         k_sparsity=k_sparsity, 
                                         pruning=pruning)
    fmnist_model_loss_unit.append(score[0])
    fmnist_model_accs_unit.append(score[1])
    
    # Save entire model to an H5 file
    sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
    del sparse_model
 MNIST Weight-pruning

k% weight sparsity:  0.0 	Test loss: 0.07616 	Test accuracy: 98.20 %%
k% weight sparsity:  0.25 	Test loss: 0.07391 	Test accuracy: 98.26 %%
k% weight sparsity:  0.5 	Test loss: 0.06690 	Test accuracy: 98.29 %%
k% weight sparsity:  0.6 	Test loss: 0.06573 	Test accuracy: 98.22 %%
k% weight sparsity:  0.7 	Test loss: 0.07903 	Test accuracy: 98.20 %%
k% weight sparsity:  0.8 	Test loss: 0.19978 	Test accuracy: 97.72 %%
k% weight sparsity:  0.9 	Test loss: 1.22491 	Test accuracy: 90.99 %%
k% weight sparsity:  0.95 	Test loss: 2.09516 	Test accuracy: 37.24 %%
k% weight sparsity:  0.97 	Test loss: 2.26086 	Test accuracy: 10.60 %%
k% weight sparsity:  0.99 	Test loss: 2.30334 	Test accuracy: 09.74 %%

 MNIST Unit-pruning

k% weight sparsity:  0.0 	Test loss: 0.07616 	Test accuracy: 98.20 %%
k% weight sparsity:  0.25 	Test loss: 0.07003 	Test accuracy: 98.20 %%
k% weight sparsity:  0.5 	Test loss: 0.10493 	Test accuracy: 98.09 %%
k% weight sparsity:  0.6 	Test loss: 0.30216 	Test accuracy: 97.51 %%
k% weight sparsity:  0.7 	Test loss: 0.85557 	Test accuracy: 95.71 %%
k% weight sparsity:  0.8 	Test loss: 1.83173 	Test accuracy: 69.72 %%
k% weight sparsity:  0.9 	Test loss: 2.25009 	Test accuracy: 25.85 %%
k% weight sparsity:  0.95 	Test loss: 2.30324 	Test accuracy: 09.74 %%
k% weight sparsity:  0.97 	Test loss: 2.30484 	Test accuracy: 09.74 %%
k% weight sparsity:  0.99 	Test loss: 2.30486 	Test accuracy: 09.74 %%

 FMNIST Weight-pruning

k% weight sparsity:  0.0 	Test loss: 0.32306 	Test accuracy: 88.85 %%
k% weight sparsity:  0.25 	Test loss: 0.31822 	Test accuracy: 89.18 %%
k% weight sparsity:  0.5 	Test loss: 0.31417 	Test accuracy: 88.97 %%
k% weight sparsity:  0.6 	Test loss: 0.31821 	Test accuracy: 88.78 %%
k% weight sparsity:  0.7 	Test loss: 0.35468 	Test accuracy: 88.09 %%
k% weight sparsity:  0.8 	Test loss: 0.47173 	Test accuracy: 86.76 %%
k% weight sparsity:  0.9 	Test loss: 1.06918 	Test accuracy: 74.13 %%
k% weight sparsity:  0.95 	Test loss: 1.78654 	Test accuracy: 46.28 %%
k% weight sparsity:  0.97 	Test loss: 2.14777 	Test accuracy: 26.22 %%
k% weight sparsity:  0.99 	Test loss: 2.30555 	Test accuracy: 11.23 %%

 FMNIST Unit-pruning

k% weight sparsity:  0.0 	Test loss: 0.32306 	Test accuracy: 88.85 %%
k% weight sparsity:  0.25 	Test loss: 0.32645 	Test accuracy: 88.65 %%
k% weight sparsity:  0.5 	Test loss: 0.41696 	Test accuracy: 85.76 %%
k% weight sparsity:  0.6 	Test loss: 0.56718 	Test accuracy: 87.20 %%
k% weight sparsity:  0.7 	Test loss: 1.01705 	Test accuracy: 78.30 %%
k% weight sparsity:  0.8 	Test loss: 1.60518 	Test accuracy: 45.68 %%
k% weight sparsity:  0.9 	Test loss: 2.23773 	Test accuracy: 19.64 %%
k% weight sparsity:  0.95 	Test loss: 2.29817 	Test accuracy: 10.05 %%
k% weight sparsity:  0.97 	Test loss: 2.30521 	Test accuracy: 09.97 %%

Nice. We can also turn this into a dataframe of results that we can use for plotting and visualization later:

# Convert the lists to numpy arrays
k_sparsities = np.asarray(k_sparsities)
mnist_model_loss_weight = np.asarray(mnist_model_loss_weight)
mnist_model_accs_weight = np.asarray(mnist_model_accs_weight)
mnist_model_loss_unit = np.asarray(mnist_model_loss_unit)
mnist_model_accs_unit = np.asarray(mnist_model_accs_unit)
fmnist_model_loss_weight = np.asarray(fmnist_model_loss_weight)
fmnist_model_accs_weight = np.asarray(fmnist_model_accs_weight)
fmnist_model_loss_unit = np.asarray(fmnist_model_loss_unit)
fmnist_model_accs_unit = np.asarray(fmnist_model_accs_unit)

# Stack the arrays so they can be used in the DataFrame
sparsity_data = np.stack([k_sparsities,
                          mnist_model_loss_weight,
                          mnist_model_accs_weight,
                          mnist_model_loss_unit,
                          mnist_model_accs_unit,
                          fmnist_model_loss_weight,
                          fmnist_model_accs_weight,
                          fmnist_model_loss_unit,
                          fmnist_model_accs_unit])

# Defining the Pandas DataFrame
sparsity_summary = pd.DataFrame(data=sparsity_data.T,    # values
                                columns=['k_sparsity',   # Column names
                                         'mnist_loss_weight',
                                         'mnist_acc_weight',
                                         'mnist_loss_unit',
                                         'mnist_acc_unit',
                                         'fmnist_loss_weight',
                                         'fmnist_acc_weight',
                                         'fmnist_loss_unit',
                                         'fmnist_acc_unit'])
sparsity_summary.to_csv('sparsity_summary.csv')
sparsity_summary
index k_sparsity mnistlossweight mnistaccweight mnistlossunit mnistaccunit fmnistlossweight fmnistaccweight fmnistlossunit fmnistaccunit
0 0.00 0.076157 0.9820 0.076157 0.9820 0.323056 0.8885 0.323056 0.8885
1 0.25 0.073905 0.9826 0.070034 0.9820 0.318219 0.8918 0.326450 0.8865
2 0.50 0.066900 0.9829 0.104928 0.9809 0.314172 0.8897 0.416962 0.8576
3 0.60 0.065726 0.9822 0.302163 0.9751 0.318206 0.8878 0.567175 0.8720
4 0.70 0.079034 0.9820 0.855566 0.9571 0.354680 0.8809 1.017053 0.7830
5 0.80 0.199783 0.9772 1.831728 0.6972 0.471729 0.8676 1.605178 0.4568
6 0.90 1.224915 0.9099 2.250092 0.2585 1.069181 0.7413 2.237728 0.1964
7 0.95 2.095160 0.3724 2.303235 0.0974 1.786535 0.4628 2.298169 0.1005
8 0.97 2.260861 0.1060 2.304837 0.0974 2.147765 0.2622 2.305212 0.0997
9 0.99 2.303342 0.0974 2.304857 0.0974 2.305550 0.1123 2.305276 0.1000

Visualizing sparsity

Before we visualize the performance of our newly sparse models, let’s check to make sure the neurons were properly pruned. Going index-by-index in 784×1000784 \times 1000 weight matrices is obviously going to be time consuming, but there’s a better way. Since we’re still working with 2-dimensional weight matrices and 1-dimensional bias arrays, we can color-code the values of the matrices. For values that are at 0.0, not close like 1E-16 but actually equal to 0.0, we can color them in a way that breaks with the colormap scheme of the rest of the weights (in this case we can just color them white).

Let’s set up our function to take in an arbitrary keras model and visualize the weights:

def visualize_model_weights(sparse_model):
    weights = sparse_model.get_weights()
    names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
    
    my_cmap = matplotlib.cm.get_cmap('rainbow')
    my_cmap.set_under('w')
    
    # Iterate over all the weight matrices in the model and visualize them
    for i in range(len(weights)):
        weight_matrix = weights[i]
        layer_name = names[i]
        if weight_matrix.ndim == 1: # If Bias or softmax
            weight_matrix = np.resize(weight_matrix,
                                      (1,weight_matrix.size))
            plt.imshow(np.abs(weight_matrix),
                       interpolation='none',
                       aspect = "auto",
                       cmap=my_cmap,
                       vmin=1e-26); # lower bound is set close to but not at 0
            plt.colorbar()
            plt.title(layer_name)
            plt.show()
        else: # all other 2D matrices
            plt.imshow(np.abs(weight_matrix),
                       interpolation='none',
                       cmap=my_cmap,
                       vmin=1e-26);
            plt.colorbar()
            plt.title(layer_name)
            plt.show()

If you’re running this code in the Google Colab version of this, you can use the interactive form to see how the weights change. In Google Colab, the code itself becomes interactive, and you can select the sparse model you want to retrieve. All weights with values of 0.0 will be color-coded weight. 1D Bias layers will be auto-scaled to the dimensions of the 2D plots.

It also helps that the Colab saves the weight files of all the sparse models we trained earlier.

So how does our sparse model look?

dataset = 'mnist'
sparsity = "0.5"
pruning = 'weight'
sparse_model = load_model('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)

As we can see, the model becomes more and more white as the smallest individual weights are set to 0.0. What about unit-pruning?

dataset = 'mnist'
sparsity = "0.5"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)

Now that we’ve confirmed that the models are pruning exactly the way we want them to, let’s see how they perform. How much can we prune away from a model before performance starts to fail?

# Visualizing performance on MNIST
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
plt.grid(b=None)
ax2 = ax1.twinx()
plt.grid(b=None)
plt.title('Test Accuracy as a function of k% Sparsity\nfor 4-hidden-layer MLP trained on MNIST')
ax1.plot(sparsity_summary['k_sparsity'].values,
         sparsity_summary['mnist_acc_weight'].values,
         '#008fd5', linestyle=':', label='Weight-pruning Acc')
ax1.plot(sparsity_summary['k_sparsity'].values,
         sparsity_summary['mnist_acc_unit'].values,
         '#008fd5', linestyle='-', label='Unit-pruning Acc')
ax2.plot(sparsity_summary['k_sparsity'].values,
         sparsity_summary['mnist_loss_weight'].values,
         '#fc4f30', linestyle=':', label='Weight-pruning Loss')
ax2.plot(sparsity_summary['k_sparsity'].values,
         sparsity_summary['mnist_loss_unit'].values,
         '#fc4f30', linestyle='-', label='Unit-pruning Loss')

ax1.set_ylabel('Accuracy (%)', color='#008fd5')
ax2.set_ylabel('Loss (categorical crossentropy)', color='#fc4f30')
ax1.set_xlabel('k% Sparsity')
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), shadow=True, ncol=2);
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), shadow=True, ncol=2);
plt.savefig('images/MNIST_sparsity_comparisons.png')

Results of our tests on MNIST (turns out you can remove a lot before performance starts to falter)

The performance curves for unit and weight pruning differ quite a lot on MNIST. Pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around k=80k=80. Even then, the accuracy does not begin to noticeably decrease until k=90k=90.

For unit-pruning, accuracy begins to fall earlier, around k=70k=70 (with loss beginning to increase around k=60k=60). Even then, both methods are able to effecively remove more than half of the network weights without any dramatic differences in test classification performance.

How does this work on FMNIST?

Looks like it’s more than just MNIST that this technique works on

Even on FMNIST, which had a lower initial accuracy and higher initial loss, the same pattern emerges

Again, pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around k=80k=80. Even then, the accuracy does not begin to noticeably decrease until k=90k=90.

For unit-pruning, the differences come much earlier for FMNIST than in MNIST. Accuracy begins to fall around k=60k=60 (with loss beginning to increase around k=60k=60).

Why are the models behaving this way?

The two methods use different strategies of finding the least useful weights. The weight-pruning finds the absolute values (w|w|) of individual weights within the weight matrices. The unit-pruning finds L2 norms across entire columns of weight matrices. This difference is in part a difference between fine-grained and coarse-grained weight pruning.

Both the weight-pruning and unit-pruning are forms of saliency mapping. The pruning functions are going through the weight matrices and identifying the weights that would have minimal impact if they were multiplied with input data being fed into the layer. Given that some weights are orders of magnitude smaller than the largest ones, the impact of removing them is minimal. What is suprising is how much of the weight matrices are made up of these low-saliency weights.

Why are we able to delete so much of the network without hurting performance?**

This demonstration shows that neural network parameter counts of trained networks can be decreased by over 80%80\% without hurting perofrmance. Still other researchers have demonstrated pruning techniques that can decrease parameter counts by over 90%90\%. This is still an open research question in machine learning.

This bears some similarity to the “optimal brain damage” hypothesis by Yann LeCun. In addition to being based on practical observations of differences in weight activations in neural networks, this bears similarity to ideas of how biological neural networks (like those in the neocortex) specialize. .

[Frankle & Carbin, 2019] proposed the “Lottery Ticket hypothesis”. This hypthesis articulates that dense, randomly-initialized, feed-forward networks contain subnetworks (“winning tickets”) that - when trained in isolation - reach test accuracy comparable to the original network in a similar number of iterations. It is named for the fact that these subnetworks, when initialized with random initialization, just happened to get initial weights that make training marticularly effective (i.e., they have won the “training lottery”) For the task of MNIST, both weight-pruning and unit-pruning suggest that 2,388,7102,388,710 trainable parameters is supervluous compared to the truly optimal subnetwork that captures the concise classification function.

In finding ways of scaling the weights (or columns of weights), and removing the k%k \% closest to 00, we are trying to filter out all but the weights that make the input features maximally separable. [Zhang et al., 2018] compared the use of this kind of separation to Grassmannian Manifolds. Given that there is no universal formula to figuring out provably optimal subspace packings, the Grassmannian model of neural networks frames them as approximators of the solutions. Given that this is a less than optimal packing, some of the subspaces of the weights (e.g., those represented either as individual weights or vectors of weights) can easily be removed without impacting the performance of the classification. At some point, there are no more values to remove save for those satisfying the subspace packing needed for the maximum separability of the classes.

The latent variable model of neural network function would suggest that neural networks work by trying to learn the requisite latent variables (David ha’s blog post is a great example of this) needed to capture the epistemological essense of the given class. As David Ha’s example shows, very high resolution images of MNIST digits can be produced from a generator with a vector containing just 3232 real numbers.

For finding minimum subnetwork (as framed by the “lottery-ticket hypothesis”), the optimal subspace packing (as framed by the ), it is also possible that there is more room for improvement beyond weight-pruning. Weight pruning looks at individual weight values in isolation. While the Unit pruning seems to suggest that looking at weights in combination does not produce better results, this is only one such filter for choosing weights to set to 00. Uber AI tested out multiple masks for finding the “lottery-ticket” subnetworks.

Figure from Uber AI’s ICML 2019 poster “Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask”; https://eng.uber.com/deconstructing-lottery-tickets/

It is also important to note that this pruning method only involves setting weight values below a certain threshold to 00 and then testing immediately afterwards. The experiments shown in the code do not demonstrate successive rounds of successive weight pruning across multiple training rounds (like what TF2.0’s pruning API allows for)

Reducing the size and runtimes of our models

One of the added benefits of reducing weights to 0 is that we can actually reduce the overall size of the network. With unit/neuron-pruning specifically, we can clear out entire columns & rows of our weight matrices. Let’s create a function that, given a list of weight matrices of the layers, compreses all sparse matrices save for the last two layers

def compress_sparse_weigths(sparse_model_weight_list, unique_columns=None):
    compressed_weight_list = []
    for i in range(0,len(sparse_model_weight_list)-2,2):
        # If this is just after the input layer, the matrix with input
        # dimensions will be added on unchanged. Otherwise the matrix will only
        # retain the columns that are not entirely 0s
        if unique_columns==None:
            kernel_weights = sparse_model_weight_list[i]
        else:
            kernel_weights = sparse_model_weight_list[i][unique_columns,:]
        bias_weights = sparse_model_weight_list[i+1]
        
        # a tuple of two arrays: 0th is row indices, 1st is cols
        indices = np.nonzero(kernel_weights)
        columns_non_unique = indices[1]
        unique_columns = sorted(set(columns_non_unique))
        kernel_weights = kernel_weights[:,unique_columns]
        bias_weights = bias_weights[unique_columns]
        
        # Adding the new kernel and bias weights to the new list
        compressed_weight_list.append(kernel_weights)
        compressed_weight_list.append(bias_weights)
    
    # Adding the softmax and modifying the pre-softmax weight matrices
    compressed_weight_list.append(sparse_model_weight_list[-2][unique_columns,:])
    compressed_weight_list.append(sparse_model_weight_list[-1])
    
    return compressed_weight_list

Now, let’s load the newly compressed list of weights into a keras model, compile it, and see how it compares to the original. Let’s try this out on the 97%97\% sparse model. Again, if you look at the Google Colab version of this, you can find an interactive version of the code below:

dataset = 'mnist'
sparsity = "0.95"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}-model_k-{}_unit-pruned.h5'.format(dataset, sparsity, pruning))

# List of weights from the loaded file
sparse_weight_list = sparse_model.get_weights()
# Creating a list of dense weights from the sparse weight list
compressed_weight_list = compress_sparse_weigths(sparse_weight_list)
# Getting a 4-layer neural network with layer dimensions that match those of
# the new dense weight matrices (e.g., at 50% sparsity, layer size of 1000 is 
# reduced to 500)
compressed_model = build_model_arch(input_shape, num_classes, sparsity=float(sparsity))
compressed_model.compile(loss=tf.keras.losses.categorical_crossentropy,
                         optimizer='adam', metrics=['accuracy'])
compressed_model.set_weights(compressed_weight_list)

# Printing the model summaries and comparing the weights of the original sparse
# model and the dense model
print(compressed_model.summary())
print('\nVISUALIZING ORIGINAL SPARSE MODEL')
visualize_model_weights(sparse_model)
print('\nVISUALIZING COMPRESSED MODEL')
visualize_model_weights(compressed_model)
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_10 (Dense)             (None, 50)                39250     
_________________________________________________________________
dense_11 (Dense)             (None, 50)                2550      
_________________________________________________________________
dense_12 (Dense)             (None, 25)                1275      
_________________________________________________________________
dense_13 (Dense)             (None, 10)                260       
_________________________________________________________________
dense_14 (Dense)             (None, 10)                110       
=================================================================
Total params: 43,445
Trainable params: 43,445
Non-trainable params: 0
_________________________________________________________________
None

VISUALIZING ORIGINAL SPARSE MODEL

Our 97%97\% sparse model only contains 43,44543,445 parameters versus the >2,000,000>2,000,000 of the original model. Granted, this is one of the more poorly performing ones, but this demonstrates how much we can reduce the memory costs of already-trained model.


If you’d like to learn more about neural network pruning, you can look at some of the resources below.

References/Resources/Further Reading

Subscribe to know whenever I post new content. I don't spam!


At least this isn't a full screen popup

That would be more annoying. Anyways, if you like what you're reading, consider subscribing to my newsletter! I'll notify you when I publish new posts - no spam.