Unverified Commit 43a79ec9 authored by Bharath Ramsundar's avatar Bharath Ramsundar Committed by GitHub
Browse files

Merge pull request #2584 from kshen3778/valcallback

Added 'transformers' argument to ValidationCallback
parents b25e47fa 7f799c73
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -26,7 +26,8 @@ class ValidationCallback(object):
               output_file=sys.stdout,
               save_dir=None,
               save_metric=0,
               save_on_minimum=True):
               save_on_minimum=True,
               transformers=[]):
    """Create a ValidationCallback.

    Parameters
@@ -49,6 +50,10 @@ class ValidationCallback(object):
      if True, the best model is considered to be the one that minimizes the
      validation metric.  If False, the best model is considered to be the one
      that maximizes it.
    transformers: List[Transformer]
      List of `dc.trans.Transformer` objects. These transformations
      must have been applied to `dataset` previously. The dataset will
      be untransformed for metric evaluation.
    """
    self.dataset = dataset
    self.interval = interval
@@ -58,6 +63,7 @@ class ValidationCallback(object):
    self.save_metric = save_metric
    self.save_on_minimum = save_on_minimum
    self._best_score = None
    self.transformers = transformers

  def __call__(self, model, step):
    """This is invoked by the KerasModel after every step of fitting.
@@ -71,7 +77,7 @@ class ValidationCallback(object):
    """
    if step % self.interval != 0:
      return
    scores = model.evaluate(self.dataset, self.metrics)
    scores = model.evaluate(self.dataset, self.metrics, self.transformers)
    message = 'Step %d validation:' % step
    for key in scores:
      message += ' %s=%g' % (key, scores[key])
+2 −1
Original line number Diff line number Diff line
@@ -35,7 +35,8 @@ class TestCallbacks(unittest.TestCase):
        30, [metric],
        log,
        save_dir=save_dir,
        save_on_minimum=False)
        save_on_minimum=False,
        transformers=transformers)
    model.fit(train_dataset, callbacks=callback)

    # Parse the log to pull out the AUC scores.