Gaussian KDE from Scratch

Understanding KDE inside and out

In unsupervised learning, traditionally known as density modeling, one usually constructs a probabilistic model p(x)p(x). The fitting of the model is performed on a training set, and its generalization performance evaluated on a separate test set. The hyper-parameters of the model are typically tuned on a validation set. As opposed to supervised learning, unsupervised learning is arguably more challenging as p(x)p(x) is typically much more complicated than p(yx)p(y \vert x).

Among many possible choices of p(x)p(x), one of the simplest is the well- known good-and-old-fashioned “kernel density estimator”. It is non-parametric in the sense that p(x)p(x) “memorizes” the entire training set. The scoring function is usually defined by a Gaussian kernel. This work borrows such a basic idea from the standard kernel density estimator and formulates it with a mixture of Gaussian distributions.

Kernel density estimation (KDE) is in some senses an algorithm which takes the “mixture-of-Gaussians” idea to its logical extreme: it uses a mixture consisting of one Gaussian component per point, resulting in an essentially non-parametric estimator of density.

Simplified 1D demonstration of KDE, which you are probably used to seeing

I’m a big believer in the idea that the best way to learn about machine learning and probabilistic programming algorithms is to be able to implement them from scratch. Given the relative ease with which one can construct KDE for 1D data, I’m going to one-up those tutorials by showing how to apply this to image data.

import argparse
import csv
import decimal
import logging
import os
import pickle
import sys
import time
import typing # Typing introduces slight startup time penalties
from decimal import Decimal, getcontext
from functools import reduce
from math import pi
from operator import mul
from typing import Any, Tuple
import as cm
import numpy as np
import numpy.matlib
from matplotlib import pyplot as plt

logging.basicConfig(level=os.environ.get("LOGLEVEL", "DEBUG"))
logger = logging.getLogger(__name__)
np.random.seed(1573368177) # Numpy Random Seed for reproducibility
# Seed derived from timestamp for Saturday, November 9, 2019 10:42:57 PM GMT-08:00

#@markdown Input dataset
dataset = "MNIST" #@param ["MNIST", "CIFAR100"]

For reproducibility purposes, the timestamp form of the date Saturday, November 9, 2019 10:42:57 PM GMT-08:00 was used as a random seed for Numpy. Using this seed value was preferred over more commonly used seed values in machine learning literature (0, 1, 5, 42, 99, 123, 1234, 1337, 12321, 1234567890, or other numbers that might be deemed inappropriate by reviewers)

We use two commonly-used datasets. The first, MNIST, is a collection of 28×28×128 \times 28 \times 1 gray-scale images belonging to 10 categories lecun2010mnist. The second, CIFAR100, is a collection of 32×32×332 \times 32 \times 3 RGB images that can be split into either 10 or 100 categories Krizhevsky09learningmultiple. The visualization of these datasets confirms not only that we’ve downloaded the correct data, but also that we’ve done the preprocessing correctly.

!tar xvzf cifar-100-python.tar.gz
!gunzip mnist.pkl.gz
def extract_pickle(file):
    with  open(file, "rb") as f:
    u = pickle._Unpickler(f)
    u.encoding = "latin1"  #'iso-8859-1'
    p = u.load()
    return p

def load_mnist(data_dir):
    unpickled_data = extract_pickle(data_dir)
    x_train_raw = unpickled_data[0][0] # Scaling pixel values between 0 and 1
    np.random.shuffle(x_train_raw) # Shuffling the original training set
    x_train = x_train_raw[0:10000]
    x_val = x_train_raw[10000:20000] # Creating the validation set from the split
    x_test = unpickled_data[2][0] # Using the original 10K test set as it is.
    # print(np.ptp(x_test))
    # x_test values already have a range of 0.99609375, and thus do not need rescaling
    return x_train, x_val, x_test

