Commit 4c14b496 authored by Bharath Ramsundar's avatar Bharath Ramsundar
Browse files

Continuing refactor of tensorflow models

parent c3833ed1
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -79,7 +79,7 @@ class MultiTaskDNN(Graph):

    loss_dict = {}
    for task in range(self.n_tasks):
      taskname = "task%d" % ind
      taskname = "task%d" % task 
      if self.task_type == "classification":
        loss_dict[taskname] = "binary_crossentropy"
      elif self.task_type == "regression":
+88 −77
Original line number Diff line number Diff line
@@ -19,6 +19,13 @@ from deepchem.models.tensorflow_models import utils as tf_utils
from deepchem.utils.save import log

class TensorflowGraph(object):
  """Simple class that holds information needed to run Tensorflow graph."""
  def __init__(self, graph, session, name_scopes):
    self.graph = graph
    self.session = session
    self.name_scopes = name_scopes

class TensorflowGraphModel(object):
  """Thin wrapper holding a tensorflow graph and a few vars.

  Notes:
@@ -105,47 +112,53 @@ class TensorflowGraph(object):
    self.train = train
    self.verbosity = verbosity


    self.graph = tf.Graph() 
    self.logdir = logdir

    # Lazily created by _get_shared_session().
    self._shared_session = None

    # Guard variable to make sure we don't Restore() this model
    # from a disk checkpoint more than once.
    self._restored_model = False

    # Cache of TensorFlow scopes, to prevent '_1' appended scope names
    # when subclass-overridden methods use the same scopes.
    self._name_scopes = {}

    # Path to save checkpoint files, which matches the
    # replicated supervisor's default path.
    self._save_path = os.path.join(logdir, 'model.ckpt')

    with self.graph.as_default():
    self.train_graph = self.construct_graph(train=True)
    self.eval_graph = self.construct_graph(train=False)


  def construct_graph(self, train):
    """Returns a TensorflowGraph object."""
    graph = tf.Graph() 

    # Lazily created by _get_shared_session().
    shared_session = None

    # Cache of TensorFlow scopes, to prevent '_1' appended scope names
    # when subclass-overridden methods use the same scopes.
    name_scopes = {}

    if train:
      with graph.as_default():
        model_ops.set_training(train)
      self.placeholder_root = 'placeholders'
      with tf.name_scope(self.placeholder_root) as scope:
        self.placeholder_scope = scope

    self.setup()
    # Setup graph
    with graph.as_default():
      with tf.name_scope('core_model'):
        self.build(graph, name_scopes)
      self.add_label_placeholders(graph, name_scopes)
      self.add_weight_placeholders(graph, name_scopes)

    if train:
      self.add_training_cost()
      self.merge_updates()
      self.add_training_cost(graph, name_scopes)
    else:
      self.add_output_ops()  # add softmax heads
      self.add_output_ops(graph)  # add softmax heads
    return TensorflowGraph(graph, shared_session, name_scopes)

  def setup(self):
    """Add ops common to training/eval to the graph."""
    with self.graph.as_default():
      with tf.name_scope('core_model'):
        self.build()
      self.add_label_placeholders()
      self.add_weight_placeholders()
  def _get_placeholder_scope(self, graph, name_scopes):
    """Gets placeholder scope."""
    placeholder_root = "placeholders"
    with graph.as_default():
      with tf.name_scope(placeholder_root) as scope:
        return scope

  def _shared_name_scope(self, name):
  def _shared_name_scope(self, name, graph, name_scopes):
    """Returns a singleton TensorFlow scope with the given name.

    Used to prevent '_1'-appended scopes when sharing scopes with child classes.
