Face-Generating GAN

Using a good face-classifier to teach generators how to make nonexistent people.

UPDATE 08/14/2020: The code is now on GitHub.

This notebook creates a Genrative Adversarial Network (GAN). The GAN is designed to produce realistic human faces after being trained on a large data set of human faces. Prior to this, it’s functionality is tested by training on a database of handwriting sample and then producing realistic images of had-drawn alphanumeric characters.

Getting started is relatively straightforward. All that’s really needed are the python packages os, glob, matplotlib, numpy, tqdm, warnings, and tensorflow.

The Data

This GAN utilizes the MNIST and CelebA datasets. The MNIST data set is far simpler than the the celebA dataset. The MNIST is also black ans white instead of in RGB. Running the GAN on MNIST allows for testing of how well the model trains before moving onto the more complex data.

data_dir = '/input'

helper.download_extract('mnist', data_dir)
helper.download_extract('celeba', data_dir)

The values of the MNIST and CelebA dataset will be in the range of -0.5 to 0.5 of 28x28 dimensional images. The CelebA images will be cropped to remove parts of the image that don’t include a face, then resized down to 28x28.

The MNIST images are black and white images with a single color channel while the CelebA images have 3 color channels.

show_n_images = 25
mnist_images = helper.get_batch(glob(os.path.join(data_dir, 'mnist/*.jpg'))[:show_n_images], 28, 28, 'L')
plt.imshow(helper.images_square_grid(mnist_images, 'L'), cmap='gray')

The CelebFaces Attributes Dataset (CelebA) is a dataset that contains over 202,599 celebrity images with annotations. For the purposes of generating new faces, the annotations are irrelevant.

show_n_images = 25

mnist_images = helper.get_batch(glob(os.path.join(data_dir, 'img_align_celeba/*.jpg'))[:show_n_images], 28, 28, 'RGB')
plt.imshow(helper.images_square_grid(mnist_images, 'RGB'))

Constructing the Neural network

For building our GAN, we need the following parts:

  • Model Inputs
  • Discriminator
  • Generator
  • Model Loss
  • Model Optimizer
  • Training functions

Input

The function model_inputs() creates the model inputs. it returns a tuple of (tensor of real input images, tensor of z data, learning rate)

imagewidth: The input image width imageheight: The input image height z_dim: The dimension of Z

def model_inputs(image_width, image_height, image_channels, z_dim):
    tensor_real_inputs = tf.placeholder(dtype = tf.float32, shape = (None, image_height, image_width, image_channels))
    tensor_z_data = tf.placeholder(dtype = tf.float32, shape = (None, z_dim))
    learning_rate = tf.placeholder(dtype = tf.float32, shape = ())

    return tensor_real_inputs, tensor_z_data, learning_rate

Discriminator

The discriminator() function creates the discriminator network. It returns a tuple of the tensor output of the discriminator and the tensor logits of the discriminator.

images: Tensor of input image(s) reuse: Boolean if the weights should be reused

def discriminator(images, reuse = False):
    with tf.variable_scope('discriminator', reuse = reuse):
        x = tf.layers.conv2d(images, filters = 32, 
                             kernel_size = 5, 
                             strides = 2,
                             padding='same', 
                             kernel_initializer = tf.contrib.layers.xavier_initializer())
        x = tf.maximum(0.2 * x, x)
        
        x = tf.layers.conv2d(images, filters = 64, 
                             kernel_size = 5, 
                             strides = 2,
                             padding='same', 
                             kernel_initializer = tf.contrib.layers.xavier_initializer())
        x = tf.layers.batch_normalization(x, training = True)
        x = tf.maximum(0.2 * x, x)
        x = tf.nn.dropout(x, keep_prob = 0.5)
        
        x = tf.layers.conv2d(x, filters = 128, 
                             kernel_size = 5, 
                             strides = 2,
                             padding = 'same', 
                             kernel_initializer = tf.contrib.layers.xavier_initializer())
        x = tf.layers.batch_normalization(x, training = True)
        x = tf.maximum(0.2 * x, x)
        x = tf.nn.dropout(x, keep_prob = 0.5)
        
        x = tf.layers.conv2d(x, filters = 256, 
                             kernel_size = 5, 
                             strides = 2,
                             padding = 'same', 
                             kernel_initializer = tf.contrib.layers.xavier_initializer())
        x = tf.layers.batch_normalization(x, training = True)
        x = tf.maximum(0.2 * x, x)
        
        flattened = tf.reshape(x, [-1, 2 * 2 * 256])
        logits = tf.layers.dense(flattened, 1, activation = None)
        output = tf.sigmoid(logits)

    return output, logits

Generator

