Commit 5d7b2a76 authored by Peter Eastman's avatar Peter Eastman
Browse files

Created MNIST GAN example

parent 2ddb535b
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -645,6 +645,27 @@ def relu(x, alpha=0., max_value=None):
  return x


def lrelu(alpha=0.01):
  """Create a leaky rectified linear unit function.

  This function returns a new function that implements the LReLU with a
  specified alpha.  The returned value can be used as an activation function in
  network layers.

  Parameters
  ----------
  alpha: float
    the slope of the function when x<0

  Returns
  -------
  a function f(x) that returns alpha*x when x<0, and x when x>0.
  """
  def eval(x):
    return relu(x, alpha=alpha)
  return eval


def selu(x):
  """Scaled Exponential Linear unit.

+2 −0
Original line number Diff line number Diff line
%% Cell type:markdown id: tags:

Conditional Generative Adversarial Network
----------------------------------------

*Note: This example implements a GAN from scratch.  The same model could be implemented much more easily with the `dc.models.GAN` class.  See the MNIST GAN notebook for an example of using that class.  It can still be useful to know how to implement a GAN from scratch for advanced situations that are beyond the scope of what the standard GAN class supports.*

A Generative Adversarial Network (GAN) is a type of generative model.  It consists of two parts called the "generator" and the "discriminator".  The generator takes random values as input and transforms them into an output that (hopefully) resembles the training data.  The discriminator takes a set of samples as input and tries to distinguish the real training samples from the ones created by the generator.  Both of them are trained together.  The discriminator tries to get better and better at telling real from false data, while the generator tries to get better and better at fooling the discriminator.

A Conditional GAN (CGAN) allows additional inputs to the generator and discriminator that their output is conditioned on.  For example, this might be a class label, and the GAN tries to learn how the data distribution varies between classes.

For this example, we will create a data distribution consisting of a set of ellipses in 2D, each with a random position, shape, and orientation.  Each class corresponds to a different ellipse.  Let's randomly generate the ellipses.

%% Cell type:code id: tags:

``` python
import deepchem as dc
import numpy as np
import tensorflow as tf

n_classes = 4
class_centers = np.random.uniform(-4, 4, (n_classes, 2))
class_transforms = []
for i in range(n_classes):
    xscale = np.random.uniform(0.5, 2)
    yscale = np.random.uniform(0.5, 2)
    angle = np.random.uniform(0, np.pi)
    m = [[xscale*np.cos(angle), -yscale*np.sin(angle)],
         [xscale*np.sin(angle), yscale*np.cos(angle)]]
    class_transforms.append(m)
class_transforms = np.array(class_transforms)
```

%% Cell type:markdown id: tags:

This function generates random data from the distribution.  For each point it chooses a random class, then a random position in that class' ellipse.

%% Cell type:code id: tags:

``` python
def generate_data(n_points):
    classes = np.random.randint(n_classes, size=n_points)
    r = np.random.random(n_points)
    angle = 2*np.pi*np.random.random(n_points)
    points = (r*np.array([np.cos(angle), np.sin(angle)])).T
    points = np.einsum('ijk,ik->ij', class_transforms[classes], points)
    points += class_centers[classes]
    return classes, points
```

%% Cell type:markdown id: tags:

Let's plot a bunch of random points drawn from this distribution to see what it looks like.  Points are colored based on their class label.

%% Cell type:code id: tags:

``` python
%matplotlib inline
import matplotlib.pyplot as plot
classes, points = generate_data(1000)
plot.scatter(x=points[:,0], y=points[:,1], c=classes)
```

%% Output

    <matplotlib.collections.PathCollection at 0x11fc3d860>


%% Cell type:markdown id: tags:

Now let's create the model for our CGAN.

%% Cell type:code id: tags:

``` python
import deepchem.models.tensorgraph.layers as layers
model = dc.models.TensorGraph(learning_rate=1e-4, use_queue=False)

# Inputs to the model

random_in = layers.Feature(shape=(None, 10)) # Random input to the generator
generator_classes = layers.Feature(shape=(None, n_classes)) # The classes of the generated samples
real_data_points = layers.Feature(shape=(None, 2)) # The training samples
real_data_classes = layers.Feature(shape=(None, n_classes)) # The classes of the training samples
is_real = layers.Weights(shape=(None, 1)) # Flags to distinguish real from generated samples

# The generator

gen_in = layers.Concat([random_in, generator_classes])
gen_dense1 = layers.Dense(30, in_layers=gen_in, activation_fn=tf.nn.relu)
gen_dense2 = layers.Dense(30, in_layers=gen_dense1, activation_fn=tf.nn.relu)
generator_points = layers.Dense(2, in_layers=gen_dense2)
model.add_output(generator_points)

# The discriminator

all_points = layers.Concat([generator_points, real_data_points], axis=0)
all_classes = layers.Concat([generator_classes, real_data_classes], axis=0)
discrim_in = layers.Concat([all_points, all_classes])
discrim_dense1 = layers.Dense(30, in_layers=discrim_in, activation_fn=tf.nn.relu)
discrim_dense2 = layers.Dense(30, in_layers=discrim_dense1, activation_fn=tf.nn.relu)
discrim_prob = layers.Dense(1, in_layers=discrim_dense2, activation_fn=tf.sigmoid)
```

%% Cell type:markdown id: tags:

We'll use different loss functions for training the generator and discriminator.  The discriminator outputs its predictions in the form of a probability that each sample is a real sample (that is, that it came from the training set rather than the generator).  Its loss consists of two terms.  The first term tries to maximize the output probability for real data, and the second term tries to minimize the output probability for generated samples.  The loss function for the generator is just a single term: it tries to maximize the discriminator's output probability for generated samples.

For each one, we create a "submodel" specifying a set of layers that will be optimized based on a loss function.

%% Cell type:code id: tags:

``` python
# Discriminator

discrim_real_data_loss = -layers.Log(discrim_prob+1e-10) * is_real
discrim_gen_data_loss = -layers.Log(1-discrim_prob+1e-10) * (1-is_real)
discrim_loss = layers.ReduceMean(discrim_real_data_loss + discrim_gen_data_loss)
discrim_submodel = model.create_submodel(layers=[discrim_dense1, discrim_dense2, discrim_prob], loss=discrim_loss)

# Generator

gen_loss = -layers.ReduceMean(layers.Log(discrim_prob+1e-10) * (1-is_real))
gen_submodel = model.create_submodel(layers=[gen_dense1, gen_dense2, generator_points], loss=gen_loss)
```

%% Cell type:markdown id: tags:

Now to fit the model.  Here are some important points to notice about the code.

- We use `fit_generator()` to train only a single batch at a time, and we alternate between the discriminator and the generator.  That way. both parts of the model improve together.
- We only train the generator half as often as the discriminator.  On this particular model, that gives much better results.  You will often need to adjust `(# of discriminator steps)/(# of generator steps)` to get good results on a given problem.
- We disable checkpointing by specifying `checkpoint_interval=0`.  Since each call to `fit_generator()` includes only a single batch, it would otherwise save a checkpoint to disk after every batch, which would be very slow.  If this were a real project and not just an example, we would want to occasionally call `model.save_checkpoint()` to write checkpoints at a reasonable interval.

%% Cell type:code id: tags:

``` python
batch_size = model.batch_size
discrim_error = []
gen_error = []
for step in range(20000):
    classes, points = generate_data(batch_size)
    class_flags = dc.metrics.to_one_hot(classes, n_classes)
    feed_dict={random_in: np.random.random((batch_size, 10)),
               generator_classes: class_flags,
               real_data_points: points,
               real_data_classes: class_flags,
               is_real: np.concatenate([np.zeros((batch_size,1)), np.ones((batch_size,1))])}
    discrim_error.append(model.fit_generator([feed_dict],
                                             submodel=discrim_submodel,
                                             checkpoint_interval=0))
    if step%2 == 0:
        gen_error.append(model.fit_generator([feed_dict],
                                             submodel=gen_submodel,
                                             checkpoint_interval=0))
    if step%1000 == 999:
        print(step, np.mean(discrim_error), np.mean(gen_error))
        discrim_error = []
        gen_error = []
```

%% Output

    999 1.55168287337 0.0408441077992
    1999 1.09775845635 0.10187220595
    2999 0.645102792382 0.287973743975
    3999 0.680221937269 0.421649519399
    4999 1.01530539939 0.21572620222
    5999 0.355141538292 0.51490073061
    6999 0.342691998512 0.522037278384
    7999 0.668916338205 0.268597296476
    8999 0.716327803612 0.262840693563
    9999 0.713392020047 0.29249474436
    10999 0.701064119875 0.309196961999
    11999 0.69168697983 0.326060062766
    12999 0.688140272975 0.335194783509
    13999 0.687209348738 0.336962600589
    14999 0.685669041574 0.343026666999
    15999 0.686265172005 0.343531137526
    16999 0.686260136843 0.346471596062
    17999 0.687051311135 0.347908975482
    18999 0.689062111676 0.344198456705
    19999 0.691419941247 0.344754773915

%% Cell type:markdown id: tags:

Have the trained model generate some data, and see how well it matches the training distribution we plotted before.

%% Cell type:code id: tags:

``` python
classes, points = generate_data(1000)
feed_dict = {random_in: np.random.random((1000, 10)),
             generator_classes: dc.metrics.to_one_hot(classes, n_classes)}
gen_points = model.predict_on_generator([feed_dict])
plot.scatter(x=gen_points[:,0], y=gen_points[:,1], c=classes)
```

%% Output

    <matplotlib.collections.PathCollection at 0x120561a58>

+213 −0

File added.

Preview size limit exceeded, changes collapsed.