@@ -155,23 +168,25 @@ class TensorflowGraph(object):
    Returns:
      tf.name_scope with the provided name.
    """
    if name not in self._name_scopes:
      with self.graph.as_default():
    with graph.as_default():
      if name not in name_scopes:
        with tf.name_scope(name) as scope:
          self._name_scopes[name] = scope
    return tf.name_scope(self._name_scopes[name])
          name_scopes[name] = scope
      placeholder_scope = tf.name_scope(name_scopes[name])
      return placeholder_scope

  def add_training_cost(self):
    with self.graph.as_default():
  def add_training_cost(self, graph, name_scopes):
    with graph.as_default():
      self.require_attributes(['output', 'labels', 'weights'])
      epsilon = 1e-3  # small float to avoid dividing by zero
      weighted_costs = []  # weighted costs for each example
      gradient_costs = []  # costs used for gradient calculation

      with self._shared_name_scope('costs'):
      with self._shared_name_scope('costs', graph, name_scopes):
        for task in xrange(self.n_tasks):
          task_str = str(task).zfill(len(str(self.n_tasks)))
          with self._shared_name_scope('cost_{}'.format(task_str)):
          with self._shared_name_scope(
              'cost_{}'.format(task_str), graph, name_scopes):
            with tf.name_scope('weighted'):
              weighted_cost = self.cost(self.output[task], self.labels[task],
                                        self.weights[task])
@@ -187,7 +202,7 @@ class TensorflowGraph(object):
              gradient_costs.append(gradient_cost)

        # aggregated costs
        with self._shared_name_scope('aggregated'):
        with self._shared_name_scope('aggregated', graph, name_scopes):
          with tf.name_scope('gradient'):
            loss = tf.add_n(gradient_costs)

@@ -201,15 +216,6 @@ class TensorflowGraph(object):

      return weighted_costs

  def merge_updates(self):
    """Group updates into a single op."""
    with self.graph.as_default():
      updates = tf.get_default_graph().get_collection('updates')
      if updates:
        self.updates = tf.group(*updates, name='updates')
      else:
        self.updates = tf.no_op(name='updates')

  def fit(self, dataset, nb_epoch=10, pad_batches=False, shuffle=False,
          max_checkpoints_to_keep=5, log_every_N_batches=50):
    """Fit the model.
