Optimal Brain Damage
Cutting away 80% of neurons in a neural network with no impact on accuracy
 UPDATED
This project was a result of work I recently did with FOR.ai, a multidisciplinary team of scientists and engineers who like doing machine learning research for fun. You can read more about their projects on their blog, or follow them on Twitter
There are multiple ways of optimizing neuralnetworkbased machine learning algorithms. One of these optimizations is the removal of connections between neurons and layers, and thus speeding up computation by reducing the overal number of parameters.
Networks generally look like the one on the left: every neuron in the layer below has a connection to the layer above; but this means that we have to multiply a lot of floats together. Ideally, we’d only connect each neuron to a few others and save on doing some of the multiplications; this is called a “sparse” network.
Given a layer of a neural network $\text{ReLU}(xW)$ are two wellknown ways to prune it:

Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.
 Here, to achieve sparsity of $k%$ we rank the individual weights in weight matrix $W$ according to their magnitude (absolute value) $w_{i,j}$, and then set to zero the smallest $k%$.

Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.
 Here to achieve sparsity of $k%$ we rank the columns of a weight matrix according to their L2norm $w = \sqrt{\sum_{i=1}^{N}(x_i)^2}$ and delete the smallest $k%$.
Naturally, as you increase the sparsity and delete more of the network, the task performance will progressively degrade. Here, we will be implementing both weight and unit pruning and compare the performance across both the MNIST and FMNIST datasets.
If you want to follow along, you can find the full notebook here
Training our model (first without pruning)
First we’re going to set up our dataset. For the sake of running experiments on multiple datasets in parallel, we create a datasetloading function.
def load_dataset(dataset='mnist'):
# input image dimensions (equal for both MNIST and FMNIST)
img_rows, img_cols = 28, 28
if dataset=='mnist':
# Number of classes in the data
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
elif dataset=='fmnist':
# Number of classes in the data
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
else:
print('dataset name does not match available options \n( mnist  keras )')
x_train = x_train.reshape(x_train.shape[0], img_rows*img_cols)
x_test = x_test.reshape(x_test.shape[0], img_rows*img_cols)
input_shape = (img_rows*img_cols*1,)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
return x_train, x_test, y_train, y_test, num_classes, input_shape
# We will load separate tensors for the MNIST and FMNIST data.
mnist_x_train, mnist_x_test, mnist_y_train, mnist_y_test, num_classes, input_shape = load_dataset(dataset='mnist')
fmnist_x_train, fmnist_x_test, fmnist_y_train, fmnist_y_test, num_classes, input_shape = load_dataset(dataset='fmnist')
We will construct a $\text{ReLU}$activated neural network with four hidden layers. These layers will be dense, fullyconnected layers with sizes $1000$, $1000$, $500$, & $200$.We’ll also have have a fifth layer for the output logits, which we will have $10$ (Note: Since these connect directly to the output layer, these will be spared from any and all pruning).
For the sake of simplicity, we will also omit Dropout layers, Convolutional layers, Batch Normalization Layers, and Avg Pooling Layers.
l = tf.keras.layers
def build_model_arch(input_shape, num_classes, sparsity=0.0):
model = tf.keras.Sequential()
model.add(l.Dense(int(1000(1000*sparsity)), activation='relu',
input_shape=input_shape)),
model.add(l.Dense(int(1000(1000*sparsity)), activation='relu'))
model.add(l.Dense(int(500(500*sparsity)), activation='relu'))
model.add(l.Dense(int(200(200*sparsity)), activation='relu'))
model.add(l.Dense(num_classes, activation='softmax'))
return model
# The architectures are the same, but we are initializing 2 different sequential
# models. One is for MNIST, and one is for FMNIST
mnist_model_base = build_model_arch(input_shape, num_classes)
fmnist_model_base = build_model_arch(input_shape, num_classes)
We can also set up TensorBoard to monitor the training process.
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
Now that we have our dataset feeder and logging set up, we can build our nonsparse model that we will prune posttraining. Setting up our model definitions in a function like this allows us to create multiple graphs that we can test in experiments like this (this is less a machine learning fact and more a coding practice I’d love to encourage)
def make_nosparse_model(model, x_train, y_train, batch_size,
epochs, x_test, y_test):
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs, verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
return model, score
batch_size = 128
epochs = 10
mnist_model, mnist_score = make_nosparse_model(mnist_model_base,
mnist_x_train,
mnist_y_train,
batch_size,
epochs,
mnist_x_test,
mnist_y_test)
print(mnist_model.summary())
fmnist_model, fmnist_score = make_nosparse_model(fmnist_model_base,
fmnist_x_train,
fmnist_y_train,
batch_size,
epochs,
fmnist_x_test,
fmnist_y_test)
print(fmnist_model.summary())
So how do our nonpruned models perform?
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================]  3s 46us/sample  loss: 0.2092  acc: 0.9361  val_loss: 0.1327  val_acc: 0.9551
Epoch 2/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0857  acc: 0.9740  val_loss: 0.0882  val_acc: 0.9738
Epoch 3/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0626  acc: 0.9804  val_loss: 0.0906  val_acc: 0.9738
Epoch 4/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0443  acc: 0.9862  val_loss: 0.0633  val_acc: 0.9814
Epoch 5/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0361  acc: 0.9887  val_loss: 0.0962  val_acc: 0.9738
Epoch 6/10
60000/60000 [==============================]  2s 28us/sample  loss: 0.0331  acc: 0.9900  val_loss: 0.0813  val_acc: 0.9791
Epoch 7/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0270  acc: 0.9920  val_loss: 0.0864  val_acc: 0.9783
Epoch 8/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.0226  acc: 0.9924  val_loss: 0.1024  val_acc: 0.9751
Epoch 9/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.0206  acc: 0.9939  val_loss: 0.0785  val_acc: 0.9814
Epoch 10/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.0192  acc: 0.9942  val_loss: 0.0762  val_acc: 0.9820
Test loss: 0.07615731712152446
Test accuracy: 0.982
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 1000) 785000
_________________________________________________________________
dense_1 (Dense) (None, 1000) 1001000
_________________________________________________________________
dense_2 (Dense) (None, 500) 500500
_________________________________________________________________
dense_3 (Dense) (None, 200) 100200
_________________________________________________________________
dense_4 (Dense) (None, 10) 2010
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Nontrainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================]  2s 35us/sample  loss: 0.4828  acc: 0.8260  val_loss: 0.4383  val_acc: 0.8391
Epoch 2/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.3591  acc: 0.8671  val_loss: 0.3768  val_acc: 0.8594
Epoch 3/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.3217  acc: 0.8805  val_loss: 0.3598  val_acc: 0.8722
Epoch 4/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.2959  acc: 0.8906  val_loss: 0.3294  val_acc: 0.8814
Epoch 5/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.2843  acc: 0.8936  val_loss: 0.3463  val_acc: 0.8761
Epoch 6/10
60000/60000 [==============================]  2s 31us/sample  loss: 0.2657  acc: 0.9000  val_loss: 0.3267  val_acc: 0.8815
Epoch 7/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.2528  acc: 0.9049  val_loss: 0.3491  val_acc: 0.8763
Epoch 8/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.2383  acc: 0.9091  val_loss: 0.3388  val_acc: 0.8828
Epoch 9/10
60000/60000 [==============================]  2s 29us/sample  loss: 0.2273  acc: 0.9126  val_loss: 0.3320  val_acc: 0.8855
Epoch 10/10
60000/60000 [==============================]  2s 30us/sample  loss: 0.2207  acc: 0.9154  val_loss: 0.3231  val_acc: 0.8885
Test loss: 0.3230560186743736
Test accuracy: 0.8885
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_5 (Dense) (None, 1000) 785000
_________________________________________________________________
dense_6 (Dense) (None, 1000) 1001000
_________________________________________________________________
dense_7 (Dense) (None, 500) 500500
_________________________________________________________________
dense_8 (Dense) (None, 200) 100200
_________________________________________________________________
dense_9 (Dense) (None, 10) 2010
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Nontrainable params: 0
_________________________________________________________________
None
So our model does pretty good. On MNIST we get % accuracy and 0.0 loss, and on KMNIST we get 98.2% loss and 0.076 cross entropy loss.
Not bad, but this is done with 2,388,710 trainable parameters. We used this many parameters on what is effecitvely the Hello world!
of machine learning. If we’re running an application that relies on multiple, repeated evaluations, this could get computationally expensive. This ignoring what would happen if we wanted to train Convolutional networks, or expand to multichannel classification tasks.
Pruning our Keras Model
How many parameters are actually needed at minimum to solve this task? We can’t quite get a provable solution for a dataset like this, but we can do the next best thing: reducing subsets of weights to 0.0
and seeing how that impacts performance (aka ‘pruning’). Two pruning methods will be used for this:

Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.
 Here, to achieve sparsity of $k%$ we rank the individual weights in weight matrix $W$ according to their magnitude (absolute value) $w_{i,j}$, and then set to zero the smallest $k%$.

Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.
 Here to achieve sparsity of $k%$ we rank the columns of a weight matrix according to their L2norm $w = \sqrt{\sum_{i=1}^{N}(x_i)^2}$ and delete the smallest $k%$.
