Commit c96af58e authored by Joseph Gomes's avatar Joseph Gomes
Browse files

Add self.pad_batches parameter to Tensorflow models; tests passing

parent 4bf83b1d
Loading
Loading
Loading
Loading
+3 −10
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ import sklearn

from deepchem.data import Dataset, pad_features
from deepchem.trans import undo_transforms
from deepchem.trans import undo_grad_transforms
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import log
@@ -58,7 +57,7 @@ class Model(object):
    raise NotImplementedError(
        "Each model is responsible for its own fit_on_batch method.")

  def predict_on_batch(self, X, pad_batches=False):
  def predict_on_batch(self, X):
    """
    Makes predictions on given batch of new data.

@@ -66,14 +65,11 @@ class Model(object):
    ----------
    X: np.ndarray
      Features
    pad_batch: bool, optional
      Ignored for Sklearn Model. Only used for Tensorflow models
      with rigid batch-size requirements.
    """
    raise NotImplementedError(
        "Each model is responsible for its own predict_on_batch method.")

  def predict_proba_on_batch(self, X, pad_batches=False):
  def predict_proba_on_batch(self, X):
    """
    Makes predictions of class probabilities on given batch of new data.

@@ -81,9 +77,6 @@ class Model(object):
    ----------
    X: np.ndarray
      Features
    pad_batch: bool, optional
      Ignored for Sklearn Model. Only used for Tensorflow models
      with rigid batch-size requirements.
    """
    raise NotImplementedError(
        "Each model is responsible for its own predict_on_batch method.")
@@ -126,7 +119,7 @@ class Model(object):
      log("Starting epoch %s" % str(epoch+1), self.verbose)
      losses = []
      for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
          batch_size, pad_batches=pad_batches):
          batch_size):
        losses.append(self.fit_on_batch(X_batch, y_batch, w_batch))
      log("Avg loss for epoch %d: %f"
          % (epoch+1,np.array(losses).mean()),self.verbose)
+19 −18
Original line number Diff line number Diff line
@@ -17,7 +17,9 @@ from deepchem.models import Model
from deepchem.metrics import from_one_hot
from deepchem.nn import model_ops
from deepchem.models.tensorflow_models import utils as tf_utils
from deepchem.trans import undo_transforms
from deepchem.utils.save import log
from deepchem.utils.evaluate import Evaluator
from deepchem.data import pad_features
from tensorflow.contrib.layers.python.layers import batch_norm

@@ -109,7 +111,7 @@ class TensorflowGraphModel(Model):
               weight_init_stddevs=[.02], bias_init_consts=[1.], penalty=0.0,
               penalty_type="l2", dropouts=[0.5], learning_rate=.001,
               momentum=.9, optimizer="adam", batch_size=50, n_classes=2,
               verbose=True, seed=None, **kwargs):
               pad_batches=False, verbose=True, seed=None, **kwargs):
    """Constructs the computational graph.

    This function constructs the computational graph for the model. It relies
@@ -166,6 +168,7 @@ class TensorflowGraphModel(Model):
    self.optimizer = optimizer
    self.batch_size = batch_size
    self.n_classes = n_classes
    self.pad_batches = pad_batches
    self.verbose= verbose
    self.seed = seed
    
@@ -270,8 +273,8 @@ class TensorflowGraphModel(Model):

      return loss 

  def fit(self, dataset, nb_epoch=10, pad_batch=False, 
          max_checkpoints_to_keep=5, log_every_N_batches=50, **kwargs):
  def fit(self, dataset, nb_epoch=10, max_checkpoints_to_keep=5, 
	  log_every_N_batches=50, **kwargs):
    """Fit the model.

    Parameters
@@ -280,8 +283,6 @@ class TensorflowGraphModel(Model):
      Dataset object holding training data 
    nb_epoch: 10
      Number of training epochs.
    pad_batches: bool
      Whether or not to pad each batch to exactly be of size batch_size.
    max_checkpoints_to_keep: int
      Maximum number of checkpoints to keep; older checkpoints will be deleted.
    log_every_N_batches: int