The generator() function creates the generator network. It returns the tensor output of the generator

z: Input z outchanneldim: The number of channels in the output image is_train: Boolean if generator is being used for training

def generator(z, out_channel_dim, is_train=True):
    
    with tf.variable_scope('generator', reuse = not is_train):
        x = tf.layers.dense(z, units = 4 * 4 * 512)
        x = tf.reshape(x, (-1, 4, 4, 512))
        x = tf.layers.batch_normalization(x, training = is_train)
        x = tf.maximum( 0.2 * x, x)
        
        x = tf.layers.conv2d_transpose(x, filters = 128, kernel_size = 4, strides = 1, padding = 'valid')
        x = tf.layers.batch_normalization(x, training = is_train)
        x = tf.maximum(0.2 * x, x)
        
        x = tf.layers.conv2d_transpose(x, filters = 64, kernel_size = 5, strides = 2, padding = 'same')
        x = tf.layers.batch_normalization(x, training = is_train)
        x = tf.maximum(0.2 * x, x)
        
        x = tf.layers.conv2d_transpose(x, filters = 32, kernel_size = 5, strides = 2, padding = 'same')
        x = tf.layers.batch_normalization(x, training = is_train)
        x = tf.maximum(0.2 * x, x)
        
        logits = tf.layers.conv2d_transpose(x, filters = out_channel_dim, kernel_size = 3, strides = 1,
                                            padding = 'same')
        out = tf.tanh(logits)
        print(out.get_shape())
        
    return out

Loss

The function model_loss() gets the loss for the discriminator and generator. It returns a tuple of the discriminator loss and the generator loss.

inputreal: Images from the real dataset inputz: Z input outchanneldim: The number of channels in the output image

def model_loss(input_real, input_z, out_channel_dim):
    
    generator_model = generator(input_z, out_channel_dim, is_train = True)
    discriminator_model_real, discriminator_logits_real = discriminator(input_real, reuse = False)
    discriminator_model_fake, discriminator_logits_fake = discriminator(generator_model, reuse = True)
    
    discriminator_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits = discriminator_logits_real, labels = tf.ones_like(discriminator_model_real) * (1 - 0.1)))
    
    discriminator_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits = discriminator_logits_fake, labels = tf.zeros_like(discriminator_model_fake)))
    
    generator_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            logits = discriminator_logits_fake, labels = tf.ones_like(discriminator_model_fake)))
    
    discriminator_loss = discriminator_loss_real + discriminator_loss_fake
    
    return discriminator_loss, generator_loss

Optimization

The function model_opt() gets the optimization operations. it returns a tuple of the discriminator training operation and the generator training operation)

dloss: Discriminator loss Tensor gloss: Generator loss Tensor learning_rate: Learning Rate Placeholder beta1: The exponential decay rate for the 1st moment in the optimizer

def model_opt(d_loss, g_loss, learning_rate, beta1):
    training_vars = tf.trainable_variables()
    discriminator_vars = [var for var in training_vars if var.name.startswith('discriminator')]
    generator_vars = [var for var in training_vars if var.name.startswith('generator')]
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    discriminator_updates = [opt for opt in update_ops if opt.name.startswith('discriminator')]
    generator_updates = [opt for opt in update_ops if opt.name.startswith('generator')]

    with tf.control_dependencies(discriminator_updates):
        discriminator_opt = tf.train.AdamOptimizer(
            learning_rate = learning_rate, beta1 = beta1).minimize(d_loss, var_list = discriminator_vars)

    with tf.control_dependencies(generator_updates):
        generator_opt = tf.train.AdamOptimizer(
            learning_rate = learning_rate, beta1 = beta1).minimize(g_loss, var_list = generator_vars)
            
    return discriminator_opt, generator_opt

Training

Showing Generator Output

The showgeneratoroutput() function shows an example output for the generator.

sess: TensorFlow session nimages: Number of Images to display inputz: Input Z Tensor outchanneldim: The number of channels in the output image image_mode: The mode to use for images (“RGB” or “L”)

def show_generator_output(sess, n_images, input_z, out_channel_dim, image_mode):
    cmap = None if image_mode == 'RGB' else 'gray'
    z_dim = input_z.get_shape().as_list()[-1]
    example_z = np.random.uniform(-1, 1, size=[n_images, z_dim])

    samples = sess.run(
        generator(input_z, out_channel_dim, False),
        feed_dict={input_z: example_z})

    images_grid = helper.images_square_grid(samples, image_mode)
    plt.imshow(images_grid, cmap=cmap)
    plt.show()

Building and Training the GANs

The train() function trains the GAN.

