Influence Functions from scratch
Finding which data instances stand out the most
Debugging Machine Learning models is still a very important area in Machine Learning. One of the issues is that black-box models obviously don’t return a full stack trace of which components are going wrong. Even the long stack-traces of C++ would be preferable to the cloudiness of figuring out why a classifier isn’t working.
One of the most popular papers at ICML 2017 described using influence functions, a common tool in robust statistics, for making machine learning models more interpretable. This post is an attempt at explaining the paper further by re-implementing these influence functions from scratch.
Before going into this article, I recommend taking at least a cursory glance at the original paper. It may be intimidating at first, but at least reading it will make the clarification that comes next all the more mind-blowing
What are Influence Functions?
Influence functions are kinds of example-based statistical functions. Given a data point, a model, and the rest of the data that went into creating the model, how much of a positive or negative impact did the data have on the model’s current state?
An Ultra-simple example
First, let’s define a Linear model. Normally we could just obtain a simple Keras model, but this is to show the crucial components of the model that we need to use influence functions.
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import tensorflow as tf
%config InlineBackend.figure_format = 'retina'
class LinearModel:
def __init__(self, data, target):
self.data = data
self.target = target
self._prediction = None
self._optimize = None
self._error = None
self._gradients = None
self._hessians = None
self._params = None
@property
def params(self):
if self._params is None:
data_dim = int(self.data.get_shape()[1])
target_dim = int(self.target.get_shape()[1])
# we construct one variable for both weight and bias
self._params = tf.get_variable(name='params', shape=[data_dim+target_dim])
return self._params
@property
def prediction(self):
if self._prediction is None:
data_dim = int(self.data.get_shape()[1])
W = tf.reshape(self.params[:-1], [data_dim,1])
b = self.params[-1]
self._prediction = tf.matmul(self.data, W) + b
return self._prediction
@property
def error(self):
if self._error is None:
self._error = tf.losses.mean_squared_error(labels = self.target,
predictions = self.prediction)
return self._error
@property
def optimize(self, lr = 0.1):
if self._optimize is None:
train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(self.error)
self._optimize = train_op
return self._optimize
@property
def gradients(self):
if self._gradients is None:
self._gradients = tf.gradients(self.error, self.params)
return self._gradients
@property
def hessians(self):
if self._hessians is None:
self._hessians = tf.hessians(self.error, self.params)
return self._hessians
As we can see have data, we have parameters, and we have predictions and errors. We also have gradients. This all seems like pretty standard stuff, but I want to draw your attention to the last method: the hessians. For our influence evaluation to work, we want to be absolutely certain of.
We now have our model, let’s define our data. We’re going to define a simple linear model: . We’ve added a small noise function to it, but for now we’ll settle for our linear function.
However, our data isn’t going to fit the equation above to the letter. We’re going to add some artificial perturbations to it. We’ll set values for and to be 100, far outside the range we’re using.
def true_function(x, noise = True):
y = -5*x+5
if noise:
y += np.random.normal(scale=0.1, size = x.shape)
return y
X_data = np.arange(-5,5,0.5).reshape((-1,1))
Y_data = true_function(X_data) # linear function
#perturbation
Y_data[1] = 100.0
Y_data[-5] = 100.0
Y_data = Y_data.reshape((-1,1))
X_test = 3.2*np.ones((1,1))
Y_test = true_function(X_test, noise=True)
Let’s see how our basic linear model performs on this.
EPOCHS = 100
R = 200
num_train_points = X_data.shape[0]
tf.reset_default_graph()
x = tf.placeholder(dtype=tf.float32, shape=(None, 1))
y_true = tf.placeholder(dtype=tf.float32, shape=(None, 1))
model = LinearModel(x, y_true)
train_op = model.optimize
loss_op = model.error
param_op = model.params
gradient_op = model.gradients
hessian_op = model.hessians
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
for e in range(EPOCHS):
fd = {x: X_data, y_true: Y_data}
_, loss_epoch = sess.run([train_op, loss_op], feed_dict = fd)
p = sess.run(param_op)
s_test = 0
for r in range(R):
v = sess.run(gradient_op, feed_dict = {x:X_test, y_true:Y_test})[0]
s_test_j = v
for j in range(num_train_points):
fd = {x:X_data[j].reshape((-1,1)), y_true:Y_data[j].reshape((-1,1))}
hess_param = sess.run(hessian_op, feed_dict = fd)[0]
hess_param = np.diag(hess_param)
s_test_j = v + np.matmul((np.identity(2)-hess_param),s_test_j)
s_test += s_test_j
s_test = s_test/R
importance = []
for j in range(num_train_points):
fd = {x:X_data[j].reshape((-1,1)), y_true:Y_data[j].reshape((-1,1))}
grad_param = sess.run(gradient_op, feed_dict = fd)[0]
importance.append(-np.matmul(s_test,grad_param))
importance = np.asarray(importance)
print('Loss: {}'.format(((p[0]*X_test+p[1]-Y_test)**2)[0]))
Loss: [75.19171343]
We’ve trained our model, and now we’ve gotten the loss output. 76
is pretty bad. In this case we’d like to know where we went wrong with creating our linear model.
This is where the influence functions come in. We have our list of feature importances, and we can use this to mark points based on how heavily they influence our linear model.
cm = plt.cm.get_cmap('RdYlBu')
sc = plt.scatter(X_data.flatten(), Y_data.flatten(), label='train',c=-importance, cmap=cm)
plt.scatter(X_test.flatten(), Y_test.flatten(),marker='+',label='test',c='r')
plt.plot(X_data, p[0]*X_data+p[1])
plt.colorbar(sc)
plt.show()