@@ -311,7 +312,7 @@ class TensorflowGraphModel(Model):
              # Turns out there are valid cases where we don't want pad-batches
              # on by default.
              #dataset.iterbatches(batch_size, pad_batches=True)):
              dataset.iterbatches(self.batch_size, pad_batch=pad_batch)):
              dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            # Run training op.
@@ -453,7 +454,7 @@ class TensorflowGraphModel(Model):
                    last_checkpoint)
      self._restored_model = True

  def predict(self, dataset, transformers=[], pad_batch=False):
  def predict(self, dataset, transformers=[]):
    """
    Uses self to make predictions on provided Dataset object.

@@ -467,7 +468,7 @@ class TensorflowGraphModel(Model):
    for (X_batch, _, _, ids_batch) in dataset.iterbatches(
        self.batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_on_batch(X_batch, pad_batch=pad_batch)
      y_pred_batch = self.predict_on_batch(X_batch)
      # Discard any padded predictions
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks))
@@ -484,7 +485,7 @@ class TensorflowGraphModel(Model):
      y_pred = np.reshape(y_pred, (n_samples,)) 
    return y_pred

  def predict_proba(self, dataset, transformers=[], n_classes=2, pad_batch=False):
  def predict_proba(self, dataset, transformers=[], n_classes=2):
    """
    TODO: Do transformers even make sense here?

@@ -497,7 +498,7 @@ class TensorflowGraphModel(Model):
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
        self.batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_proba_on_batch(X_batch, pad_batch=pad_batch)
      y_pred_batch = self.predict_proba_on_batch(X_batch)
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks, n_classes))
      y_pred_batch = undo_transforms(y_pred_batch, transformers)
@@ -510,7 +511,7 @@ class TensorflowGraphModel(Model):
    y_pred = np.reshape(y_pred, (n_samples, n_tasks, n_classes))
    return y_pred

  def evaluate(self, dataset, metrics, transformers=[], pad_batch=False):
  def evaluate(self, dataset, metrics, transformers=[]):
    """
    Evaluates the performance of this model on specified dataset.
  
@@ -529,7 +530,7 @@ class TensorflowGraphModel(Model):
      Maps tasks to scores under metric.
    """
    evaluator = Evaluator(self, dataset, transformers)
    scores = evaluator.compute_model_performance(metrics, pad_batch=pad_batch)
    scores = evaluator.compute_model_performance(metrics)
    return scores

  def _find_last_checkpoint(self):
@@ -595,7 +596,7 @@ class TensorflowClassifier(TensorflowGraphModel):
                             name='labels_%d' % task)))
      return labels

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X):
    """Return model output for the provided input.

    Restore(checkpoint) must have previously been called on this object.
@@ -616,7 +617,7 @@ class TensorflowClassifier(TensorflowGraphModel):
      ValueError: If output and labels are not both 3D or both 2D.
    """
    len_unpadded = len(X)
    if pad_batch:
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    
    if not self._restored_model:
@@ -651,7 +652,7 @@ class TensorflowClassifier(TensorflowGraphModel):
    outputs = outputs[:len_unpadded]
    return outputs

  def predict_proba_on_batch(self, X, pad_batch=False):
  def predict_proba_on_batch(self, X):
    """Return model output for the provided input.

    Restore(checkpoint) must have previously been called on this object.
@@ -669,7 +670,7 @@ class TensorflowClassifier(TensorflowGraphModel):
      AssertionError: If model is not in evaluation mode.
      ValueError: If output and labels are not both 3D or both 2D.
    """
    if pad_batch:
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    if not self._restored_model:
      self.restore()
@@ -744,7 +745,7 @@ class TensorflowRegressor(TensorflowGraphModel):
                             name='labels_%d' % task)))
    return labels

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X):
    """Return model output for the provided input.

    Restore(checkpoint) must have previously been called on this object.
@@ -765,7 +766,7 @@ class TensorflowRegressor(TensorflowGraphModel):
      ValueError: If output and labels are not both 3D or both 2D.
    """
    len_unpadded = len(X)
    if pad_batch:
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    
    if not self._restored_model:
