Commit 4bf83b1d authored by Joseph Gomes's avatar Joseph Gomes
Browse files

Move batch padding funcs from Model to TensorflowGraphModel

parent 7010d1a8
Loading
Loading
Loading
Loading
+10 −25
Original line number Diff line number Diff line
@@ -23,7 +23,6 @@ from deepchem.trans import undo_grad_transforms
from deepchem.utils.save import load_from_disk
from deepchem.utils.save import save_to_disk
from deepchem.utils.save import log
from deepchem.data import pad_batch
from deepchem.utils.evaluate import Evaluator


@@ -32,7 +31,7 @@ class Model(object):
  Abstract base class for different ML models.
  """
  def __init__(self, model_instance=None, model_dir=None,
               fit_transformers=None, verbose=True, **kwargs):
               verbose=True, **kwargs):
    """Abstract class for all models.
    Parameters:
    -----------
@@ -49,7 +48,6 @@ class Model(object):
    self.model_dir = model_dir
    self.model_instance = model_instance
    self.model_class = model_instance.__class__
    self.fit_transformers = fit_transformers

    self.verbose = verbose

@@ -60,7 +58,7 @@ class Model(object):
    raise NotImplementedError(
        "Each model is responsible for its own fit_on_batch method.")

  def predict_on_batch(self, X, pad_batch=False):
  def predict_on_batch(self, X, pad_batches=False):
    """
    Makes predictions on given batch of new data.

@@ -75,7 +73,7 @@ class Model(object):
    raise NotImplementedError(
        "Each model is responsible for its own predict_on_batch method.")

  def predict_proba_on_batch(self, X, pad_batch=False):
  def predict_proba_on_batch(self, X, pad_batches=False):
    """
    Makes predictions of class probabilities on given batch of new data.

@@ -118,7 +116,7 @@ class Model(object):
    """
    raise NotImplementedError

  def fit(self, dataset, nb_epoch=10, batch_size=50, pad_batches=False, **kwargs):
  def fit(self, dataset, nb_epoch=10, batch_size=50, **kwargs):
    """
    Fits a model on data in a Dataset object.
    """
@@ -129,19 +127,11 @@ class Model(object):
      losses = []
      for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
          batch_size, pad_batches=pad_batches):
        if self.fit_transformers:
          X_batch, y_batch, w_batch = self.transform_on_batch(X_batch, y_batch,
                                            w_batch)
        if pad_batches:
          X_batch, y_batch, w_batch, ids_batch = pad_batch(
              batch_size, X_batch, y_batch, w_batch, ids_batch)
        
        losses.append(self.fit_on_batch(X_batch, y_batch, w_batch))
      log("Avg loss for epoch %d: %f"
          % (epoch+1,np.array(losses).mean()),self.verbose)

  def predict(self, dataset, transformers=[], batch_size=None,
              pad_batch=False):
  def predict(self, dataset, transformers=[], batch_size=None):
    """
    Uses self to make predictions on provided Dataset object.

@@ -152,15 +142,10 @@ class Model(object):
    n_tasks = self.get_num_tasks()
    ind = 0

    try:
      batch_size = self.batch_size
    except:
      batch_size = batch_size

    for (X_batch, _, _, ids_batch) in dataset.iterbatches(
        batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_on_batch(X_batch, pad_batch=pad_batch)
      y_pred_batch = self.predict_on_batch(X_batch)
      # Discard any padded predictions
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks))
@@ -177,7 +162,7 @@ class Model(object):
      y_pred = np.reshape(y_pred, (n_samples,)) 
    return y_pred

  def evaluate(self, dataset, metrics, transformers=[], pad_batch=False):
  def evaluate(self, dataset, metrics, transformers=[]):
    """
    Evaluates the performance of this model on specified dataset.
  
@@ -196,11 +181,11 @@ class Model(object):
      Maps tasks to scores under metric.
    """
    evaluator = Evaluator(self, dataset, transformers)
    scores = evaluator.compute_model_performance(metrics, pad_batch=pad_batch)
    scores = evaluator.compute_model_performance(metrics)
    return scores

  def predict_proba(self, dataset, transformers=[], batch_size=None,
                    n_classes=2, pad_batch=False):
                    n_classes=2):
    """
    TODO: Do transformers even make sense here?

@@ -212,7 +197,7 @@ class Model(object):
    for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
        batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_proba_on_batch(X_batch, pad_batch=pad_batch)
      y_pred_batch = self.predict_proba_on_batch(X_batch)
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks, n_classes))
      y_pred_batch = undo_transforms(y_pred_batch, transformers)
+81 −2
Original line number Diff line number Diff line
@@ -270,7 +270,7 @@ class TensorflowGraphModel(Model):

      return loss 

  def fit(self, dataset, nb_epoch=10, pad_batches=False, 
  def fit(self, dataset, nb_epoch=10, pad_batch=False, 
          max_checkpoints_to_keep=5, log_every_N_batches=50, **kwargs):
    """Fit the model.

