Optimal Brain Damage
Cutting away 80% of neurons in a neural network with no impact on accuracy
| UPDATED
This project was a result of work I recently did with FOR.ai, a multi-disciplinary team of scientists and engineers who like doing machine learning research for fun. You can read more about their projects on their blog, or follow them on Twitter
There are multiple ways of optimizing neural-network-based machine learning algorithms. One of these optimizations is the removal of connections between neurons and layers, and thus speeding up computation by reducing the overal number of parameters.
Networks generally look like the one on the left: every neuron in the layer below has a connection to the layer above; but this means that we have to multiply a lot of floats together. Ideally, we’d only connect each neuron to a few others and save on doing some of the multiplications; this is called a “sparse” network.
Given a layer of a neural network are two well-known ways to prune it:
-
Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.
- Here, to achieve sparsity of we rank the individual weights in weight matrix according to their magnitude (absolute value) , and then set to zero the smallest .
-
Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.
- Here to achieve sparsity of we rank the columns of a weight matrix according to their L2-norm and delete the smallest .
Naturally, as you increase the sparsity and delete more of the network, the task performance will progressively degrade. Here, we will be implementing both weight and unit pruning and compare the performance across both the MNIST and FMNIST datasets.
If you want to follow along, you can find the full notebook here
Training our model (first without pruning)
First we’re going to set up our dataset. For the sake of running experiments on multiple datasets in parallel, we create a dataset-loading function.
def load_dataset(dataset='mnist'):
# input image dimensions (equal for both MNIST and FMNIST)
img_rows, img_cols = 28, 28
if dataset=='mnist':
# Number of classes in the data
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
elif dataset=='fmnist':
# Number of classes in the data
num_classes = 10
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
else:
print('dataset name does not match available options \n( mnist | keras )')
x_train = x_train.reshape(x_train.shape[0], img_rows*img_cols)
x_test = x_test.reshape(x_test.shape[0], img_rows*img_cols)
input_shape = (img_rows*img_cols*1,)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
return x_train, x_test, y_train, y_test, num_classes, input_shape
# We will load separate tensors for the MNIST and FMNIST data.
mnist_x_train, mnist_x_test, mnist_y_train, mnist_y_test, num_classes, input_shape = load_dataset(dataset='mnist')
fmnist_x_train, fmnist_x_test, fmnist_y_train, fmnist_y_test, num_classes, input_shape = load_dataset(dataset='fmnist')
We will construct a -activated neural network with four hidden layers. These layers will be dense, fully-connected layers with sizes , , , & .We’ll also have have a fifth layer for the output logits, which we will have (Note: Since these connect directly to the output layer, these will be spared from any and all pruning).
For the sake of simplicity, we will also omit Dropout layers, Convolutional layers, Batch Normalization Layers, and Avg Pooling Layers.
l = tf.keras.layers
def build_model_arch(input_shape, num_classes, sparsity=0.0):
model = tf.keras.Sequential()
model.add(l.Dense(int(1000-(1000*sparsity)), activation='relu',
input_shape=input_shape)),
model.add(l.Dense(int(1000-(1000*sparsity)), activation='relu'))
model.add(l.Dense(int(500-(500*sparsity)), activation='relu'))
model.add(l.Dense(int(200-(200*sparsity)), activation='relu'))
model.add(l.Dense(num_classes, activation='softmax'))
return model
# The architectures are the same, but we are initializing 2 different sequential
# models. One is for MNIST, and one is for FMNIST
mnist_model_base = build_model_arch(input_shape, num_classes)
fmnist_model_base = build_model_arch(input_shape, num_classes)
We can also set up TensorBoard to monitor the training process.
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)
Now that we have our dataset feeder and logging set up, we can build our non-sparse model that we will prune post-training. Setting up our model definitions in a function like this allows us to create multiple graphs that we can test in experiments like this (this is less a machine learning fact and more a coding practice I’d love to encourage)
def make_nosparse_model(model, x_train, y_train, batch_size,
epochs, x_test, y_test):
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=batch_size,
epochs=epochs, verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
return model, score
batch_size = 128
epochs = 10
mnist_model, mnist_score = make_nosparse_model(mnist_model_base,
mnist_x_train,
mnist_y_train,
batch_size,
epochs,
mnist_x_test,
mnist_y_test)
print(mnist_model.summary())
fmnist_model, fmnist_score = make_nosparse_model(fmnist_model_base,
fmnist_x_train,
fmnist_y_train,
batch_size,
epochs,
fmnist_x_test,
fmnist_y_test)
print(fmnist_model.summary())
So how do our non-pruned models perform?
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 3s 46us/sample - loss: 0.2092 - acc: 0.9361 - val_loss: 0.1327 - val_acc: 0.9551
Epoch 2/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0857 - acc: 0.9740 - val_loss: 0.0882 - val_acc: 0.9738
Epoch 3/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0626 - acc: 0.9804 - val_loss: 0.0906 - val_acc: 0.9738
Epoch 4/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0443 - acc: 0.9862 - val_loss: 0.0633 - val_acc: 0.9814
Epoch 5/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0361 - acc: 0.9887 - val_loss: 0.0962 - val_acc: 0.9738
Epoch 6/10
60000/60000 [==============================] - 2s 28us/sample - loss: 0.0331 - acc: 0.9900 - val_loss: 0.0813 - val_acc: 0.9791
Epoch 7/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0270 - acc: 0.9920 - val_loss: 0.0864 - val_acc: 0.9783
Epoch 8/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0226 - acc: 0.9924 - val_loss: 0.1024 - val_acc: 0.9751
Epoch 9/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.0206 - acc: 0.9939 - val_loss: 0.0785 - val_acc: 0.9814
Epoch 10/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.0192 - acc: 0.9942 - val_loss: 0.0762 - val_acc: 0.9820
Test loss: 0.07615731712152446
Test accuracy: 0.982
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 1000) 785000
_________________________________________________________________
dense_1 (Dense) (None, 1000) 1001000
_________________________________________________________________
dense_2 (Dense) (None, 500) 500500
_________________________________________________________________
dense_3 (Dense) (None, 200) 100200
_________________________________________________________________
dense_4 (Dense) (None, 10) 2010
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 2s 35us/sample - loss: 0.4828 - acc: 0.8260 - val_loss: 0.4383 - val_acc: 0.8391
Epoch 2/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3591 - acc: 0.8671 - val_loss: 0.3768 - val_acc: 0.8594
Epoch 3/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.3217 - acc: 0.8805 - val_loss: 0.3598 - val_acc: 0.8722
Epoch 4/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2959 - acc: 0.8906 - val_loss: 0.3294 - val_acc: 0.8814
Epoch 5/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2843 - acc: 0.8936 - val_loss: 0.3463 - val_acc: 0.8761
Epoch 6/10
60000/60000 [==============================] - 2s 31us/sample - loss: 0.2657 - acc: 0.9000 - val_loss: 0.3267 - val_acc: 0.8815
Epoch 7/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2528 - acc: 0.9049 - val_loss: 0.3491 - val_acc: 0.8763
Epoch 8/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.2383 - acc: 0.9091 - val_loss: 0.3388 - val_acc: 0.8828
Epoch 9/10
60000/60000 [==============================] - 2s 29us/sample - loss: 0.2273 - acc: 0.9126 - val_loss: 0.3320 - val_acc: 0.8855
Epoch 10/10
60000/60000 [==============================] - 2s 30us/sample - loss: 0.2207 - acc: 0.9154 - val_loss: 0.3231 - val_acc: 0.8885
Test loss: 0.3230560186743736
Test accuracy: 0.8885
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_5 (Dense) (None, 1000) 785000
_________________________________________________________________
dense_6 (Dense) (None, 1000) 1001000
_________________________________________________________________
dense_7 (Dense) (None, 500) 500500
_________________________________________________________________
dense_8 (Dense) (None, 200) 100200
_________________________________________________________________
dense_9 (Dense) (None, 10) 2010
=================================================================
Total params: 2,388,710
Trainable params: 2,388,710
Non-trainable params: 0
_________________________________________________________________
None
So our model does pretty good. On MNIST we get % accuracy and 0.0 loss, and on KMNIST we get 98.2% loss and 0.076 cross entropy loss.
Not bad, but this is done with 2,388,710 trainable parameters. We used this many parameters on what is effecitvely the Hello world!
of machine learning. If we’re running an application that relies on multiple, repeated evaluations, this could get computationally expensive. This ignoring what would happen if we wanted to train Convolutional networks, or expand to multi-channel classification tasks.
Pruning our Keras Model
How many parameters are actually needed at minimum to solve this task? We can’t quite get a provable solution for a dataset like this, but we can do the next best thing: reducing subsets of weights to 0.0
and seeing how that impacts performance (aka ‘pruning’). Two pruning methods will be used for this:
-
Weight pruning: set individual weights in the weight matrix to zero. This corresponds to deleting connections as in the figure above.
- Here, to achieve sparsity of we rank the individual weights in weight matrix according to their magnitude (absolute value) , and then set to zero the smallest .
-
Unit/Neuron pruning: set entire columns to zero in the weight matrix to zero, in effect deleting the corresponding output neuron.
- Here to achieve sparsity of we rank the columns of a weight matrix according to their L2-norm and delete the smallest .
The of weights using weight and unit pruning for in . Neither of these pruning methods will prune the weights leading to the softmax layer.
First, we define our function for weight-pruning, which takes in matrices of kernel and bias weights (for a dense layer) and returns the weight-pruned versions of each:
def weight_prune_dense_layer(k_weights, b_weights, k_sparsity):
# Copy the kernel weights and get ranked indeces of the abs
kernel_weights = np.copy(k_weights)
ind = np.unravel_index(
np.argsort(
np.abs(kernel_weights),
axis=None),
kernel_weights.shape)
# Number of indexes to set to 0
cutoff = int(len(ind[0])*k_sparsity)
# The indexes in the 2D kernel weight matrix to set to 0
sparse_cutoff_inds = (ind[0][0:cutoff], ind[1][0:cutoff])
kernel_weights[sparse_cutoff_inds] = 0.
# Copy the bias weights and get ranked indeces of the abs
bias_weights = np.copy(b_weights)
ind = np.unravel_index(
np.argsort(
np.abs(bias_weights),
axis=None),
bias_weights.shape)
# Number of indexes to set to 0
cutoff = int(len(ind[0])*k_sparsity)
# The indexes in the 1D bias weight matrix to set to 0
sparse_cutoff_inds = (ind[0][0:cutoff])
bias_weights[sparse_cutoff_inds] = 0.
return kernel_weights, bias_weights
We can also set up a similar function for unit/neuron-pruning. This will take in matrices of kernel and bias weights (again, for a dense layer) and returns the unit/neuron-pruned versions of each
def unit_prune_dense_layer(k_weights, b_weights, k_sparsity):
# Copy the kernel weights and get ranked indeces of the
# column-wise L2 Norms
kernel_weights = np.copy(k_weights)
ind = np.argsort(LA.norm(kernel_weights, axis=0))
# Number of indexes to set to 0
cutoff = int(len(ind)*k_sparsity)
# The indexes in the 2D kernel weight matrix to set to 0
sparse_cutoff_inds = ind[0:cutoff]
kernel_weights[:,sparse_cutoff_inds] = 0.
# Copy the bias weights and get ranked indeces of the abs
bias_weights = np.copy(b_weights)
# The indexes in the 1D bias weight matrix to set to 0
# Equal to the indexes of the columns that were removed in this case
#sparse_cutoff_inds
bias_weights[sparse_cutoff_inds] = 0.
return kernel_weights, bias_weights
Pruning across the entire model
The previous functions only took in individual layers and pruned them. Let’s create a function that automatically takes the entire architecture and applies our puning algorithm of choice:
def sparsify_model(model, x_test, y_test, k_sparsity, pruning='weight'):
# Copying a temporary sparse model from our original
sparse_model = tf.keras.models.clone_model(model)
sparse_model.set_weights(model.get_weights())
# Getting a list of the names of each component (w + b) of each layer
names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
# Getting the list of the weights for each component (w + b) of each layer
weights = sparse_model.get_weights()
# Initializing list that will contain the new sparse weights
newWeightList = []
# Iterate over all but the final 2 layers (the softmax)
for i in range(0, len(weights)-2, 2):
if pruning=='weight':
kernel_weights, bias_weights = weight_prune_dense_layer(weights[i],
weights[i+1],
k_sparsity)
elif pruning=='unit':
kernel_weights, bias_weights = unit_prune_dense_layer(weights[i],
weights[i+1],
k_sparsity)
else:
print('does not match available pruning methods ( weight | unit )')
# Append the new weight list with our sparsified kernel weights
newWeightList.append(kernel_weights)
# Append the new weight list with our sparsified bias weights
newWeightList.append(bias_weights)
# Adding the unchanged weights of the final 2 layers
for i in range(len(weights)-2, len(weights)):
unmodified_weight = np.copy(weights[i])
newWeightList.append(unmodified_weight)
# Setting the weights of our model to the new ones
sparse_model.set_weights(newWeightList)
# Re-compiling the Keras model (necessary for using `evaluate()`)
sparse_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])
# Printing the the associated loss & Accuracy for the k% sparsity
score = sparse_model.evaluate(x_test, y_test, verbose=0)
print('k% weight sparsity: ', k_sparsity,
'\tTest loss: {:07.5f}'.format(score[0]),
'\tTest accuracy: {:05.2f} %%'.format(score[1]*100.))
return sparse_model, score
Weight-and-unit pruning across all sparsities
Now for our full set of experiments:
# list of sparsities
k_sparsities = [0.0, 0.25, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.97, 0.99]
# The empty lists where we will store our training results
mnist_model_loss_weight = []
mnist_model_accs_weight = []
mnist_model_loss_unit = []
mnist_model_accs_unit = []
fmnist_model_loss_weight = []
fmnist_model_accs_weight = []
fmnist_model_loss_unit = []
fmnist_model_accs_unit = []
dataset = 'mnist'
pruning = 'weight'
print('\n MNIST Weight-pruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
y_test=mnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
mnist_model_loss_weight.append(score[0])
mnist_model_accs_weight.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
pruning='unit'
print('\n MNIST Unit-pruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(mnist_model, x_test=mnist_x_test,
y_test=mnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
mnist_model_loss_unit.append(score[0])
mnist_model_accs_unit.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
dataset = 'fmnist'
pruning = 'weight'
print('\n FMNIST Weight-pruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
y_test=fmnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
fmnist_model_loss_weight.append(score[0])
fmnist_model_accs_weight.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
pruning='unit'
print('\n FMNIST Unit-pruning\n')
for k_sparsity in k_sparsities:
sparse_model, score = sparsify_model(fmnist_model, x_test=fmnist_x_test,
y_test=fmnist_y_test,
k_sparsity=k_sparsity,
pruning=pruning)
fmnist_model_loss_unit.append(score[0])
fmnist_model_accs_unit.append(score[1])
# Save entire model to an H5 file
sparse_model.save('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, k_sparsity, pruning))
del sparse_model
MNIST Weight-pruning
k% weight sparsity: 0.0 Test loss: 0.07616 Test accuracy: 98.20 %%
k% weight sparsity: 0.25 Test loss: 0.07391 Test accuracy: 98.26 %%
k% weight sparsity: 0.5 Test loss: 0.06690 Test accuracy: 98.29 %%
k% weight sparsity: 0.6 Test loss: 0.06573 Test accuracy: 98.22 %%
k% weight sparsity: 0.7 Test loss: 0.07903 Test accuracy: 98.20 %%
k% weight sparsity: 0.8 Test loss: 0.19978 Test accuracy: 97.72 %%
k% weight sparsity: 0.9 Test loss: 1.22491 Test accuracy: 90.99 %%
k% weight sparsity: 0.95 Test loss: 2.09516 Test accuracy: 37.24 %%
k% weight sparsity: 0.97 Test loss: 2.26086 Test accuracy: 10.60 %%
k% weight sparsity: 0.99 Test loss: 2.30334 Test accuracy: 09.74 %%
MNIST Unit-pruning
k% weight sparsity: 0.0 Test loss: 0.07616 Test accuracy: 98.20 %%
k% weight sparsity: 0.25 Test loss: 0.07003 Test accuracy: 98.20 %%
k% weight sparsity: 0.5 Test loss: 0.10493 Test accuracy: 98.09 %%
k% weight sparsity: 0.6 Test loss: 0.30216 Test accuracy: 97.51 %%
k% weight sparsity: 0.7 Test loss: 0.85557 Test accuracy: 95.71 %%
k% weight sparsity: 0.8 Test loss: 1.83173 Test accuracy: 69.72 %%
k% weight sparsity: 0.9 Test loss: 2.25009 Test accuracy: 25.85 %%
k% weight sparsity: 0.95 Test loss: 2.30324 Test accuracy: 09.74 %%
k% weight sparsity: 0.97 Test loss: 2.30484 Test accuracy: 09.74 %%
k% weight sparsity: 0.99 Test loss: 2.30486 Test accuracy: 09.74 %%
FMNIST Weight-pruning
k% weight sparsity: 0.0 Test loss: 0.32306 Test accuracy: 88.85 %%
k% weight sparsity: 0.25 Test loss: 0.31822 Test accuracy: 89.18 %%
k% weight sparsity: 0.5 Test loss: 0.31417 Test accuracy: 88.97 %%
k% weight sparsity: 0.6 Test loss: 0.31821 Test accuracy: 88.78 %%
k% weight sparsity: 0.7 Test loss: 0.35468 Test accuracy: 88.09 %%
k% weight sparsity: 0.8 Test loss: 0.47173 Test accuracy: 86.76 %%
k% weight sparsity: 0.9 Test loss: 1.06918 Test accuracy: 74.13 %%
k% weight sparsity: 0.95 Test loss: 1.78654 Test accuracy: 46.28 %%
k% weight sparsity: 0.97 Test loss: 2.14777 Test accuracy: 26.22 %%
k% weight sparsity: 0.99 Test loss: 2.30555 Test accuracy: 11.23 %%
FMNIST Unit-pruning
k% weight sparsity: 0.0 Test loss: 0.32306 Test accuracy: 88.85 %%
k% weight sparsity: 0.25 Test loss: 0.32645 Test accuracy: 88.65 %%
k% weight sparsity: 0.5 Test loss: 0.41696 Test accuracy: 85.76 %%
k% weight sparsity: 0.6 Test loss: 0.56718 Test accuracy: 87.20 %%
k% weight sparsity: 0.7 Test loss: 1.01705 Test accuracy: 78.30 %%
k% weight sparsity: 0.8 Test loss: 1.60518 Test accuracy: 45.68 %%
k% weight sparsity: 0.9 Test loss: 2.23773 Test accuracy: 19.64 %%
k% weight sparsity: 0.95 Test loss: 2.29817 Test accuracy: 10.05 %%
k% weight sparsity: 0.97 Test loss: 2.30521 Test accuracy: 09.97 %%
Nice. We can also turn this into a dataframe of results that we can use for plotting and visualization later:
# Convert the lists to numpy arrays
k_sparsities = np.asarray(k_sparsities)
mnist_model_loss_weight = np.asarray(mnist_model_loss_weight)
mnist_model_accs_weight = np.asarray(mnist_model_accs_weight)
mnist_model_loss_unit = np.asarray(mnist_model_loss_unit)
mnist_model_accs_unit = np.asarray(mnist_model_accs_unit)
fmnist_model_loss_weight = np.asarray(fmnist_model_loss_weight)
fmnist_model_accs_weight = np.asarray(fmnist_model_accs_weight)
fmnist_model_loss_unit = np.asarray(fmnist_model_loss_unit)
fmnist_model_accs_unit = np.asarray(fmnist_model_accs_unit)
# Stack the arrays so they can be used in the DataFrame
sparsity_data = np.stack([k_sparsities,
mnist_model_loss_weight,
mnist_model_accs_weight,
mnist_model_loss_unit,
mnist_model_accs_unit,
fmnist_model_loss_weight,
fmnist_model_accs_weight,
fmnist_model_loss_unit,
fmnist_model_accs_unit])
# Defining the Pandas DataFrame
sparsity_summary = pd.DataFrame(data=sparsity_data.T, # values
columns=['k_sparsity', # Column names
'mnist_loss_weight',
'mnist_acc_weight',
'mnist_loss_unit',
'mnist_acc_unit',
'fmnist_loss_weight',
'fmnist_acc_weight',
'fmnist_loss_unit',
'fmnist_acc_unit'])
sparsity_summary.to_csv('sparsity_summary.csv')
sparsity_summary
index | k_sparsity | mnistlossweight | mnistaccweight | mnistlossunit | mnistaccunit | fmnistlossweight | fmnistaccweight | fmnistlossunit | fmnistaccunit |
---|---|---|---|---|---|---|---|---|---|
0 | 0.00 | 0.076157 | 0.9820 | 0.076157 | 0.9820 | 0.323056 | 0.8885 | 0.323056 | 0.8885 |
1 | 0.25 | 0.073905 | 0.9826 | 0.070034 | 0.9820 | 0.318219 | 0.8918 | 0.326450 | 0.8865 |
2 | 0.50 | 0.066900 | 0.9829 | 0.104928 | 0.9809 | 0.314172 | 0.8897 | 0.416962 | 0.8576 |
3 | 0.60 | 0.065726 | 0.9822 | 0.302163 | 0.9751 | 0.318206 | 0.8878 | 0.567175 | 0.8720 |
4 | 0.70 | 0.079034 | 0.9820 | 0.855566 | 0.9571 | 0.354680 | 0.8809 | 1.017053 | 0.7830 |
5 | 0.80 | 0.199783 | 0.9772 | 1.831728 | 0.6972 | 0.471729 | 0.8676 | 1.605178 | 0.4568 |
6 | 0.90 | 1.224915 | 0.9099 | 2.250092 | 0.2585 | 1.069181 | 0.7413 | 2.237728 | 0.1964 |
7 | 0.95 | 2.095160 | 0.3724 | 2.303235 | 0.0974 | 1.786535 | 0.4628 | 2.298169 | 0.1005 |
8 | 0.97 | 2.260861 | 0.1060 | 2.304837 | 0.0974 | 2.147765 | 0.2622 | 2.305212 | 0.0997 |
9 | 0.99 | 2.303342 | 0.0974 | 2.304857 | 0.0974 | 2.305550 | 0.1123 | 2.305276 | 0.1000 |
Visualizing sparsity
Before we visualize the performance of our newly sparse models, let’s check to make sure the neurons were properly pruned. Going index-by-index in weight matrices is obviously going to be time consuming, but there’s a better way. Since we’re still working with 2-dimensional weight matrices and 1-dimensional bias arrays, we can color-code the values of the matrices. For values that are at 0.0
, not close like 1E-16
but actually equal to 0.0
, we can color them in a way that breaks with the colormap scheme of the rest of the weights (in this case we can just color them white).
Let’s set up our function to take in an arbitrary keras model and visualize the weights:
def visualize_model_weights(sparse_model):
weights = sparse_model.get_weights()
names = [weight.name for layer in sparse_model.layers for weight in layer.weights]
my_cmap = matplotlib.cm.get_cmap('rainbow')
my_cmap.set_under('w')
# Iterate over all the weight matrices in the model and visualize them
for i in range(len(weights)):
weight_matrix = weights[i]
layer_name = names[i]
if weight_matrix.ndim == 1: # If Bias or softmax
weight_matrix = np.resize(weight_matrix,
(1,weight_matrix.size))
plt.imshow(np.abs(weight_matrix),
interpolation='none',
aspect = "auto",
cmap=my_cmap,
vmin=1e-26); # lower bound is set close to but not at 0
plt.colorbar()
plt.title(layer_name)
plt.show()
else: # all other 2D matrices
plt.imshow(np.abs(weight_matrix),
interpolation='none',
cmap=my_cmap,
vmin=1e-26);
plt.colorbar()
plt.title(layer_name)
plt.show()
If you’re running this code in the Google Colab version of this, you can use the interactive form to see how the weights change. In Google Colab, the code itself becomes interactive, and you can select the sparse model you want to retrieve. All weights with values of 0.0
will be color-coded weight. 1D Bias layers will be auto-scaled to the dimensions of the 2D plots.
It also helps that the Colab saves the weight files of all the sparse models we trained earlier.
So how does our sparse model look?
dataset = 'mnist'
sparsity = "0.5"
pruning = 'weight'
sparse_model = load_model('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)
As we can see, the model becomes more and more white as the smallest individual weights are set to 0.0
. What about unit-pruning?
dataset = 'mnist'
sparsity = "0.5"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}-model_k-{}_{}-pruned.h5'.format(dataset, sparsity, pruning))
visualize_model_weights(sparse_model)
Now that we’ve confirmed that the models are pruning exactly the way we want them to, let’s see how they perform. How much can we prune away from a model before performance starts to fail?
# Visualizing performance on MNIST
fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1)
plt.grid(b=None)
ax2 = ax1.twinx()
plt.grid(b=None)
plt.title('Test Accuracy as a function of k% Sparsity\nfor 4-hidden-layer MLP trained on MNIST')
ax1.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_acc_weight'].values,
'#008fd5', linestyle=':', label='Weight-pruning Acc')
ax1.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_acc_unit'].values,
'#008fd5', linestyle='-', label='Unit-pruning Acc')
ax2.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_loss_weight'].values,
'#fc4f30', linestyle=':', label='Weight-pruning Loss')
ax2.plot(sparsity_summary['k_sparsity'].values,
sparsity_summary['mnist_loss_unit'].values,
'#fc4f30', linestyle='-', label='Unit-pruning Loss')
ax1.set_ylabel('Accuracy (%)', color='#008fd5')
ax2.set_ylabel('Loss (categorical crossentropy)', color='#fc4f30')
ax1.set_xlabel('k% Sparsity')
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), shadow=True, ncol=2);
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.25), shadow=True, ncol=2);
plt.savefig('images/MNIST_sparsity_comparisons.png')