The intercept of our model is heavily skewed by the presence of both of our points at 100. However, the effects of these points are not equal. The point at has by far the most negative of any of the training data, while the blue mark at has the largest positive influence. Most of the other dots have very neutral colors. What can we interpret from this, we can use this plot to show that our regular data and the had the biggest effects on the linear model approaching the correct underlying equation. The red one, by constrast, influenced our model to be skewed away from the true distribution.
Telling which data points are causing trouble? If you’ve worked in ML for any length of time, you’ve probably yearned for some kind of debugging tool like this. This is precisely what Pang Wei Koh and company demonstrated in their paper.
An image-processing example
However, influence functions aren’t quite ready to be used for complicated machine learning models right out of the box. Influence functions have been around for decades, but there’s a reason it wasn’t until 2017 that a team gave serious consideration to using them for machine learning models. That’s because some of the requisite factors we need are expensive to compute. Machine learning also presents quite a lot of algorithms where the space is not differentiable.
# Load the MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Gets number of 7 and 1 labels in MNIST
num_test = len(np.argwhere(y_test == 7)) + len(np.argwhere(y_test == 1))
# Creating an empty array of blank data, with all labels being 1
test = np.zeros((num_test, 784))
test_label = np.zeros((num_test, 1))
We then loop through the entire set of test samples for each test training example, we’re normalizing the pixel values so they fall in the range instead of (This is not relevant to the influence functions, specifically. It’s just a recommended practice for image processing. We’re also specifically getting examples where the true labels of the dataset are equal to 1 or 7.
count = 0
for i in range(len(x_test)):
if y_test[i] == 7:
test[count] += (x_test[i] / np.linalg.norm(x_test[i])).reshape(784)
test_label[count] += np.array([1.0])
count += 1
if count == num_test:
break
elif y_test[i] == 1:
test[count] += (x_test[i] / np.linalg.norm(x_test[i])).reshape(784)
test_label[count] += np.array([0.0])
count += 1
if count == num_test:
break
# We repeat this process on the training data
num_train = len(np.argwhere(y_train == 7)) + len(np.argwhere(y_train == 1))
data = np.zeros((num_train, 784))
data_label = np.zeros((num_train, 1))
count = 0
for i in range(len(x_train)):
if y_train[i] == 7:
data[count] += (x_train[i] / np.linalg.norm(x_train[i])).reshape(784)
data_label[count] += np.array([1.0])
count += 1
if count == num_train:
break
elif y_train[i] == 1:
data[count] += (x_train[i] / np.linalg.norm(x_train[i])).reshape(784)
data_label[count] += np.array([0.0])
count += 1
if count == num_train:
break
# We're going to set a test index. This is an index within the training data
# where we want to see how well our classifier performs on it
test_index = 157
plt.imshow(test[test_index].reshape(28, 28),cmap="inferno", interpolation="nearest")

We want to define a few important functions, first.
First, we have our hinge loss function. For an intended output and a classifier score , the hinge loss of the prediction is defined as
Hinge loss is a loss function commonly used for Support vector machines, though not exclusive to SVMs. The hinge loss is a convex function, so many of the usual convex optimizers used in machine learning can work with it.
def hinge_loss(logits, labels, dty=tf.float64):
margin = tf.multiply(tf.cast(labels, dtype=dty), logits)
log_loss = tf.maximum(tf.constant(0, dtype=dty), 1 - margin)
return tf.reduce_mean(log_loss)
However, Our influence function methods only work on differentiable loss functions. Hinge loss unfortunately for us is not differentiable.
The authors have pointed out a solution to this. We can create a very close approximation of hinge loss that creates a differentiable region right at the margin.
def smooth_hinge_loss(logits, labels, t=1e-3, dty=tf.float64):
margin = tf.multiply(tf.cast(labels, dtype=dty), logits)
exponents = (1 - margin) / t
max_elems = tf.maximum(exponents, tf.zeros_like(exponents))
log_loss = t * (max_elems + tf.log(tf.exp(exponents - max_elems) + tf.exp(tf.zeros_like(exponents) - max_elems)))
return tf.reduce_mean(log_loss)
Our accuracy operation tells us what fraction of our predictions are correctly labelled. Again, standard classifier stuff.
def get_accuracy_op(logits, labels, sigmoid=True, dty=tf.float64):
if sigmoid:
correct_prediction = tf.equal(tf.cast(L > 0.5, tf.int32), tf.cast(labels, tf.int32))
accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
return accuracy / tf.shape(labels)[0]
else:
preds = tf.sign(logits)
correct = tf.reduce_sum(tf.cast(tf.equal(preds, tf.cast(labels, dty)), tf.int32))
return correct / tf.shape(labels)[0]
Now, here is where we get to the math of the influence functions. We have our Hessian Vector product.
The code below was borrowed from the hessian vector product code in tensorflow. Strangely enough, this is not a function you would find in the Tensorflow documentation
def hessian_vector_product(ys, xs, v, do_not_sum_up=True):
# Validate the input
length = len(xs)
if len(v) != length:
raise ValueError("xs and v must have the same length.")
# First backprop
grads = tf.gradients(ys, xs)
# grads = xs
assert len(grads) == length
elemwise_products = [
math_ops.multiply( grad_elem, array_ops.stop_gradient(v_elem)) for grad_elem, v_elem in zip(grads, v) if grad_elem is not None
]
# Second backprop
if do_not_sum_up:
seperate = []
for i in range(length):
seperate.append(tf.gradients(elemwise_products[i], xs[i])[0])
grads_with_none = seperate
else:
grads_with_none = tf.gradients(elemwise_products, xs)
return_grads = [
grad_elem if grad_elem is not None else tf.zeros_like(x) for x, grad_elem in zip(xs, grads_with_none)
]
return return_grads
Let’s set up our classifier. This time, we’re going beyond just a regular linear regularization model.
tf.reset_default_graph()
nb_clases = 1
dty = tf.float64
scale = 1e0
damping = 1e-2
I = tf.eye(784, dtype=dty)
we = {}
we[0] = 784 * 1
w1 = tf.get_variable("w1", [we[0]], initializer=tf.initializers.truncated_normal, dtype=dty)
w1 = w1 / tf.norm(w1)
w1 = w1 / tf.constant(1e6, dtype=dty)
params = [w1]
Hess = tf.placeholder(dty, shape=[w1.get_shape()[0], w1.get_shape()[0]], name="inverse")
cur_in = tf.placeholder(dty, shape=[w1.get_shape()[0], w1.get_shape()[0]], name="inverse")
v_cur_est = [tf.placeholder(dty, shape=a.get_shape(), name="v_cur_est" + str(i)) for i, a in enumerate(params)]
hessian_vector_val_place = [tf.placeholder(dty, shape=a.get_shape()[0], name="hessian_vector_val_place" + str(i)) for i, a in enumerate(params)]
Test = [tf.placeholder(dty, shape=a.get_shape(), name="v_cur_est" + str(i)) for i, a in enumerate(params)]
X = tf.placeholder(dty, [None, 784], name="X")
Y = tf.placeholder(dty, [None, nb_clases], name="Y")
L = tf.matmul(X, tf.reshape(w1, [-1, 1]))
L = tf.nn.sigmoid(L)
Z = tf.placeholder(dty, [None, 784], name="Z")
Y_of_Z_train = tf.placeholder(dty, [None, nb_clases], name="Y_of_Z_train")
L_Z = tf.matmul(Z, tf.reshape(w1, [-1, 1]))
L_Z = tf.nn.sigmoid(L_Z)
Z_test = tf.placeholder(dty, [None, 784], name="Z_test")
Y_test = tf.placeholder(dty, [None, nb_clases], name="Y_test")
L_test = tf.matmul(Z_test, tf.reshape(w1, [-1, 1]))
L_test = tf.nn.sigmoid(L_test)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(L + 1e-6) + (1 - Y) * tf.log(1 - L + 1e-6), 1))
cost += damping * tf.nn.l2_loss(params) # gradient vanishing
upweighting_loss = tf.reduce_mean(-tf.reduce_sum(Y_of_Z_train * tf.log(L_Z + 1e-6) + (1 - Y_of_Z_train) * tf.log(1 - L_Z + 1e-6), 1))
upweighting_loss += damping * tf.nn.l2_loss(params) # gradient vanishing
Test_loss = tf.reduce_mean(-tf.reduce_sum(Y_test * tf.log(L_test + 1e-6) + (1 - Y_test) * tf.log(1 - L_test + 1e-6), 1))
Test_loss += damping * tf.nn.l2_loss(params) # gradient vanishing
# grads
test_grad = tf.gradients(Test_loss, params)
train_grad = tf.gradients(upweighting_loss, params)
# Hessians
true_hess = tf.hessians(cost, params)
# H dot v
hessian_vector_val = hessian_vector_product(cost, params, v_cur_est, True)
# H inverse
estimation_IHVP = [ g + cur_e - HV / scale for g, HV, cur_e in zip(Test, hessian_vector_val, v_cur_est)]
estimation_inverse = (I + cur_in - tf.matmul(Hess, cur_in) / scale)
train_op = tf.train.AdamOptimizer(1e-2).minimize(cost)
accuracy = get_accuracy_op(L, Y)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(1501):
for i in range(26):
batch_xs, batch_ys = (data[i * 500 : (i + 1) * 500], data_label[i * 500 : (i + 1) * 500])
_ = sess.run([train_op], feed_dict={X: batch_xs, Y: batch_ys})
if epoch % 500 == 0 and epoch > 0:
c = sess.run(accuracy, feed_dict={X: test, Y: test_label})
a = sess.run(accuracy, feed_dict={X: data, Y: data_label})
ccc = sess.run(cost, feed_dict={X: test, Y: test_label})
print("Train accuracy: ", a, " Test accuracy: ", c, " cost: ", ccc)
print("sum of parameters: ", sess.run(tf.nn.l2_loss(params)).sum())
Train accuracy: 0.9535634658260936 Test accuracy: 0.951918631530282 cost: 0.6931450106795913
Train accuracy: 0.9537941108633813 Test accuracy: 0.9523809523809523 cost: 0.6931450102818402
Train accuracy: 0.9537941108633813 Test accuracy: 0.9523809523809523 cost: 0.6931450102817838 sum of parameters: 5e-13
Hessian Inverse computing
Numpy direct inverse
It’s possible to calculate the hessian inverse directly using numpy’s np.linalg.inv
true_h = sess.run(true_hess[0], feed_dict={X: data, Y: data_label})
inv = np.linalg.inv(true_h)
np.linalg.norm(true_h)
0.312246772578134
but for a lot of parameters, this is going to take a while. This is why we use the Inverse-Hessian by Lissa algorithm
Lissa Algorithms
cur_estimate = sess.run(I)
start_time = time.time()
for j in range(5001):
cur_estimate = sess.run(estimation_inverse, feed_dict={Hess: true_h, cur_in: cur_estimate})
inverse = cur_estimate / scale
duration = time.time() - start_time
print("Inverse Hessian by Lissa: took %s minute %s sec" % (duration // 60, duration % 60))
Inverse Hessian by Lissa: took 0.0 minute 12.466225385665894 sec
So our Inverse Hessian by Lissa took a full 12.5 seconds to calculate. What kind of error do we obtain from these processes?
print("Lissa Identity Error: ", abs(np.dot(true_h, inverse) - np.eye(784)).sum())
print("Numpy Identity Error: ", abs(np.dot(true_h, inv) - np.eye(784)).sum())
print("Inverse Error: ", abs(inverse - inv).sum())
Lissa Identity Error: 7.356780947536368e-12
Numpy Identity Error: 1.0263673345784422e-12
Inverse Error: 8.125032173888674e-10
IHVP calculation
test_val = sess.run(
test_grad,
feed_dict={
Z_test: test[test_index].reshape((1, 784)),
Y_test: test_label[test_index].reshape((1, 1)),
},
)
IHVP = np.dot(test_val[0], inv)
start_time = time.time()
cur_estimate = test_val.copy()
feed1 = {place: cur for place, cur in zip(Test, test_val)}
for j in range(5001):
feed2 = {place: cur for place, cur in zip(v_cur_est, cur_estimate)}
r = np.random.randint(len(data), size=[1024])
cur_estimate = sess.run(
estimation_IHVP,
feed_dict={
X: data[r],
Y: data_label[r],
**feed1,
**feed2,
},
)
if j % 2500 == 0 and j > 0:
print(cur_estimate[0][0])
inverse_hvp = [b / scale for b in cur_estimate]
duration = time.time() - start_time
print("Inverse HVP by HVPs+Lissa: took %s minute %s sec" % (duration // 60, duration % 60))
print(abs(IHVP - inverse_hvp[0]).sum())
2.5299487232785195e-13
2.529948723308861e-13
Inverse HVP by HVPs+Lissa: took 0.0 minute 27.68460702896118 sec
1.9075249369356082
Now we can apply both our Numpy IHVP and our Lissa IHVP to our model.
First, the Lissa IHVP…
s = time.time()
val_lissa = []
for i in range(num_train):
if data_label[i][0] == test_label[test_index][0]:
train_grad_loss_val = sess.run(
train_grad,
feed_dict={
Z: data[i].reshape((1, 784)),
Y_of_Z_train: data_label[i].reshape((1, 1)),
},
)
val_lissa.append([i, np.dot(np.concatenate(inverse_hvp), np.concatenate(train_grad_loss_val))])
duration = time.time() - s
print("Multiplying by %s train examples took %s minute %s sec" % (1, duration // 60, duration % 60))
val_lissa = sorted(val_lissa, key=lambda x: x[1])
Multiplying by 1 train examples took 0.0 minute 10.824917316436768 sec
And then the Numpy IHVP.
s = time.time()
val = []
for i in range(num_train):
if data_label[i][0] == test_label[test_index][0]:
train_grad_loss_val = sess.run(
train_grad,
feed_dict={
Z: data[i].reshape((1, 784)),
Y_of_Z_train: data_label[i].reshape((1, 1)),
},
)
val.append([i, np.dot(IHVP, np.concatenate(train_grad_loss_val))])
duration = time.time() - s
print("Multiplying by %s train examples took %s minute %s sec" % (1, duration // 60, duration % 60))
val = sorted(val, key=lambda x: x[1])
Multiplying by 1 train examples took 0.0 minute 10.686049222946167 sec
So we’ve calculated the IHVP for all the training examples. let’s look at what our most influential examples look like. We’ll also see how well our Lissa IHVP and Numpy IHVP compare
print("Numpy IHVP",
"\nMost Harmful Indexes", [val[i][0] for i in range(0, 6)],
"\nMost Helpful Indexes", [val[i][0] for i in range(-1, -7, -1)])
print("Lissa IHVP",
"\nMost Harmful Indexes", [val_lissa[i][0] for i in range(0, 6)],
"\nMost Helpful Indexes", [val_lissa[i][0] for i in range(-1, -7, -1)])
Numpy IHVP
Most Harmful Indexes [9309, 6463, 11720, 10878, 10648, 7869]
Most Helpful Indexes [3441, 932, 2799, 3200, 1147, 9394]
Lissa IHVP
Most Harmful Indexes [9309, 6463, 11720, 10878, 10648, 7869]
Most Helpful Indexes [3441, 932, 2799, 3200, 1147, 9394]
At this point, you’re probably wondering what these images actually look like. What qualitative differences would we see between the most helpful and the most harmful images.
fig = plt.figure(figsize=(16, 4))
image_details = [
["Test_image", test[test_index]],
["Harmful_image1", data[val[0][0]]],
["Harmful_image2", data[val[1][0]]],
["Harmful_image3", data[val[2][0]]],
["Harmful_image4", data[val[3][0]]],
["Harmful_image5", data[val[4][0]]],
["Harmful_image6", data[val[5][0]]]]
for i in range(1, 8):
ax = plt.subplot(1, 7, i)
plt.imshow(image_details[i-1][1].reshape(28, 28), cmap="inferno", interpolation="nearest")
ax.set_title(image_details[i-1][0])
plt.tight_layout()
plt.show()
We’ve got a pretty wide variety of & shapes here. We can certainly see how some of these might have made it hard to distinguish 7s from 1s. This can be either due to the slope at the top being irregularly-shaped, or perhaps the line at the bottom of a latin ‘1’ being confused with the line through the middle of a latin ‘7’. What about the most helpful instances?
fig = plt.figure(figsize=(16, 4))
image_details = [
["Test_image", test[test_index]],
["Positive_image1", data[val[-1][0]]],
["Positive_image2", data[val[-2][0]]],
["Positive_image3", data[val[-3][0]]],
["Positive_image4", data[val[-4][0]]],
["Positive_image5", data[val[-5][0]]],
["Positive_image6", data[val[-6][0]]]]
for i in range(1, 8):
ax = plt.subplot(1, 7, i)
plt.imshow(image_details[i-1][1].reshape(28, 28), cmap="inferno", interpolation="nearest")
ax.set_title(image_details[i-1][0])
plt.tight_layout()
plt.show()
By contrast, the most helpful are much more uniform. There is far less variance between the items in this list. It appears that in the case of this model.
Concluding thoughts
So we’ve gone through all the math of how Pang Wei Koh et al.’s influence functions work. In principle, these are very similar to the influence functions that are normally used in robust statistics. The main differences are that we need to make some special accomodations for non-differentiable loss functions, and we need a better method of estimating the IHVP. For calculating the most influential values for a given machine learnin task, the influence calculation at a high level is structured as a function .
There are a few techniques we can use that weren’t covered above that we can use to investigate the influence further. We gave the evaluation of confusing digits in MNIST as a good example, but other use-cases include:
- Detecting mislabels: If one were to run the influence functions for each test data instance, and repeat the calculation over the entire model and training set for all the test data, one could conceivably construct an automatic mislabel detector.
- Determining what a model is sensitive to: As Pand wei Koh demonstrated in their paper, influence function results for support vector machines and CNNs can have different influence scores, even in cases where
- Determining differences in influence between parts of a model: We applied all the parameters in our model to the Influence detector. However,
As useful as these are, there are still quite a few hurdles blocking widespread practical use of influence functions:
- Influence functions for NNs are approximate. I.e., they may produce subtly different results and scores for all but the most influential (positive or negative) instances.
- There are still no widely-accepted best-practices for influence functions.
If popular demand is high enough, I may write a few future posts about expanding this technique to other machine learning tasks.
Cited as:
@article{mcateer2018infbasics,
title = "Influence Functions from scratch",
author = "McAteer, Matthew",
journal = "matthewmcateer.me",
year = "2018",
url = "https://matthewmcateer.me/blog/basics-of-influence-functions/"
}
If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me]
and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.
See you in the next post 😄