@@ -311,7 +311,7 @@ class TensorflowGraphModel(Model):
              # Turns out there are valid cases where we don't want pad-batches
              # on by default.
              #dataset.iterbatches(batch_size, pad_batches=True)):
              dataset.iterbatches(self.batch_size, pad_batches=pad_batches)):
              dataset.iterbatches(self.batch_size, pad_batch=pad_batch)):
            if ind % log_every_N_batches == 0:
              log("On batch %d" % ind, self.verbose)
            # Run training op.
@@ -453,6 +453,85 @@ class TensorflowGraphModel(Model):
                    last_checkpoint)
      self._restored_model = True

  def predict(self, dataset, transformers=[], pad_batch=False):
    """
    Uses self to make predictions on provided Dataset object.

    Returns:
      y_pred: numpy ndarray of shape (n_samples,)
    """
    y_preds = []
    n_tasks = self.get_num_tasks()
    ind = 0

    for (X_batch, _, _, ids_batch) in dataset.iterbatches(
        self.batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_on_batch(X_batch, pad_batch=pad_batch)
      # Discard any padded predictions
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks))
      y_pred_batch = undo_transforms(y_pred_batch, transformers)
      y_preds.append(y_pred_batch)
    y_pred = np.vstack(y_preds)
  
    # The iterbatches does padding with zero-weight examples on the last batch.
    # Remove padded examples.
    n_samples = len(dataset)
    y_pred = np.reshape(y_pred, (n_samples, n_tasks))
    # Special case to handle singletasks.
    if n_tasks == 1:
      y_pred = np.reshape(y_pred, (n_samples,)) 
    return y_pred

  def predict_proba(self, dataset, transformers=[], n_classes=2, pad_batch=False):
    """
    TODO: Do transformers even make sense here?

    Returns:
      y_pred: numpy ndarray of shape (n_samples, n_classes*n_tasks)
    """
    y_preds = []
    n_tasks = self.get_num_tasks()

    for (X_batch, y_batch, w_batch, ids_batch) in dataset.iterbatches(
        self.batch_size, deterministic=True):
      n_samples = len(X_batch)
      y_pred_batch = self.predict_proba_on_batch(X_batch, pad_batch=pad_batch)
      y_pred_batch = y_pred_batch[:n_samples]
      y_pred_batch = np.reshape(y_pred_batch, (n_samples, n_tasks, n_classes))
      y_pred_batch = undo_transforms(y_pred_batch, transformers)
      y_preds.append(y_pred_batch)
    y_pred = np.vstack(y_preds)
    # The iterbatches does padding with zero-weight examples on the last batch.
    # Remove padded examples.
    n_samples = len(dataset)
    y_pred = y_pred[:n_samples]
    y_pred = np.reshape(y_pred, (n_samples, n_tasks, n_classes))
    return y_pred

  def evaluate(self, dataset, metrics, transformers=[], pad_batch=False):
    """
    Evaluates the performance of this model on specified dataset.
  
    Parameters
    ----------
    dataset: dc.data.Dataset
      Dataset object.
    metric: deepchem.metrics.Metric
      Evaluation metric
    transformers: list
      List of deepchem.transformers.Transformer

    Returns
    -------
    dict
      Maps tasks to scores under metric.
    """
    evaluator = Evaluator(self, dataset, transformers)
    scores = evaluator.compute_model_performance(metrics, pad_batch=pad_batch)
    return scores

  def _find_last_checkpoint(self):
    """Finds last saved checkpoint."""
    highest_num, last_checkpoint = -np.inf, None
+4 −2
Original line number Diff line number Diff line
@@ -79,11 +79,13 @@ class Evaluator(object):
    else:
      mode = metrics[0].mode
    if mode == "classification":
      y_pred = self.model.predict_proba(self.dataset, self.output_transformers, pad_batch=pad_batch)
      y_pred = self.model.predict_proba(self.dataset, self.output_transformers, 
				  	pad_batch=pad_batch)
      y_pred_print = self.model.predict(
          self.dataset, self.output_transformers).astype(int)
    else:
      y_pred = self.model.predict(self.dataset, self.output_transformers, pad_batch=pad_batch)
      y_pred = self.model.predict(self.dataset, self.output_transformers, 
				  pad_batch=pad_batch)
      y_pred_print = y_pred
    multitask_scores = {}