@@ -232,7 +238,7 @@ class TensorflowGraph(object):
    with self.graph.as_default():
      self.require_attributes(['loss', 'updates'])
      train_op = self.get_training_op()
      with self._get_shared_session() as sess:
      with self._get_shared_session(train=True) as sess:
        sess.run(tf.initialize_all_variables())
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        # Save an initial checkpoint.
@@ -304,9 +310,9 @@ class TensorflowGraph(object):
      n_tasks = self.n_tasks
      output = []
      start = time.time()
      with self._get_shared_session().as_default():
      with self._get_shared_session(train=False).as_default():
        feed_dict = self.construct_feed_dict(X)
        data = self._get_shared_session().run(
        data = self._get_shared_session(train=False).run(
            self.output, feed_dict=feed_dict)
        batch_output = np.asarray(data[:n_tasks], dtype=float)
        # reshape to batch_size x n_tasks x ...
@@ -325,16 +331,16 @@ class TensorflowGraph(object):

    return np.copy(outputs)

  def add_output_ops(self):
  def add_output_ops(self, graph):
    """Replace logits with softmax outputs."""
    with self.graph.as_default():
    with graph.as_default():
      softmax = []
      with tf.name_scope('inference'):
        for i, logits in enumerate(self.output):
          softmax.append(tf.nn.softmax(logits, name='softmax_%d' % i))
      self.output = softmax

  def build(self):
  def build(self, graph):
    """Define the core graph.

    NOTE(user): Operations defined here should be in their own name scope to
@@ -353,7 +359,7 @@ class TensorflowGraph(object):
    raise NotImplementedError('Must be overridden by concrete subclass')


  def add_label_placeholders(self):
  def add_label_placeholders(self, graph, name_scopes):
    """Add Placeholders for labels for each task.

    This method creates the following Placeholders for each task:
@@ -366,7 +372,7 @@ class TensorflowGraph(object):
    """
    raise NotImplementedError('Must be overridden by concrete subclass')

  def add_weight_placeholders(self):
  def add_weight_placeholders(self, graph, name_scopes):
    """Add Placeholders for example weights for each task.

    This method creates the following Placeholders for each task:
@@ -376,8 +382,9 @@ class TensorflowGraph(object):
    feeding and fetching the same tensor.
    """
    weights = []
    placeholder_scope = self._get_placeholder_scope(graph, name_scopes)
    for task in xrange(self.n_tasks):
      with tf.name_scope(self.placeholder_scope):
      with tf.name_scope(placeholder_scope):
        weights.append(tf.identity(
            tf.placeholder(tf.float32, shape=[None],
                           name='weights_%d' % task)))
@@ -409,13 +416,19 @@ class TensorflowGraph(object):
    opt = model_ops.Optimizer(self.optimizer, self.learning_rate, self.momentum)
    return opt.minimize(self.loss, name='train')

  def _get_shared_session(self):
    if not self._shared_session:
  def _get_shared_session(self, train):
    # allow_soft_placement=True allows ops without a GPU implementation
    # to run on the CPU instead.
    if train:
      if not self.train_graph.session:
        config = tf.ConfigProto(allow_soft_placement=True)
      self._shared_session = tf.Session(config=config)
    return self._shared_session
        self.train_graph.session = tf.Session(config=config)
      return self.train_graph.session
    else:
      if not self.eval_graph.session:
        config = tf.ConfigProto(allow_soft_placement=True)
        self.eval_graph.session = tf.Session(config=config)
      return self.eval_graph.session

  def _get_feed_dict(self, named_values):
    feed_dict = {}
@@ -435,8 +448,9 @@ class TensorflowGraph(object):
      assert not model_ops.is_training()
      last_checkpoint = self._find_last_checkpoint()

      # TODO(rbharath): Is setting train=Falseright here?
      saver = tf.train.Saver()
      saver.restore(self._get_shared_session(),
      saver.restore(self._get_shared_session(train=False),
                    last_checkpoint)
      self._restored_model = True

@@ -470,7 +484,7 @@ class TensorflowGraph(object):
        raise AssertionError(
            'self.%s must be defined by a concrete subclass' % attr)
  
class TensorflowClassifier(TensorflowGraph):
class TensorflowClassifier(TensorflowGraphModel):
  """Classification model.

  Subclasses must set the following attributes:
@@ -502,7 +516,7 @@ class TensorflowClassifier(TensorflowGraph):
    return tf.mul(tf.nn.softmax_cross_entropy_with_logits(logits, labels),
                  weights)

  def add_label_placeholders(self):
  def add_label_placeholders(self, graph, name_scopes):
    """Add Placeholders for labels for each task.

    This method creates the following Placeholders for each task:
@@ -511,19 +525,20 @@ class TensorflowClassifier(TensorflowGraph):
    Placeholders are wrapped in identity ops to avoid the error caused by
    feeding and fetching the same tensor.
    """
    with self.graph.as_default():
    placeholder_scope = self._get_placeholder_scope(graph, name_scopes)
    with graph.as_default():
      batch_size = self.batch_size 
      n_classes = self.n_classes
      labels = []
      for task in xrange(self.n_tasks):
        with tf.name_scope(self.placeholder_scope):
        with tf.name_scope(placeholder_scope):
          labels.append(tf.identity(
              tf.placeholder(tf.float32, shape=[None, n_classes],
                             name='labels_%d' % task)))
      self.labels = labels


class TensorflowRegressor(TensorflowGraph):
class TensorflowRegressor(TensorflowGraphModel):
  """Regression model.

  Subclasses must set the following attributes:
@@ -539,7 +554,7 @@ class TensorflowRegressor(TensorflowGraph):
  def get_task_type(self):
    return "regressor"

  def add_output_ops(self):
  def add_output_ops(self, graph):
    """No-op for regression models since no softmax."""
    pass

@@ -557,7 +572,7 @@ class TensorflowRegressor(TensorflowGraph):
    """
    return tf.mul(0.5 * tf.square(output - labels), weights)

  def add_label_placeholders(self):
  def add_label_placeholders(self, graph, placeholder_scope):
    """Add Placeholders for labels for each task.

    This method creates the following Placeholders for each task:
@@ -566,11 +581,11 @@ class TensorflowRegressor(TensorflowGraph):
    Placeholders are wrapped in identity ops to avoid the error caused by
    feeding and fetching the same tensor.
    """
    with self.graph.as_default():
    with graph.as_default():
      batch_size = self.batch_size
      labels = []
      for task in xrange(self.n_tasks):
        with tf.name_scope(self.placeholder_scope):
        with tf.name_scope(placeholder_scope):
          labels.append(tf.identity(
              tf.placeholder(tf.float32, shape=[None],
                             name='labels_%d' % task)))
@@ -581,11 +596,7 @@ class TensorflowModel(Model):
  Abstract base class shared across all Tensorflow models.
  """

  def __init__(self, tf_class, logdir, verbosity=None):
    '''
    Args:
      tf_class: Class that inherits from TensorflowGraph
    ''' 
  def __init__(self, model, logdir, verbosity=None):
    assert verbosity in [None, "low", "high"]
    self.verbosity = verbosity
    if tf_class is None:
+7 −6
Original line number Diff line number Diff line
@@ -37,16 +37,17 @@ def softmax(x):
class TensorflowMultiTaskClassifier(TensorflowClassifier):
  """Implements an icml model as configured in a model_config.proto."""

  def build(self):
  def build(self, graph, name_scopes):
    """Constructs the graph architecture as specified in its config.

    This method creates the following Placeholders:
      mol_features: Molecule descriptor (e.g. fingerprint) tensor with shape
        batch_size x n_features.
    """
    placeholder_scope = self._get_placeholder_scope(graph, name_scopes)
    n_features = self.n_features
    with self.graph.as_default():
      with tf.name_scope(self.placeholder_scope):
    with graph.as_default():
      with tf.name_scope(placeholder_scope):
        self.mol_features = tf.placeholder(
            tf.float32,
            shape=[None, n_features],
@@ -132,7 +133,7 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
    """
    if not self._restored_model:
      self.restore()
    with self.graph.as_default():
    with self.eval_graph.graph.as_default():
      assert not model_ops.is_training()
      self.require_attributes(['output'])

@@ -163,7 +164,7 @@ class TensorflowMultiTaskClassifier(TensorflowClassifier):
class TensorflowMultiTaskRegressor(TensorflowRegressor):
  """Implements an icml model as configured in a model_config.proto."""

  def build(self):
  def build(self, graph):
    """Constructs the graph architecture as specified in its config.

    This method creates the following Placeholders:
@@ -171,7 +172,7 @@ class TensorflowMultiTaskRegressor(TensorflowRegressor):
        batch_size x n_features.
    """
    n_features = self.n_inputs
    with self.graph.as_default():
    with graph.as_default():
      with tf.name_scope(self.placeholder_scope):
        self.mol_features = tf.placeholder(
            tf.float32,
+0 −507

File deleted.

Preview size limit exceeded, changes collapsed.

+20 −123
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ from deepchem.models.tensorflow_models.fcnet import TensorflowMultiTaskClassifie
from deepchem.splits import ScaffoldSplitter
from deepchem.splits import SpecifiedSplitter
from deepchem.models.keras_models.fcnet import MultiTaskDNN
from deepchem.models.keras_models import KerasModel 
import tensorflow as tf
from keras import backend as K

@@ -45,7 +46,6 @@ class TestModelAPI(TestAPI):
    """Test of singletask RF ECFP regression API."""
    splittype = "scaffold"
    featurizer = CircularFingerprint(size=1024)
    model_params = {}
    tasks = ["log-solubility"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
@@ -64,14 +64,12 @@ class TestModelAPI(TestAPI):
    output_transformers = [
        NormalizationTransformer(transform_y=True, dataset=train_dataset)]
    transformers = input_transformers + output_transformers
    model_params["data_shape"] = train_dataset.get_data_shape()
    regression_metrics = [Metric(metrics.r2_score),
                          Metric(metrics.mean_squared_error),
                          Metric(metrics.mean_absolute_error)]

    model = SklearnModel(tasks, task_types, model_params, self.model_dir,
                         mode="regression",
                         model_instance=RandomForestRegressor())
    sklearn_model = RandomForestRegressor()
    model = SklearnModel(sklearn_model, self.model_dir)

    # Fit trained model
    model.fit(train_dataset)
@@ -89,7 +87,6 @@ class TestModelAPI(TestAPI):
    """Test of singletask RF USF regression API."""
    splittype = "specified"
    featurizer = UserDefinedFeaturizer(["user-specified1", "user-specified2"])
    model_params = {}
    tasks = ["log-solubility"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
@@ -112,14 +109,12 @@ class TestModelAPI(TestAPI):
      for transformer in transformers:
        transformer.transform(dataset)

    model_params["data_shape"] = train_dataset.get_data_shape()
    regression_metrics = [Metric(metrics.r2_score),
                          Metric(metrics.mean_squared_error),
                          Metric(metrics.mean_absolute_error)]

    model = SklearnModel(tasks, task_types, model_params, self.model_dir,
                         mode="regression",
                         model_instance=RandomForestRegressor())
    sklearn_model = RandomForestRegressor()
    model = SklearnModel(sklearn_model, self.model_dir)

    # Fit trained model
    model.fit(train_dataset)
@@ -137,7 +132,6 @@ class TestModelAPI(TestAPI):
    """Test of singletask RF ECFP regression API: sharded edition."""
    splittype = "scaffold"
    featurizer = CircularFingerprint(size=1024)
    model_params = {}
    tasks = ["label"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
@@ -162,14 +156,12 @@ class TestModelAPI(TestAPI):
        transformer.transform(dataset)
    # We set shard size above to force the creation of multiple shards of the data.
    # pdbbind_core has ~200 examples.
    model_params["data_shape"] = train_dataset.get_data_shape()
    regression_metrics = [Metric(metrics.r2_score),
                          Metric(metrics.mean_squared_error),
                          Metric(metrics.mean_absolute_error)]

    model = SklearnModel(tasks, task_types, model_params, self.model_dir,
                         mode="regression",
                         model_instance=RandomForestRegressor())
    sklearn_model = RandomForestRegressor()
    model = SklearnModel(sklearn_model, self.model_dir)

    # Fit trained model
    model.fit(train_dataset)
@@ -190,7 +182,6 @@ class TestModelAPI(TestAPI):
    tasks = ["log-solubility"]
    task_type = "regression"
    task_types = {task: task_type for task in tasks}
    model_params = {}
    input_file = os.path.join(self.current_dir, "example.csv")
    loader = DataLoader(tasks=tasks,
                        smiles_field=self.smiles_field,
@@ -213,15 +204,12 @@ class TestModelAPI(TestAPI):
      for transformer in transformers:
        transformer.transform(dataset)

    model_params["data_shape"] = train_dataset.get_data_shape()
    regression_metrics = [Metric(metrics.r2_score),
                          Metric(metrics.mean_squared_error),
                          Metric(metrics.mean_absolute_error)]

    model = SklearnModel(tasks, task_types, model_params, self.model_dir,
                         mode="regression",
                         model_instance=RandomForestRegressor())
  
    sklearn_model = RandomForestRegressor()
    model = SklearnModel(sklearn_model, self.model_dir)

    # Fit trained model
    model.fit(train_dataset)
@@ -235,64 +223,6 @@ class TestModelAPI(TestAPI):
    evaluator = Evaluator(model, test_dataset, transformers, verbosity=True)
    _ = evaluator.compute_model_performance(regression_metrics)


  #### TODO(rbharath): This test is being disabled since deepchem no longer
  #### accepts this format of input. Decide whether this test should be deleted
  #### altogether or replaced.
  #def test_singletask_keras_mlp_USF_regression_API(self):
  #  """Test of singletask MLP User Specified Features regression API."""
  #  from deepchem.models.keras_models.fcnet import SingleTaskDNN
  #  featurizer = UserDefinedFeaturizer(["evals"])
  #  tasks = ["u0"]
  #  task_type = "regression"
  #  task_types = {task: task_type for task in tasks}
  #  model_params = {"nb_hidden": 10, "activation": "relu",
  #                  "dropout": .5, "learning_rate": .01,
  #                  "momentum": .9, "nesterov": False,
  #                  "decay": 1e-4, "batch_size": 5,
  #                  "nb_epoch": 2, "init": "glorot_uniform",
  #                  "nb_layers": 1, "batchnorm": False}

  #  input_file = os.path.join(self.current_dir, "gbd3k.pkl.gz")
  #  loader = DataLoader(tasks=tasks,
  #                      smiles_field=self.smiles_field,
  #                      featurizer=featurizer,
  #                      verbosity="low")
  #  dataset = loader.featurize(input_file, self.data_dir)

  #  splitter = ScaffoldSplitter()
  #  train_dataset, test_dataset = splitter.train_test_split(
  #      dataset, self.train_dir, self.test_dir)

  #  input_transformers = [
  #    NormalizationTransformer(transform_X=True, dataset=train_dataset),
  #    ClippingTransformer(transform_X=True, dataset=train_dataset)]
  #  output_transformers = [
  #    NormalizationTransformer(transform_y=True, dataset=train_dataset)]
  #  transformers = input_transformers + output_transformers

  #  for dataset in [train_dataset, test_dataset]:
  #    for transformer in transformers:
  #      transformer.transform(dataset)

  #  model_params["data_shape"] = train_dataset.get_data_shape()
  #  regression_metrics = [Metric(metrics.r2_score),
  #                        Metric(metrics.mean_squared_error),
  #                        Metric(metrics.mean_absolute_error)]

  #  # Fit trained model
  #  model.fit(train_dataset)
  #  model.save()

  #  # Eval model on train
  #  evaluator = Evaluator(model, train_dataset, transformers, verbosity=True)
  #  _ = evaluator.compute_model_performance(regression_metrics)

  #  # Eval model on test
  #  evaluator = Evaluator(model, test_dataset, transformers, verbosity=True)
  #  _ = evaluator.compute_model_performance(regression_metrics)


  def test_multitask_keras_mlp_ECFP_classification_API(self):
    """Straightforward test of Keras multitask deepchem classification API."""
    g = tf.Graph()
@@ -300,24 +230,13 @@ class TestModelAPI(TestAPI):
    K.set_session(sess)
    with g.as_default():
      task_type = "classification"
      # TODO(rbharath): There should be some automatic check to ensure that all
      # required model_params are specified.
      # TODO(rbharath): Turning off dropout to make tests behave.
      model_params = {"nb_hidden": 10, "activation": "relu",
                      "dropout": .0, "learning_rate": .01,
                      "momentum": .9, "nesterov": False,
                      "decay": 1e-4, "batch_size": 5,
                      "nb_epoch": 2, "init": "glorot_uniform",
                      "nb_layers": 1, "batchnorm": False}

      input_file = os.path.join(self.current_dir, "multitask_example.csv")
      tasks = ["task0", "task1", "task2", "task3", "task4", "task5", "task6",
               "task7", "task8", "task9", "task10", "task11", "task12",
               "task13", "task14", "task15", "task16"]
      task_types = {task: task_type for task in tasks}

      featurizer = CircularFingerprint(size=1024)

      n_features = 1024
      featurizer = CircularFingerprint(size=n_features)
      loader = DataLoader(tasks=tasks,
                          smiles_field=self.smiles_field,
                          featurizer=featurizer,
@@ -328,13 +247,14 @@ class TestModelAPI(TestAPI):
          dataset, self.train_dir, self.test_dir)

      transformers = []
      model_params["data_shape"] = train_dataset.get_data_shape()
      classification_metrics = [Metric(metrics.roc_auc_score),
                                Metric(metrics.matthews_corrcoef),
                                Metric(metrics.recall_score),
                                Metric(metrics.accuracy_score)]
      
      model = MultiTaskDNN(tasks, task_types, model_params, self.model_dir)
      keras_model = MultiTaskDNN(len(tasks), n_features, "classification",
                                 dropout=0.)
      model = KerasModel(keras_model, self.model_dir)

      # Fit trained model
      model.fit(train_dataset)
@@ -350,16 +270,10 @@ class TestModelAPI(TestAPI):

  def test_singletask_tf_mlp_ECFP_classification_API(self):
    """Straightforward test of Tensorflow singletask deepchem classification API."""
    splittype = "scaffold"
    output_transformers = []
    input_transformers = []
    task_type = "classification"

    featurizer = CircularFingerprint(size=1024)
    n_features = 1024
    featurizer = CircularFingerprint(size=n_features)

    tasks = ["outcome"]
    task_type = "classification"
    task_types = {task: task_type for task in tasks}
    input_file = os.path.join(self.current_dir, "example_classification.csv")

    loader = DataLoader(tasks=tasks,
@@ -372,38 +286,21 @@ class TestModelAPI(TestAPI):
    train_dataset, test_dataset = splitter.train_test_split(
        dataset, self.train_dir, self.test_dir)
    
    input_transformers = []
    output_transformers = [
    transformers = [
        NormalizationTransformer(transform_y=True, dataset=train_dataset)]
    transformers = input_transformers + output_transformers

    for dataset in [train_dataset, test_dataset]:
      for transformer in transformers:
        transformer.transform(dataset)

    model_params = {
      "batch_size": 2,
      "num_classification_tasks": 1,
      "num_features": 1024,
      "layer_sizes": [1024],
      "weight_init_stddevs": [1.],
      "bias_init_consts": [0.],
      "dropouts": [.5],
      "num_classes": 2,
      "nb_epoch": 1,
      "penalty": 0.0,
      "optimizer": "adam",
      "learning_rate": .001,
      "data_shape": train_dataset.get_data_shape()
    }
    classification_metrics = [Metric(metrics.roc_auc_score),
                              Metric(metrics.matthews_corrcoef),
                              Metric(metrics.recall_score),
                              Metric(metrics.accuracy_score)]

    model = TensorflowModel(
        tasks, task_types, model_params, self.model_dir,
        tf_class=TensorflowMultiTaskClassifier)
    tensorflow_model = TensorflowMultiTaskClassifier(
        len(tasks), n_features, self.model_dir)
    model = TensorflowModel(tensorflow_model, self.model_dir)

    # Fit trained model
    model.fit(train_dataset)