def load_cifar(train_dir, test_dir):
    unpickled_train, unpickled_test = (extract_pickle(train_dir),
    x_train_raw = unpickled_train["data"].astype(np.float64)
    x_train_raw = x_train_raw / 255  # Scaling pixel values between 0 and 1
    np.random.shuffle(x_train_raw) # Shuffling the original training set
    x_train = x_train_raw[0:10000]
    x_val = x_train_raw[10000:20000] # Creating the validation set from the split
    x_test = unpickled_test["data"].astype(np.float64)
    # Using the original 10K test set as it is.
    x_test = x_test / 255  # Scaling pixel values between 0 and 1
    return x_train, x_val, x_test

Visualization Check

def vis_check(image_data, num_img_edge=20, pixel_rows=28, pixel_cols=28,
    num_img_square = num_img_edge ** 2
    dataset_examples = (
        .reshape(num_img_square, image_channels, pixel_rows, pixel_cols)
        .reshape(num_img_edge, num_img_edge, image_channels, pixel_rows, pixel_cols)
        .transpose(0, 1, 3, 4, 2)
    if image_channels == 3:
        img = dataset_examples.swapaxes(1, 2).reshape(
            pixel_rows * num_img_edge, pixel_cols * num_img_edge, image_channels
    elif image_channels == 1:
        img = (
            dataset_examples.swapaxes(1, 2)
                pixel_rows * num_img_edge, pixel_cols * num_img_edge, image_channels
            .reshape(pixel_rows * num_img_edge, pixel_cols * num_img_edge)
    fig = plt.imshow(img, cmap=cm.get_cmap("gray"))
    return plt

getcontext().prec = 7  # Precision for decimal

if dataset in ["MNIST", "mnist"]:
    x_train, x_val, x_test = load_mnist(data_dir="mnist.pkl")
    img = vis_check(x_train, num_img_edge=20, pixel_rows=28,
    pixel_cols=28, image_channels=1)
    img.savefig("mnist.png", dpi=500)
elif dataset in ["CIFAR100", "cifar100", "CIFAR", "cifar"]:
    x_train, x_val, x_test = load_cifar(train_dir="cifar-100-python/train",
    img = vis_check(x_train, num_img_edge=20, pixel_rows=32,
    pixel_cols=32, image_channels=3)
    img.savefig("cifar100.png", dpi=500)
    logger.error('\tPlease pass dataset_name as arg: MNIST or CIFAR')

The result of running the above code is 400 (20×2020 \times 20 grid) representative images from MNIST, and CIFAR100. These represent data post-preprocessing and scaling. This serves not just as a source of visual intuition for the data, but a confirmation of correct preprocessing.



Our Model

Here’s what we’re actually going to be coding:

Given a dataset that contains two splits DARk×d\mathcal{D}_A \in R^{k \times d} and DBRm×d\mathcal{D}_B \in R^{m \times d}, we compute the log-likelihood of DB\mathcal{D}_B under DA\mathcal{D}_A with the following probability density function.

logp(x)=logi=1kp(zi)p(xzi)\log p(x) = \log \sum^{k}_{i=1} p(z_i)p(x \vert z_i)

where xRdx \in \mathcal{R}^d and ziz_i is discrete.

The above formulation assumes the probability of xx in terms of a mixture of conditional distributions. Here we call p(zi)p(z_i) the probability of its ithi^\mathrm{th} mixing component, and p(xzi)p(x \vert z_i) the probability of xx under the ithi^\mathrm{th} component.

To simplify even more, let us further assume the following

p(zi)=1kp(z_i) = \frac{1}{k} and p(xzi)=j=1dp(xjzi)p(x \vert z_i) = \prod^{d}_{j=1}p(x_j \vert z_i) where p(xjzi)=12πσi2exp((xjμi,j)22σi2)p(x_j \vert z_i) = \frac{1}{\sqrt{2 \pi \sigma^2_i}} \exp(-\frac{(x_{j}- \mu_{i,j})^2}{2 \sigma^2_i})

and μRk×d\mu \in R^{k \times d}. To simplify further, we also assume that all p(zi)p(\cdot \vert z_i) Gaussian components share the same σ\sigma. Therefore Equ. (1) can be written as logp(x)=logi=1kexp{log1k+j=1d[(xjμi,j)22σ212log(2πσ2)]}\log p(x) = \log \sum^{k}_{i=1} \exp\{\log \frac{1}{k} + \sum^{d}_{j=1}[-\frac{(x_{j}- \mu_{i,j})^2}{2 \sigma^2} - \frac{1}{2} \log(2 \pi \sigma^2)]\}

With Equ. (5), one can compute for each example in DB\mathcal{D}_B its log-probability by considering all kk examples in DA\mathcal{D}_A with μi,jxi,jA\mu_{i,j} \equiv x^{A}_{i,j} where xADARk×dx^A \equiv \mathcal{D}_A \in R^{k \times d}.

Finally the mean of the log-probability on DB\mathcal{D}_B can be written as LDB=1mlogi=1mp(xiB)=1mi=1mlogp(xiB)\mathcal{L}_{\mathcal{D}_B} = \frac{1}{m} \log \prod^{m}_{i=1} p(x^B_i) = \frac{1}{m} \sum^{m}_{i=1} \log p(x^B_i)

The below code reconstructs these equations in a Python 3.7 environment with just the standard libraries and Numpy oliphant2006guide van2011numpy.

Now, at first, the presence of the \sum and \prod operators may make the runtime seem daunting. If you are used to leetcode-style interviews, many of these just scream O(n2)O(n^2) or O(n3)O(n^3) or O(n4)O(n^4) or something nastier. This is why we add Numpy to the mix of standard libraries we will be using. At multiple steps we can convert the data to vectors in lieu of creating multi-nested for loops. This reduces the risk of run-times approaching quadratic or cubic run-times where unnecessary. We can optimize this even further by reducing the problem of using extraneous digits with numpy.float64 or numpy.float32 types. We use the standard Decimal module to reduce precision to the bare minimum needed 7 digits. In many ways, this is similar to the principle behind using TPUs instead of GPUs for machine learning algorithms.

def kde_scratch(sigma, D_A, D_B):
    getcontext().prec = 7
    mu, prob_x = D_A.astype(np.float64), 0
    len_D_A, len_D_B, d = len(D_A), len(D_B), len(D_A[0])
    t_1 = -Decimal(0.5 * d) * Decimal(2 * pi * (sigma ** 2)).ln()
    log_k = Decimal(len_D_A).ln()
    for i in  range(0, len_D_A):
        t_0 = np.sum((-((np.matlib.repmat(D_B[i], len_D_A, 1).astype(np.float64) - mu) ** 2)) / (2 * (sigma ** 2)), axis=1)
        elements_sum = 0
        for j in  range(0, len_D_B):
            elements_sum += Decimal(t_0[j]).exp()
        prob_x += t_1 - log_k + elements_sum.ln()
    return prob_x / len_D_B

All these experiments are runnable in a Google Colab environment, or for that matter on any n1-highmem-2 instance with an Intel(R) Xeon(R) 2.30GHz CPU and 16GB DDR4 RAM.

While it is possible to optimize the CPU-code further to use all of the available 6 CPUs using the multiprocessing module, these further optimizations were omitted for the sake of being able to run the code on other machines (like, whatever environment you have).

Running the actual training

To summarize what covered above:

A KDE was trained on both the MNIST and CIFAR100 datasets, with relevant preprocessing beforehand (shuffling and train/validation splits, along with scaling the pixel values to be in the range [0,1][0, 1]). For the task of finding the optimal σ\sigma value for the model, we set up a roughly log-scaled 1-dimensional grid search. The training of a KDE for both MNIST &\& CIFAR100 was repeated for the set of σ\sigma values σ={0.05,0.08,0.1,0.2,0.5,1.0,1.5,2.0}\sigma = \{0.05, 0.08, 0.1, 0.2, 0.5, 1.0, 1.5, 2.0\} on both datasets. For computing the log-likelihood of DB\mathcal{D}_B under DA\mathcal{D}_A with the previously-described probability density function, we used the split (DA=MNISTtrain\mathcal{D}_A = \mathrm{MNIST}_{\mathrm{train}}, DB=MNISTvalid\mathcal{D}_B = \mathrm{MNIST}_{\mathrm{valid}}) for MNIST and (DA=CIFAR100train\mathcal{D}_A = \mathrm{CIFAR100}_{\mathrm{train}}, DB=CIFARvalid\mathcal{D}_B = \mathrm{CIFAR}_{\mathrm{valid}}) for CIFAR100."\tWorking on {} dataset".format(dataset.upper()))
L_valid = [] # Initializing log-likelihood list
sigma_list = [0.05, 0.08, 0.10, 0.20, 0.50, 1.00, 1.50, 2.00]
for sigma in sigma_list:
logger.debug("\tKDE with Gaussian kernel using \u03C3 = {}".format(sigma))
kde_prob = kde_scratch(sigma, x_train[0:10], x_val[0:10])"\tWhere \u03C3 = {}, L_D_valid = {}".format(sigma, kde_prob))
sigma_optimal = sigma_list[np.argmax(L_valid)] # Optimal sigma value
logger.debug("\tPredicting model with optimal \u03C3")
begin_time = time.time()"\tOptimal \u03C3 from training = {}".format(sigma_optimal))
L_test = kde_scratch(sigma_optimal, x_train[0:10], x_test[0:10])"\tL_D_test with optimal \u03C3 = {}".format(L_test))
run_time = time.time() - begin_time
logger.debug("\tExecution time on test dataset: {} seconds".format(run_time))
INFO:__main__: Working on MNIST dataset
DEBUG:__main__: KDE with Gaussian kernel using σ = 0.05
INFO:__main__: Where σ = 0.05, L_D_valid = -13044.42
DEBUG:__main__: KDE with Gaussian kernel using σ = 0.08
INFO:__main__: Where σ = 0.08, L_D_valid = -4473.172
DEBUG:__main__: KDE with Gaussian kernel using σ = 0.1
INFO:__main__: Where σ = 0.1, L_D_valid = -2585.103
DEBUG:__main__: KDE with Gaussian kernel using σ = 0.2
INFO:__main__: Where σ = 0.2, L_D_valid = -377.8456
DEBUG:__main__: KDE with Gaussian kernel using σ = 0.5
INFO:__main__: Where σ = 0.5, L_D_valid = -326.022
DEBUG:__main__: KDE with Gaussian kernel using σ = 1.0
INFO:__main__: Where σ = 1.0, L_D_valid = -759.3353
DEBUG:__main__: KDE with Gaussian kernel using σ = 1.5
INFO:__main__: Where σ = 1.5, L_D_valid = -1056.648
DEBUG:__main__: KDE with Gaussian kernel using σ = 2.0
INFO:__main__: Where σ = 2.0, L_D_valid = -1274.786
DEBUG:__main__: Predicting model with optimal σ
INFO:__main__: Optimal σ from training = 0.5
INFO:__main__: L_D_test with optimal σ = -319.5088 DEBUG:__main__: Execution time on test dataset: 0.014196634292602539 seconds
INFO:__main__: Saving results to file "MNIST_kde_results.csv"


Results with numpy.random.seed = 1573368177


Starting from σ=0.05\sigma = 0.05, increases in σ\sigma result in increases in the log-likelihood. This log-likelihood eventually peaks (not necessarily at the same value for all datasets), and then begins to decrease as it increases beyond this optimal value (σoptimal\sigma_{\mathrm{optimal}}). The assumption behind treating this as a maximization problem is that the largest log-likelihood will correspond to the probability distribution with the most generalizable pattern for a given dataset. In the case of MNIST, the maximum log-likelihood reached LDvalidMNIST=326.022\mathcal{L}_{D^{\mathrm{MNIST}}_{\mathrm{valid}}} = -326.022 at σ=0.50\sigma = 0.50. For CIFAR100, LDvalidCIFAR100=909.8296\mathcal{L}_{D^{\mathrm{CIFAR100}}_{\mathrm{valid}}} = -909.8296 at σ=0.20\sigma = 0.20. These values for σ\sigma were then used to calculate the log-likelihoods for the train-test pairings: LDB\mathcal{L}_{D_{B}} on test data where (DA=MNISTtrain\mathcal{D}_A = \mathrm{MNIST}_{\mathrm{train}}, DB=MNISTtest\mathcal{D}_B = \mathrm{MNIST}_{\mathrm{test}}) and (DA=CIFAR100train\mathcal{D}_A = \mathrm{CIFAR100}_{\mathrm{train}}, DB=CIFARtest\mathcal{D}_B = \mathrm{CIFAR}_{\mathrm{test}})

Datasetσoptimal\sigma_{\mathrm{optimal}}LDtest\mathcal{L}_{\mathcal{D}_{\mathrm{test}}}Average Run-time
DtestMNIST\mathcal{D}^{\mathrm{MNIST}}_{\mathrm{test}}0.500.50319.5088-319.50881.09491.0949 milliseconds
DtestCIFAR100\mathcal{D}^{\mathrm{CIFAR100}}_{\mathrm{test}}0.200.20310.7053-310.70532.10432.1043 milliseconds

When computed on the respective σoptimal\sigma_{\mathrm{optimal}} values for both MNIST and CIFAR100, the result was LDvalidMNIST=319.5088\mathcal{L}_{D^{\mathrm{MNIST}}_{\mathrm{valid}}} = -319.5088 on σ\sigma that is 0.500.50 and LDvalidCIFAR100=310.7053\mathcal{L}_{D^{\mathrm{CIFAR100}}_{\mathrm{valid}}} = -310.7053 when computed on the optimal σ\sigma that is 0.200.20. Based on the similarities between these values for the log-likelihoods between the train-validation splits and the train-test splits, the final trained model has successfully found out the optimal standard deviation of Gaussian kernel on the given MNIST and CIFAR100 datasets.

These results are summarized in Table 2, along with the benchmarking results of taking the average run-time of 5 runs of the KDE for each dataset. The 1.92×1.92 \times speedup of the MNIST model over the CIFAR100 model is likely due to the fact that, even after scaling to the [0,1][0, 1] range, each MNIST image is smaller than each CIFAR100 image by a factor of 49:19249 : 192. In other words, switching from gray-scale to RGB inputs adds significant computational cost.

Quality Control

At this point you may be asking, “Wait, how do we know this model is functioning correctly?“.

It is true that our visual analysis of the sample images confirmed that we performed our data imports correctly, but determining the performace of the KDE model is trickier.

A critical step in Machine learning workflows (that is sadly underutilized) is error-checking. In our case we have several other packages like SciKit-Learn and SciPy that have KDE modules we can compare our performance against.

from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from scipy.stats import gaussian_kde
from statsmodels.nonparametric.kde import KDEUnivariate
from statsmodels.nonparametric.kernel_density import KDEMultivariate

if dataset in ["MNIST", "mnist"]:
    x_train, x_val, x_test = load_mnist(data_dir="mnist.pkl")
elif dataset in ["CIFAR100", "cifar100", "CIFAR", "cifar"]:
    x_train, x_val, x_test = load_cifar(train_dir="cifar-100-python/train",

def kde_sklearn(sigma, D_A, D_B):
    """Kernel Density Estimation with Scikit-learn"""
    kde_skl = KernelDensity(bandwidth=sigma)
    # score_samples() returns the log-likelihood of the samples
    log_pdf = kde_skl.score_samples(D_B)
    return np.mean(log_pdf)
def kde_scipy(sigma, D_A, D_B):
    # kde = gaussian_kde(D_A, bw_method=sigma / D_A.std(ddof=1))
    kde = gaussian_kde(D_A, bw_method=sigma / D_A.std(ddof=1))
    return np.sum(kde.evaluate(D_B))
def kde_statsmodels_u(sigma, D_A, D_B):
    """Univariate Kernel Density Estimation with Statsmodels"""
    kde = KDEUnivariate(D_A)
    return kde.evaluate(D_B)

def kde_statsmodels_m(sigma, D_A, D_B):
    """Multivariate Kernel Density Estimation with Statsmodels"""
    kde = KDEMultivariate(D_A, bw=sigma * np.ones_like(D_A),
    return kde.pdf(D_B)

# testing our results
print(kde_scratch(0.2, x_train[0:10], x_val[0:10]))
print(kde_sklearn(0.2, x_train[0:10], x_val[0:10]))
print(kde_scipy(0.2, x_train[0:10], x_val[0:10]))

Misconceptions about KDE and KL-divergence

I’d like to conclude this tutorial with a very important concept. As you saw in the equations we outlined earlier, we create our Gaussian approximation using KL divergence. KL divergence is usually described as the level of overlap or lack thereof between two probability distributions. It is extremely tempting to think of KL Divergence as being analogous to some sort of distance between the distributions, but you should avoid this for one important reason: KL divergence is not symmetrical

For example, what if we reverse the positions of our training and validation datasets in our custom and out-of-the-box Gaussian KDE functions?

print(kde_scratch(0.2, x_val[0:10], x_train[0:10]))
print(kde_sklearn(0.2, x_val[0:10], x_train[0:10])
print(kde_scipy(0.2, x_val[0:10], x_train[0:10]))

Bit different from what we had before, isn’t it?


  1. Y. LeCun, C. Cortes, C. Burges, Mnist handwritten digit database, ATT Labs [Online]. Available: 2.
  2. A. Krizhevsky, Learning multiple layers of features from tiny images, Tech. rep. (2009).
  3. T. E. Oliphant, A guide to NumPy, Vol. 1, Trelgol Publishing USA, 2006.
  4. S. Van Der Walt, S. C. Colbert, G. Varoquaux, The numpy array: a structure for efficient numerical computation, Computingin Science & Engineering 13 (2) (2011) 22.
  5. J. D. Hunter, Matplotlib: A 2d graphics environment, Computing in science & engineering 9 (3) (2007) 90.
  6. L. McInnes, J. Healy, J. Melville, Umap: Uniform manifold approximation and projection for dimension reduction, arXivpreprint arXiv:1802.03426.
  7. S. Park, E. Serpedin, K. Qaraqe, Gaussian assumption: The least favorable but the most useful [lecture notes], IEEE SignalProcessing Magazine 30 (3) (2013) 183–186.

Cited as:

  title   = "Gaussian KDE from Scratch",
  author  = "McAteer, Matthew",
  journal = "",
  year    = "2019",
  url     = ""

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

At least this isn't a full-screen popup

That'd be more annoying. Anyways, subscribe to my newsletter to get new posts by email! I write about AI, Biotech, and a bunch of other topics.

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.