The performance curves for unit and weight pruning differ quite a lot on MNIST. Pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around . Even then, the accuracy does not begin to noticeably decrease until .
For unit-pruning, accuracy begins to fall earlier, around (with loss beginning to increase around ). Even then, both methods are able to effecively remove more than half of the network weights without any dramatic differences in test classification performance.
How does this work on FMNIST?

Even on FMNIST, which had a lower initial accuracy and higher initial loss, the same pattern emerges
Again, pruing weight matrices of dense matrices does not result in dramatic drops in accuracy or increases in loss until around . Even then, the accuracy does not begin to noticeably decrease until .
For unit-pruning, the differences come much earlier for FMNIST than in MNIST. Accuracy begins to fall around (with loss beginning to increase around ).
Why are the models behaving this way?
The two methods use different strategies of finding the least useful weights. The weight-pruning finds the absolute values () of individual weights within the weight matrices. The unit-pruning finds L2 norms across entire columns of weight matrices. This difference is in part a difference between fine-grained and coarse-grained weight pruning.
Both the weight-pruning and unit-pruning are forms of saliency mapping. The pruning functions are going through the weight matrices and identifying the weights that would have minimal impact if they were multiplied with input data being fed into the layer. Given that some weights are orders of magnitude smaller than the largest ones, the impact of removing them is minimal. What is suprising is how much of the weight matrices are made up of these low-saliency weights.
Why are we able to delete so much of the network without hurting performance?**
This demonstration shows that neural network parameter counts of trained networks can be decreased by over without hurting perofrmance. Still other researchers have demonstrated pruning techniques that can decrease parameter counts by over . This is still an open research question in machine learning.
This bears some similarity to the “optimal brain damage” hypothesis by Yann LeCun. In addition to being based on practical observations of differences in weight activations in neural networks, this bears similarity to ideas of how biological neural networks (like those in the neocortex) specialize. .
[Frankle & Carbin, 2019] proposed the “Lottery Ticket hypothesis”. This hypthesis articulates that dense, randomly-initialized, feed-forward networks contain subnetworks (“winning tickets”) that - when trained in isolation - reach test accuracy comparable to the original network in a similar number of iterations. It is named for the fact that these subnetworks, when initialized with random initialization, just happened to get initial weights that make training marticularly effective (i.e., they have won the “training lottery”) For the task of MNIST, both weight-pruning and unit-pruning suggest that trainable parameters is supervluous compared to the truly optimal subnetwork that captures the concise classification function.
In finding ways of scaling the weights (or columns of weights), and removing the closest to , we are trying to filter out all but the weights that make the input features maximally separable. [Zhang et al., 2018] compared the use of this kind of separation to Grassmannian Manifolds. Given that there is no universal formula to figuring out provably optimal subspace packings, the Grassmannian model of neural networks frames them as approximators of the solutions. Given that this is a less than optimal packing, some of the subspaces of the weights (e.g., those represented either as individual weights or vectors of weights) can easily be removed without impacting the performance of the classification. At some point, there are no more values to remove save for those satisfying the subspace packing needed for the maximum separability of the classes.
The latent variable model of neural network function would suggest that neural networks work by trying to learn the requisite latent variables (David ha’s blog post is a great example of this) needed to capture the epistemological essense of the given class. As David Ha’s example shows, very high resolution images of MNIST digits can be produced from a generator with a vector containing just real numbers.
For finding minimum subnetwork (as framed by the “lottery-ticket hypothesis”), the optimal subspace packing (as framed by the ), it is also possible that there is more room for improvement beyond weight-pruning. Weight pruning looks at individual weight values in isolation. While the Unit pruning seems to suggest that looking at weights in combination does not produce better results, this is only one such filter for choosing weights to set to . Uber AI tested out multiple masks for finding the “lottery-ticket” subnetworks.