epochcount: Number of epochs batchsize: Batch Size zdim: Z dimension learningrate: Learning Rate beta1: The exponential decay rate for the 1st moment in the optimizer getbatches: Function to get batches datashape: Shape of the data dataimagemode: The image mode to use for images (“RGB” or “L”)

def train(epoch_count, batch_size, z_dim, learning_rate, beta1, get_batches, data_shape, data_image_mode):
    n_samples, width, height, channels = data_shape
    input_real, input_z, lr = model_inputs(width, height, channels, z_dim)
    discriminator_loss, generator_loss = model_loss(input_real, input_z, channels)
    discriminator_train_opt, generator_train_opt = model_opt(discriminator_loss, generator_loss, lr, beta1)
    print(data_shape)
    
    i = 0    
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch_i in range(epoch_count):
            for batch_images in get_batches(batch_size):           
                batch_images = batch_images * 2.0
                batch_z = np.random.uniform(-1, 1, size = (batch_size, z_dim))
                
                sess.run(discriminator_train_opt, feed_dict = {input_real: batch_images,
                                                               input_z: batch_z,
                                                               lr: learning_rate})
                sess.run(generator_train_opt, feed_dict = {input_real: batch_images,
                                                           input_z: batch_z,
                                                           lr: learning_rate})
                
                i += 1
                if i % 10 == 0:
                    train_loss_discriminator = discriminator_loss.eval({input_z: batch_z, input_real: batch_images})
                    train_loss_generator = generator_loss.eval({input_z: batch_z})
                    print('Epoch %d/%d Discriminator loss %.4f Generator loss %.4f' % (epoch_i + 1,
                                                                                       epoch_count,
                                                                                       train_loss_discriminator,
                                                                                       train_loss_generator))
                if i % 100 == 0:
                    show_generator_output(sess, 50, input_z, channels, data_image_mode)

Testing our GAN on MNIST

batch_size = 64
z_dim = 128
learning_rate = 0.0005
beta1 = 0.1
epochs = 2

# After 2 epochs, the GANs are usually able to generate images that look like handwritten digits.
# When running, make sure the loss of the generator is lower than the loss of the discriminator or close to 0.

mnist_dataset = helper.Dataset('mnist', glob(os.path.join(data_dir, 'mnist/*.jpg')))
with tf.Graph().as_default():
    train(epochs, batch_size, z_dim, learning_rate, beta1, mnist_dataset.get_batches,
          mnist_dataset.shape, mnist_dataset.image_mode)

Beginning of Epoch 1/2: Looks more like TV static than digits

Just after Beginning of Epoch 1/2: Looks like some very out-of-focus digits

End of Epoch 1/2: Better. Looking like a Parkinson’s patient scribbling on a chalkboard

Beginning of Epoch 2/2: Some recognizable Digits starting to appear

End of Epoch 2/2: Still got that chalkboard look, but it’s come a long way in only two epochs

Testing our GAN on CelebA

batch_size = 32
z_dim = 128
learning_rate = 0.0002
beta1 = 0.5
epochs = 1
# ~20 minutes for typical GPU to run one epoch

celeba_dataset = helper.Dataset('celeba', glob(os.path.join(data_dir, 'img_align_celeba/*.jpg')))
with tf.Graph().as_default():
    train(epochs, batch_size, z_dim, learning_rate, beta1, celeba_dataset.get_batches,
          celeba_dataset.shape, celeba_dataset.image_mode)

As a result, we have a series of Celebrity images

Beginning of Epoch 1/1: Looks nothing like a face
Slightly after Beginning of Epoch 1/1: looking more like heads, but a long way to go

About 40 percent through Epoch 1/1: Some serious nightmare fuel here

About 60 percent through Epoch 1/1: not quite as nightmarish as before, but definitely still in the uncanny valley

About 80 percent through Epoch 1/1: Getting better, but many still getting heads sheared off by artefacts

End of Epoch 1/1: Looking like faces you would se in a PS1 game, maybe slightly better

Future Steps


Cited as:

@article{mcateer2017facegan,
    title = "Face-Generating GAN",
    author = "McAteer, Matthew",
    journal = "matthewmcateer.me",
    year = "2017",
    url = "https://matthewmcateer.me/blog/face-generating-gan/"
}

If you notice mistakes and errors in this post, don’t hesitate to contact me at [contact at matthewmcateer dot me] and I will be very happy to correct them right away! Alternatily, you can follow me on Twitter and reach out to me there.

See you in the next post 😄

I write about AI, Biotech, and a bunch of other topics. Subscribe to get new posts by email!


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

At least this isn't a full-screen popup

That'd be more annoying. Anyways, subscribe to my newsletter to get new posts by email! I write about AI, Biotech, and a bunch of other topics.


This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.