+4 −4
Original line number Diff line number Diff line
@@ -157,8 +157,8 @@ class TensorflowLogisticRegression(TensorflowGraphModel):
            (self.batch_size,)) 
    return TensorflowGraph.get_feed_dict(orig_dict)
  
  def predict_proba_on_batch(self, X, pad_batch=False):
    if pad_batch:
  def predict_proba_on_batch(self, X):
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    if not self._restored_model:
      self.restore()
@@ -190,9 +190,9 @@ class TensorflowLogisticRegression(TensorflowGraphModel):

    return np.copy(outputs)

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X):
    
    if pad_batch:
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    
    if not self._restored_model:
+5 −7
Original line number Diff line number Diff line
@@ -199,8 +199,8 @@ class ProgressiveJointRegressor(TensorflowMultiTaskRegressor):
        name="U_layer_%d_task%d" % (i, task), dtype=tf.float32)
    return tf.matmul(prev_layer, U)

  def fit(self, dataset, nb_epoch=10, pad_batches=False, 
          max_checkpoints_to_keep=5, log_every_N_batches=50, **kwargs):
  def fit(self, dataset, nb_epoch=10, max_checkpoints_to_keep=5, 
	  log_every_N_batches=50, **kwargs):
    """Fit the model.

    Parameters
@@ -209,8 +209,6 @@ class ProgressiveJointRegressor(TensorflowMultiTaskRegressor):
      Dataset object holding training data 
    nb_epoch: 10
      Number of training epochs.
    pad_batches: bool
      Whether or not to pad each batch to exactly be of size batch_size.
    max_checkpoints_to_keep: int
      Maximum number of checkpoints to keep; older checkpoints will be deleted.
    log_every_N_batches: int
@@ -240,7 +238,7 @@ class ProgressiveJointRegressor(TensorflowMultiTaskRegressor):
              # Turns out there are valid cases where we don't want pad-batches
              # on by default.
              #dataset.iterbatches(batch_size, pad_batches=True)):
              dataset.iterbatches(self.batch_size, pad_batches=pad_batches)):
              dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            # Run training op.
@@ -344,7 +342,7 @@ class ProgressiveJointRegressor(TensorflowMultiTaskRegressor):
            (self.batch_size,)) 
    return TensorflowGraph.get_feed_dict(orig_dict)

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X):
    """Return model output for the provided input.

    Restore(checkpoint) must have previously been called on this object.
@@ -365,7 +363,7 @@ class ProgressiveJointRegressor(TensorflowMultiTaskRegressor):
      ValueError: If output and labels are not both 3D or both 2D.
    """
    len_unpadded = len(X)
    if pad_batch:
    if self.pad_batches:
      X = pad_features(self.batch_size, X)
    
    if not self._restored_model:
+2 −4
Original line number Diff line number Diff line
@@ -523,7 +523,7 @@ class ProgressiveMultitaskRegressor(TensorflowMultiTaskRegressor):
      return self.eval_graph.session


  def fit_task(self, sess, dataset, task, task_train_op, nb_epoch=10, pad_batches=False,
  def fit_task(self, sess, dataset, task, task_train_op, nb_epoch=10,
               log_every_N_batches=50):
    """Fit the model.

@@ -540,8 +540,6 @@ class ProgressiveMultitaskRegressor(TensorflowMultiTaskRegressor):
      The index of the task to train on.
    nb_epoch: 10
      Number of training epochs.
    pad_batches: bool
      Whether or not to pad each batch to exactly be of size batch_size.
    max_checkpoints_to_keep: int
      Maximum number of checkpoints to keep; older checkpoints will be deleted.
    log_every_N_batches: int
@@ -564,7 +562,7 @@ class ProgressiveMultitaskRegressor(TensorflowMultiTaskRegressor):
          # Turns out there are valid cases where we don't want pad-batches
          # on by default.
          #dataset.iterbatches(batch_size, pad_batches=True)):
          dataset.iterbatches(self.batch_size, pad_batches=pad_batches)):
          dataset.iterbatches(self.batch_size, pad_batches=self.pad_batches)):
        if ind % log_every_N_batches == 0:
          log("On batch %d" % ind, self.verbose)
        feed_dict = self.construct_task_feed_dict(task, X_b, y_b, w_b, ids_b)
Loading