It is also important to note that this pruning method only involves setting weight values below a certain threshold to and then testing immediately afterwards. The experiments shown in the code do not demonstrate successive rounds of successive weight pruning across multiple training rounds (like what TF2.0’s pruning API allows for)
Reducing the size and runtimes of our models
One of the added benefits of reducing weights to 0 is that we can actually reduce the overall size of the network. With unit/neuron-pruning specifically, we can clear out entire columns & rows of our weight matrices. Let’s create a function that, given a list of weight matrices of the layers, compreses all sparse matrices save for the last two layers
def compress_sparse_weigths(sparse_model_weight_list, unique_columns=None):
compressed_weight_list = []
for i in range(0,len(sparse_model_weight_list)-2,2):
# If this is just after the input layer, the matrix with input
# dimensions will be added on unchanged. Otherwise the matrix will only
# retain the columns that are not entirely 0s
if unique_columns==None:
kernel_weights = sparse_model_weight_list[i]
else:
kernel_weights = sparse_model_weight_list[i][unique_columns,:]
bias_weights = sparse_model_weight_list[i+1]
# a tuple of two arrays: 0th is row indices, 1st is cols
indices = np.nonzero(kernel_weights)
columns_non_unique = indices[1]
unique_columns = sorted(set(columns_non_unique))
kernel_weights = kernel_weights[:,unique_columns]
bias_weights = bias_weights[unique_columns]
# Adding the new kernel and bias weights to the new list
compressed_weight_list.append(kernel_weights)
compressed_weight_list.append(bias_weights)
# Adding the softmax and modifying the pre-softmax weight matrices
compressed_weight_list.append(sparse_model_weight_list[-2][unique_columns,:])
compressed_weight_list.append(sparse_model_weight_list[-1])
return compressed_weight_list
Now, let’s load the newly compressed list of weights into a keras model, compile it, and see how it compares to the original. Let’s try this out on the sparse model. Again, if you look at the Google Colab version of this, you can find an interactive version of the code below:
dataset = 'mnist'
sparsity = "0.95"
pruning = 'unit'
sparse_model = load_model('models/sparse_{}-model_k-{}_unit-pruned.h5'.format(dataset, sparsity, pruning))
# List of weights from the loaded file
sparse_weight_list = sparse_model.get_weights()
# Creating a list of dense weights from the sparse weight list
compressed_weight_list = compress_sparse_weigths(sparse_weight_list)
# Getting a 4-layer neural network with layer dimensions that match those of
# the new dense weight matrices (e.g., at 50% sparsity, layer size of 1000 is
# reduced to 500)
compressed_model = build_model_arch(input_shape, num_classes, sparsity=float(sparsity))
compressed_model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam', metrics=['accuracy'])
compressed_model.set_weights(compressed_weight_list)
# Printing the model summaries and comparing the weights of the original sparse
# model and the dense model
print(compressed_model.summary())
print('\nVISUALIZING ORIGINAL SPARSE MODEL')
visualize_model_weights(sparse_model)
print('\nVISUALIZING COMPRESSED MODEL')
visualize_model_weights(compressed_model)
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_10 (Dense) (None, 50) 39250
_________________________________________________________________
dense_11 (Dense) (None, 50) 2550
_________________________________________________________________
dense_12 (Dense) (None, 25) 1275
_________________________________________________________________
dense_13 (Dense) (None, 10) 260
_________________________________________________________________
dense_14 (Dense) (None, 10) 110
=================================================================
Total params: 43,445
Trainable params: 43,445
Non-trainable params: 0
_________________________________________________________________
None
VISUALIZING ORIGINAL SPARSE MODEL
Our sparse model only contains parameters versus the of the original model. Granted, this is one of the more poorly performing ones, but this demonstrates how much we can reduce the memory costs of already-trained model.
If you’d like to learn more about neural network pruning, you can look at some of the resources below.
References/Resources/Further Reading
- “Pruning deep neural networks to make them fast and small”, Jacob’s Computer Vision and Machine Learning blog
- Molchanov, Pavlo, et al. “Pruning convolutional neural networks for resource efficient inference.” arXiv preprint arXiv:1611.06440 (2016).
- Li, Hao, et al. “Pruning filters for efficient convnets.” arXiv preprint arXiv:1608.08710 (2016).
- Saliency map. Wikipedia. https://en.wikipedia.org/wiki/Saliency_map
- http://yann.lecun.com/exdb/publis/pdf/lecun-90b.pdf
- Frankle, Jonathan, and Michael Carbin. “The lottery ticket hypothesis: Finding sparse, trainable neural networks.” arXiv preprint arXiv:1803.03635 (2018).
- Zhang, Jiayao, Guangxu Zhu, and Robert W. Heath Jr. “Grassmannian Learning: Embedding Geometry Awareness in Shallow and Deep Learning.” arXiv preprint arXiv:1808.02229 (2018).
- Generating Large Images From Latent Vectors. David Ha’s blog
Cited as:
@article{mcateer2019obd,
title = "Optimal Brain Damage",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2019",
url = "https://matthewmcateer.me/blog/optimal-brain-damage/"
}
If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me]
and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.
See you in the next post 😄