The $k\%$ of weights using weight and unit pruning for $k$ in $[0, 25, 50, 60, 70, 80, 90, 95, 97, 99]$. Neither of these pruning methods will prune the weights leading to the softmax layer.
First, we define our function for weightpruning, which takes in matrices of kernel and bias weights (for a dense layer) and returns the weightpruned versions of each:
def weight_prune_dense_layer(k_weights, b_weights, k_sparsity):
# Copy the kernel weights and get ranked indeces of the abs
kernel_weights = np.copy(k_weights)
ind = np.unravel_index(
np.argsort(
np.abs(kernel_weights),
axis=None),
kernel_weights.shape)
# Number of indexes to set to 0
cutoff = int(len(ind[0])*k_sparsity)
# The indexes in the 2D kernel weight matrix to set to 0
sparse_cutoff_inds = (ind[0][0:cutoff], ind[1][0:cutoff])
kernel_weights[sparse_cutoff_inds] = 0.
# Copy the bias weights and get ranked indeces of the abs
bias_weights = np.copy(b_weights)
ind = np.unravel_index(
np.argsort(
np.abs(bias_weights),
axis=None),
bias_weights.shape)
# Number of indexes to set to 0
cutoff = int(len(ind[0])*k_sparsity)
# The indexes in the 1D bias weight matrix to set to 0
sparse_cutoff_inds = (ind[0][0:cutoff])
bias_weights[sparse_cutoff_inds] = 0.
return kernel_weights, bias_weights
We can also set up a similar function for unit/neuronpruning. This will take in matrices of kernel and bias weights (again, for a dense layer) and returns the unit/neuronpruned versions of each
def unit_prune_dense_layer(k_weights, b_weights, k_sparsity):
# Copy the kernel weights and get ranked indeces of the
# columnwise L2 Norms
kernel_weights = np.copy(k_weights)
ind = np.argsort(LA.norm(kernel_weights, axis=0))
# Number of indexes to set to 0
cutoff = int(len(ind)*k_sparsity)
# The indexes in the 2D kernel weight matrix to set to 0
sparse_cutoff_inds = ind[0:cutoff]
kernel_weights[:,sparse_cutoff_inds] = 0.
# Copy the bias weights and get ranked indeces of the abs
bias_weights = np.copy(b_weights)
# The indexes in the 1D bias weight matrix to set to 0
# Equal to the indexes of the columns that were removed in this case
#sparse_cutoff_inds
bias_weights[sparse_cutoff_inds] = 0.
return kernel_weights, bias_weights
Pruning across the entire model
The previous functions only took in individual layers and pruned them. Let’s create a function that automatically takes the entire architecture and applies our puning algorithm of choice:
def sparsify_model(model, x_test, y_test, k_sparsity, pruning='weight'):
# Copying a temporary sparse model from our original
sparse_model = tf.keras.models.clone_model(model)
sparse_model.set_weights(model.get_weights())
# Getting a list of the names of each component (w + b) of each layer
names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
# Getting the list of the weights for each component (w + b) of each layer
weights = sparse_model.get_weights()
# Initializing list that will contain the new sparse weights
newWeightList = []
# Iterate over all but the final 2 layers (the softmax)
for i in range(0, len(weights)2, 2):
if pruning=='weight':
kernel_weights, bias_weights = weight_prune_dense_layer(weights[i],
weights[i+1],
k_sparsity)
elif pruning=='unit':
kernel_weights, bias_weights = unit_prune_dense_layer(weights[i],
weights[i+1],
k_sparsity)
else:
print('does not match available pruning methods ( weight  unit )')
# Append the new weight list with our sparsified kernel weights
newWeightList.append(kernel_weights)
# Append the new weight list with our sparsified bias weights
newWeightList.append(bias_weights)
# Adding the unchanged weights of the final 2 layers
for i in range(len(weights)2, len(weights)):
unmodified_weight = np.copy(weights[i])
newWeightList.append(unmodified_weight)
# Setting the weights of our model to the new ones
sparse_model.set_weights(newWeightList)
# Recompiling the Keras model (necessary for using `evaluate()`)
sparse_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])
# Printing the the associated loss & Accuracy for the k% sparsity
score = sparse_model.evaluate(x_test, y_test, verbose=0)
print('k% weight sparsity: ', k_sparsity,
'\tTest loss: {:07.5f}'.format(score[0]),
'\tTest accuracy: {:05.2f} %%'.format(score[1]*100.))
return sparse_model, score
Weightandunit pruning across all $k\%$ sparsities
Now for our full set of experiments:
# list of sparsities
k_sparsities = [0.0, 0.25, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.97, 0.99]
# The empty lists where we will store our training results
mnist_model_loss_weight = []
mnist_model_accs_weight = []
mnist_model_loss_unit = []
mnist_model_accs_unit = []
fmnist_model_loss_weight = []
fmnist_model_accs_weight = []
fmnist_model_loss_unit = []
fmnist_model_accs_unit = []
dataset = 'mnist'
pruning = 'weight'
print('\n MNIST Weightpruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
y_test=mnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
mnist_model_loss_weight.append(score[0])
mnist_model_accs_weight.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
pruning='unit'
print('\n MNIST Unitpruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
y_test=mnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
mnist_model_loss_unit.append(score[0])
mnist_model_accs_unit.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
dataset = 'fmnist'
pruning = 'weight'
print('\n FMNIST Weightpruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
y_test=fmnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
fmnist_model_loss_weight.append(score[0])
fmnist_model_accs_weight.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
pruning='unit'
print('\n FMNIST Unitpruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
y_test=fmnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
fmnist_model_loss_unit.append(score[0])
fmnist_model_accs_unit.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
MNIST Weightpruning
k% weight sparsity: 0.0 Test loss: 0.07616 Test accuracy: 98.20 %%
k% weight sparsity: 0.25 Test loss: 0.07391 Test accuracy: 98.26 %%
k% weight sparsity: 0.5 Test loss: 0.06690 Test accuracy: 98.29 %%
k% weight sparsity: 0.6 Test loss: 0.06573 Test accuracy: 98.22 %%
k% weight sparsity: 0.7 Test loss: 0.07903 Test accuracy: 98.20 %%
k% weight sparsity: 0.8 Test loss: 0.19978 Test accuracy: 97.72 %%
k% weight sparsity: 0.9 Test loss: 1.22491 Test accuracy: 90.99 %%
k% weight sparsity: 0.95 Test loss: 2.09516 Test accuracy: 37.24 %%
k% weight sparsity: 0.97 Test loss: 2.26086 Test accuracy: 10.60 %%
k% weight sparsity: 0.99 Test loss: 2.30334 Test accuracy: 09.74 %%
MNIST Unitpruning
k% weight sparsity: 0.0 Test loss: 0.07616 Test accuracy: 98.20 %%
k% weight sparsity: 0.25 Test loss: 0.07003 Test accuracy: 98.20 %%
k% weight sparsity: 0.5 Test loss: 0.10493 Test accuracy: 98.09 %%
k% weight sparsity: 0.6 Test loss: 0.30216 Test accuracy: 97.51 %%
k% weight sparsity: 0.7 Test loss: 0.85557 Test accuracy: 95.71 %%
k% weight sparsity: 0.8 Test loss: 1.83173 Test accuracy: 69.72 %%
k% weight sparsity: 0.9 Test loss: 2.25009 Test accuracy: 25.85 %%
k% weight sparsity: 0.95 Test loss: 2.30324 Test accuracy: 09.74 %%
k% weight sparsity: 0.97 Test loss: 2.30484 Test accuracy: 09.74 %%
k% weight sparsity: 0.99 Test loss: 2.30486 Test accuracy: 09.74 %%
FMNIST Weightpruning
k% weight sparsity: 0.0 Test loss: 0.32306 Test accuracy: 88.85 %%
k% weight sparsity: 0.25 Test loss: 0.31822 Test accuracy: 89.18 %%
k% weight sparsity: 0.5 Test loss: 0.31417 Test accuracy: 88.97 %%
k% weight sparsity: 0.6 Test loss: 0.31821 Test accuracy: 88.78 %%
k% weight sparsity: 0.7 Test loss: 0.35468 Test accuracy: 88.09 %%
k% weight sparsity: 0.8 Test loss: 0.47173 Test accuracy: 86.76 %%
k% weight sparsity: 0.9 Test loss: 1.06918 Test accuracy: 74.13 %%
k% weight sparsity: 0.95 Test loss: 1.78654 Test accuracy: 46.28 %%
k% weight sparsity: 0.97 Test loss: 2.14777 Test accuracy: 26.22 %%
k% weight sparsity: 0.99 Test loss: 2.30555 Test accuracy: 11.23 %%
FMNIST Unitpruning
k% weight sparsity: 0.0 Test loss: 0.32306 Test accuracy: 88.85 %%
k% weight sparsity: 0.25 Test loss: 0.32645 Test accuracy: 88.65 %%
k% weight sparsity: 0.5 Test loss: 0.41696 Test accuracy: 85.76 %%
k% weight sparsity: 0.6 Test loss: 0.56718 Test accuracy: 87.20 %%
k% weight sparsity: 0.7 Test loss: 1.01705 Test accuracy: 78.30 %%
k% weight sparsity: 0.8 Test loss: 1.60518 Test accuracy: 45.68 %%
k% weight sparsity: 0.9 Test loss: 2.23773 Test accuracy: 19.64 %%
k% weight sparsity: 0.95 Test loss: 2.29817 Test accuracy: 10.05 %%
k% weight sparsity: 0.97 Test loss: 2.30521 Test accuracy: 09.97 %%
Nice. We can also turn this into a dataframe of results that we can use for plotting and visualization later:
# Convert the lists to numpy arrays
k_sparsities = np.asarray(k_sparsities)
mnist_model_loss_weight = np.asarray(mnist_model_loss_weight)
mnist_model_accs_weight = np.asarray(mnist_model_accs_weight)
mnist_model_loss_unit = np.asarray(mnist_model_loss_unit)
mnist_model_accs_unit = np.asarray(mnist_model_accs_unit)
fmnist_model_loss_weight = np.asarray(fmnist_model_loss_weight)
fmnist_model_accs_weight = np.asarray(fmnist_model_accs_weight)
fmnist_model_loss_unit = np.asarray(fmnist_model_loss_unit)
fmnist_model_accs_unit = np.asarray(fmnist_model_accs_unit)
# Stack the arrays so they can be used in the DataFrame
sparsity_data = np.stack([k_sparsities,
mnist_model_loss_weight,
mnist_model_accs_weight,
mnist_model_loss_unit,
mnist_model_accs_unit,
fmnist_model_loss_weight,
fmnist_model_accs_weight,
fmnist_model_loss_unit,
fmnist_model_accs_unit])
# Defining the Pandas DataFrame
sparsity_summary = pd.DataFrame(data=sparsity_data.T, # values
columns=['k_sparsity', # Column names
'mnist_loss_weight',
'mnist_acc_weight',
'mnist_loss_unit',
'mnist_acc_unit',
'fmnist_loss_weight',
'fmnist_acc_weight',
'fmnist_loss_unit',
'fmnist_acc_unit'])
sparsity_summary.to_csv('sparsity_summary.csv')
sparsity_summary
index  k_sparsity  mnistlossweight  mnistaccweight  mnistlossunit  mnistaccunit  fmnistlossweight  fmnistaccweight  fmnistlossunit  fmnistaccunit 

0  0.00  0.076157  0.9820  0.076157  0.9820  0.323056  0.8885  0.323056  0.8885 
1  0.25  0.073905  0.9826  0.070034  0.9820  0.318219  0.8918  0.326450  0.8865 
2  0.50  0.066900  0.9829  0.104928  0.9809  0.314172  0.8897  0.416962  0.8576 
3  0.60  0.065726  0.9822  0.302163  0.9751  0.318206  0.8878  0.567175  0.8720 
4  0.70  0.079034  0.9820  0.855566  0.9571  0.354680  0.8809  1.017053  0.7830 
5  0.80  0.199783  0.9772  1.831728  0.6972  0.471729  0.8676  1.605178  0.4568 
6  0.90  1.224915  0.9099  2.250092  0.2585  1.069181  0.7413  2.237728  0.1964 
7  0.95  2.095160  0.3724  2.303235  0.0974  1.786535  0.4628  2.298169  0.1005 
8  0.97  2.260861  0.1060  2.304837  0.0974  2.147765  0.2622  2.305212  0.0997 
9  0.99  2.303342  0.0974  2.304857  0.0974  2.305550  0.1123  2.305276  0.1000 
Visualizing sparsity
Before we visualize the performance of our newly sparse models, let’s check to make sure the neurons were properly pruned. Going indexbyindex in $784 \times 1000$ weight matrices is obviously going to be time consuming, but there’s a better way. Since we’re still working with 2dimensional weight matrices and 1dimensional bias arrays, we can colorcode the values of the matrices. For values that are at 0.0
, not close like 1E16
but actually equal to 0.0
, we can color them in a way that breaks with the colormap scheme of the rest of the weights (in this case we can just color them white).
Let’s set up our function to take in an arbitrary keras model and visualize the weights:
def visualize_model_weights(sparse_model):
weights = sparse_model.get_weights()
names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
my_cmap = matplotlib.cm.get_cmap('rainbow')
my_cmap.set_under('w')
# Iterate over all the weight matrices in the model and visualize them
for i in range(len(weights)):
weight_matrix = weights[i]
layer_name = names[i]
if weight_matrix.ndim == 1: # If Bias or softmax
weight_matrix = np.resize(weight_matrix,
(1,weight_matrix.size))
plt.imshow(np.abs(weight_matrix),
interpolation='none',
aspect = "auto",
cmap=my_cmap,
vmin=1e26); # lower bound is set close to but not at 0
plt.colorbar()
plt.title(layer_name)
plt.show()
else: # all other 2D matrices
plt.imshow(np.abs(weight_matrix),
interpolation='none',
cmap=my_cmap,
vmin=1e26);
plt.colorbar()
plt.title(layer_name)
plt.show()
If you’re running this code in the Google Colab version of this, you can use the interactive form to see how the weights change. In Google Colab, the code itself becomes interactive, and you can select the sparse model you want to retrieve. All weights with values of 0.0
will be colorcoded weight. 1D Bias layers will be autoscaled to the dimensions of the 2D plots.
It also helps that the Colab saves the weight files of all the sparse models we trained earlier.
So how does our sparse model look?
dataset = 'mnist'
sparsity = "0.5"
pruning = 'weight'
sparse_model = load_model('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)
As we can see, the model becomes more and more white as the smallest individual weights are set to 0.0
. What about unitpruning?
dataset = 'mnist'
sparsity = "0.5"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}model_k{}_{}pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)
Now that we’ve confirmed that the models are pruning exactly the way we want them to, let’s see how they perform. How much can we prune away from a model before performance starts to fail?
# Visualizing performance on MNIST
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
plt.grid(b=None)
ax2 = ax1.twinx()
plt.grid(b=None)
plt.title('Test Accuracy as a function of k% Sparsity\nfor 4hiddenlayer MLP trained on MNIST')
ax1.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_acc_weight'].values,
'#008fd5', linestyle=':', label='Weightpruning Acc')
ax1.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_acc_unit'].values,
'#008fd5', linestyle='', label='Unitpruning Acc')
ax2.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_loss_weight'].values,
'#fc4f30', linestyle=':', label='Weightpruning Loss')
ax2.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_loss_unit'].values,
'#fc4f30', linestyle='', label='Unitpruning Loss')
ax1.set_ylabel('Accuracy (%)', color='#008fd5')
ax2.set_ylabel('Loss (categorical crossentropy)', color='#fc4f30')
ax1.set_xlabel('k% Sparsity')
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 0.15), shadow=True, ncol=2);
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 0.25), shadow=True, ncol=2);
plt.savefig('images/MNIST_sparsity_comparisons.png')
The performance curves for unit and weight pruning differ quite a lot on MNIST. Pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around $k=80$. Even then, the accuracy does not begin to noticeably decrease until $k=90$.
For unitpruning, accuracy begins to fall earlier, around $k=70$ (with loss beginning to increase around $k=60$). Even then, both methods are able to effecively remove more than half of the network weights without any dramatic differences in test classification performance.
How does this work on FMNIST?
Even on FMNIST, which had a lower initial accuracy and higher initial loss, the same pattern emerges
Again, pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around $k=80$. Even then, the accuracy does not begin to noticeably decrease until $k=90$.
For unitpruning, the differences come much earlier for FMNIST than in MNIST. Accuracy begins to fall around $k=60$ (with loss beginning to increase around $k=60$).
Why are the models behaving this way?
The two methods use different strategies of finding the least useful weights. The weightpruning finds the absolute values ($w$) of individual weights within the weight matrices. The unitpruning finds L2 norms across entire columns of weight matrices. This difference is in part a difference between finegrained and coarsegrained weight pruning.
Both the weightpruning and unitpruning are forms of saliency mapping. The pruning functions are going through the weight matrices and identifying the weights that would have minimal impact if they were multiplied with input data being fed into the layer. Given that some weights are orders of magnitude smaller than the largest ones, the impact of removing them is minimal. What is suprising is how much of the weight matrices are made up of these lowsaliency weights.
Why are we able to delete so much of the network without hurting performance?**
This demonstration shows that neural network parameter counts of trained networks can be decreased by over $80\%$ without hurting perofrmance. Still other researchers have demonstrated pruning techniques that can decrease parameter counts by over $90\%$. This is still an open research question in machine learning.
This bears some similarity to the “optimal brain damage” hypothesis by Yann LeCun. In addition to being based on practical observations of differences in weight activations in neural networks, this bears similarity to ideas of how biological neural networks (like those in the neocortex) specialize. .
[Frankle & Carbin, 2019] proposed the “Lottery Ticket hypothesis”. This hypthesis articulates that dense, randomlyinitialized, feedforward networks contain subnetworks (“winning tickets”) that  when trained in isolation  reach test accuracy comparable to the original network in a similar number of iterations. It is named for the fact that these subnetworks, when initialized with random initialization, just happened to get initial weights that make training marticularly effective (i.e., they have won the “training lottery”) For the task of MNIST, both weightpruning and unitpruning suggest that $2,388,710$ trainable parameters is supervluous compared to the truly optimal subnetwork that captures the concise classification function.
In finding ways of scaling the weights (or columns of weights), and removing the $k \%$ closest to $0$, we are trying to filter out all but the weights that make the input features maximally separable. [Zhang et al., 2018] compared the use of this kind of separation to Grassmannian Manifolds. Given that there is no universal formula to figuring out provably optimal subspace packings, the Grassmannian model of neural networks frames them as approximators of the solutions. Given that this is a less than optimal packing, some of the subspaces of the weights (e.g., those represented either as individual weights or vectors of weights) can easily be removed without impacting the performance of the classification. At some point, there are no more values to remove save for those satisfying the subspace packing needed for the maximum separability of the classes.
The latent variable model of neural network function would suggest that neural networks work by trying to learn the requisite latent variables (David ha’s blog post is a great example of this) needed to capture the epistemological essense of the given class. As David Ha’s example shows, very high resolution images of MNIST digits can be produced from a generator with a vector containing just $32$ real numbers.
For finding minimum subnetwork (as framed by the “lotteryticket hypothesis”), the optimal subspace packing (as framed by the ), it is also possible that there is more room for improvement beyond weightpruning. Weight pruning looks at individual weight values in isolation. While the Unit pruning seems to suggest that looking at weights in combination does not produce better results, this is only one such filter for choosing weights to set to $0$. Uber AI tested out multiple masks for finding the “lotteryticket” subnetworks.
It is also important to note that this pruning method only involves setting weight values below a certain threshold to $0$ and then testing immediately afterwards. The experiments shown in the code do not demonstrate successive rounds of successive weight pruning across multiple training rounds (like what TF2.0’s pruning API allows for)
Reducing the size and runtimes of our models
One of the added benefits of reducing weights to 0 is that we can actually reduce the overall size of the network. With unit/neuronpruning specifically, we can clear out entire columns & rows of our weight matrices. Let’s create a function that, given a list of weight matrices of the layers, compreses all sparse matrices save for the last two layers
def compress_sparse_weigths(sparse_model_weight_list, unique_columns=None):
compressed_weight_list = []
for i in range(0,len(sparse_model_weight_list)2,2):
# If this is just after the input layer, the matrix with input
# dimensions will be added on unchanged. Otherwise the matrix will only
# retain the columns that are not entirely 0s
if unique_columns==None:
kernel_weights = sparse_model_weight_list[i]
else:
kernel_weights = sparse_model_weight_list[i][unique_columns,:]
bias_weights = sparse_model_weight_list[i+1]
# a tuple of two arrays: 0th is row indices, 1st is cols
indices = np.nonzero(kernel_weights)
columns_non_unique = indices[1]
unique_columns = sorted(set(columns_non_unique))
kernel_weights = kernel_weights[:,unique_columns]
bias_weights = bias_weights[unique_columns]
# Adding the new kernel and bias weights to the new list
compressed_weight_list.append(kernel_weights)
compressed_weight_list.append(bias_weights)
# Adding the softmax and modifying the presoftmax weight matrices
compressed_weight_list.append(sparse_model_weight_list[2][unique_columns,:])
compressed_weight_list.append(sparse_model_weight_list[1])
return compressed_weight_list
Now, let’s load the newly compressed list of weights into a keras model, compile it, and see how it compares to the original. Let’s try this out on the $97\%$ sparse model. Again, if you look at the Google Colab version of this, you can find an interactive version of the code below:
dataset = 'mnist'
sparsity = "0.95"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}model_k{}_unitpruned.h5'.format(dataset, sparsity, pruning))
# List of weights from the loaded file
sparse_weight_list = sparse_model.get_weights()
# Creating a list of dense weights from the sparse weight list
compressed_weight_list = compress_sparse_weigths(sparse_weight_list)
# Getting a 4layer neural network with layer dimensions that match those of
# the new dense weight matrices (e.g., at 50% sparsity, layer size of 1000 is
# reduced to 500)
compressed_model = build_model_arch(input_shape, num_classes, sparsity=float(sparsity))
compressed_model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam', metrics=['accuracy'])
compressed_model.set_weights(compressed_weight_list)
# Printing the model summaries and comparing the weights of the original sparse
# model and the dense model
print(compressed_model.summary())
print('\nVISUALIZING ORIGINAL SPARSE MODEL')
visualize_model_weights(sparse_model)
print('\nVISUALIZING COMPRESSED MODEL')
visualize_model_weights(compressed_model)
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_10 (Dense) (None, 50) 39250
_________________________________________________________________
dense_11 (Dense) (None, 50) 2550
_________________________________________________________________
dense_12 (Dense) (None, 25) 1275
_________________________________________________________________
dense_13 (Dense) (None, 10) 260
_________________________________________________________________
dense_14 (Dense) (None, 10) 110
=================================================================
Total params: 43,445
Trainable params: 43,445
Nontrainable params: 0
_________________________________________________________________
None
VISUALIZING ORIGINAL SPARSE MODEL
Our $97\%$ sparse model only contains $43,445$ parameters versus the $>2,000,000$ of the original model. Granted, this is one of the more poorly performing ones, but this demonstrates how much we can reduce the memory costs of alreadytrained model.
If you’d like to learn more about neural network pruning, you can look at some of the resources below.
References/Resources/Further Reading
 “Pruning deep neural networks to make them fast and small”, Jacob’s Computer Vision and Machine Learning blog
 Molchanov, Pavlo, et al. “Pruning convolutional neural networks for resource efficient inference.” arXiv preprint arXiv:1611.06440 (2016).
 Li, Hao, et al. “Pruning filters for efficient convnets.” arXiv preprint arXiv:1608.08710 (2016).
 Saliency map. Wikipedia. https://en.wikipedia.org/wiki/Saliency_map
 http://yann.lecun.com/exdb/publis/pdf/lecun90b.pdf
 Frankle, Jonathan, and Michael Carbin. “The lottery ticket hypothesis: Finding sparse, trainable neural networks.” arXiv preprint arXiv:1803.03635 (2018).
 Zhang, Jiayao, Guangxu Zhu, and Robert W. Heath Jr. “Grassmannian Learning: Embedding Geometry Awareness in Shallow and Deep Learning.” arXiv preprint arXiv:1808.02229 (2018).
 Generating Large Images From Latent Vectors. David Ha’s blog